import torch import torch.nn.functional as F import os from transformers import BertJapaneseTokenizer import warnings warnings.filterwarnings("ignore") warnings.filterwarnings("ignore", message="Distant resource does not have an ETag") warnings.filterwarnings("ignore", category=UserWarning) # 年代モデルと性別モデルの定義をインポート from SupervisedLearning import BertForAgeClassification, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS # 統合モデル用のクラス定義 class BertForClassification(torch.nn.Module): """統合分類モデル(年代と性別を同時に分類)""" def __init__(self, model_name, num_classes): super().__init__() from transformers import BertModel if model_name is None: self.bert = BertModel.from_pretrained('cl-tohoku/bert-large-japanese', use_safetensors=True) else: self.bert = BertModel.from_pretrained(model_name, use_safetensors=True) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_classes) def forward(self, input_ids, attention_mask, labels=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return loss, logits # モデルファイルのパス AGE_MODEL_PATH = 'bert_age_model.bin' GENDER_MODEL_PATH = 'bert_gender_model.bin' # 性別のカテゴリマッピング GENDER_CATEGORIES = ["male", "female"] GENDER_CATEGORIES_JP = ["男性", "女性"] # --- グローバル変数としてモデルとトークナイザを一度だけロード --- TOKENIZER = None AGE_MODEL = None GENDER_MODEL = None MODELS_LOADED = False # モデル読み込み状態を追跡 def load_shared_tokenizer(): """共有トークナイザーを読み込む""" global TOKENIZER if TOKENIZER is not None: return # 既に読み込み済み print("🔧 共有トークナイザーの読み込みを開始します...") # 共有トークナイザー読み込み戦略(Largeモデル用) tokenizer_loaded = False tokenizer_models = ['cl-tohoku/bert-large-japanese', 'cl-tohoku/bert-base-japanese-v3'] for model_name in tokenizer_models: if tokenizer_loaded: break print(f"共有トークナイザー読み込み試行: {model_name}") # 戦略1: オンラインモード try: print(f"オンラインモードでトークナイザーを読み込み中... ({model_name})") TOKENIZER = BertJapaneseTokenizer.from_pretrained( model_name, use_fast=False, force_download=False, resume_download=True ) print(f"✅ オンラインモードでトークナイザーの読み込みが完了しました ({model_name})") tokenizer_loaded = True except Exception as e: print(f"オンラインモード失敗 ({model_name}): {e}") # 戦略2: オフラインモード try: print(f"オフラインモードでトークナイザーを読み込み中... ({model_name})") os.environ['TRANSFORMERS_OFFLINE'] = '1' TOKENIZER = BertJapaneseTokenizer.from_pretrained( model_name, local_files_only=True, use_fast=False ) print(f"✅ オフラインモードでトークナイザーの読み込みが完了しました ({model_name})") tokenizer_loaded = True except Exception as e2: print(f"オフラインモード失敗 ({model_name}): {e2}") if not tokenizer_loaded: raise Exception("共有トークナイザーの読み込みに失敗しました") def load_age_model(): """年代予測用モデルを読み込む""" global TOKENIZER, AGE_MODEL # モデルファイルの存在確認 if not os.path.exists(AGE_MODEL_PATH): raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。") print("--- 年代モデルの読み込みを開始します ---") # 共有トークナイザーの読み込み load_shared_tokenizer() # 年代モデルの読み込み(bert-large-japanese を使用) print("📊 年代モデルの読み込みを開始します...") age_model_loaded = False # 戦略1: bert-large-japanese で試行(学習時と同じ) try: print("年代モデル: bert-large-japanese で試行中...") AGE_MODEL = BertForAgeClassification('cl-tohoku/bert-large-japanese', NUM_AGE_CLASSIFIERS) AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE)) AGE_MODEL.to(DEVICE) AGE_MODEL.eval() print("✅ 年代モデル(bert-large-japanese)の読み込みが完了しました") age_model_loaded = True except Exception as e: print(f"年代モデル(bert-large-japanese)失敗: {e}") # 戦略2: bert-base-japanese-v3 で試行 try: print("年代モデル: bert-base-japanese-v3 で試行中...") AGE_MODEL = BertForAgeClassification('cl-tohoku/bert-base-japanese-v3', NUM_AGE_CLASSIFIERS) AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE)) AGE_MODEL.to(DEVICE) AGE_MODEL.eval() print("✅ 年代モデル(bert-base-japanese-v3)の読み込みが完了しました") age_model_loaded = True except Exception as e2: print(f"年代モデル(bert-base-japanese-v3)も失敗: {e2}") if not age_model_loaded: raise Exception("年代モデルの読み込みに失敗しました") print("🎉 年代モデルの読み込みが完了しました!") def load_gender_model(): """性別予測用モデルを読み込む""" global TOKENIZER, GENDER_MODEL # 統合モデルファイルの存在確認 classification_model_path = 'bert_classification_model.bin' if not os.path.exists(classification_model_path): print(f"警告: 統合学習済みモデル '{classification_model_path}' が見つかりません。") print("性別予測は利用できません。") return print("--- 性別モデルの読み込みを開始します(統合モデル使用) ---") # 共有トークナイザーの読み込み load_shared_tokenizer() # 統合モデルの読み込み(bert-large-japanese を使用) print("👥 性別モデルの読み込みを開始します(統合モデル)...") gender_model_loaded = False # 戦略1: bert-large-japanese で試行(統合モデルはLarge) try: print("性別モデル: bert-large-japanese で試行中(統合モデル)...") GENDER_MODEL = BertForClassification('cl-tohoku/bert-large-japanese', 12) # 12クラス分類 GENDER_MODEL.load_state_dict(torch.load(classification_model_path, map_location=DEVICE)) GENDER_MODEL.to(DEVICE) GENDER_MODEL.eval() print("✅ 性別モデル(統合モデル)の読み込みが完了しました") gender_model_loaded = True except Exception as e: print(f"性別モデル(統合モデル)失敗: {e}") # 戦略2: bert-base-japanese-v3 で試行 try: print("性別モデル: bert-base-japanese-v3 で試行中(統合モデル)...") GENDER_MODEL = BertForClassification('cl-tohoku/bert-base-japanese-v3', 12) # 12クラス分類 GENDER_MODEL.load_state_dict(torch.load(classification_model_path, map_location=DEVICE)) GENDER_MODEL.to(DEVICE) GENDER_MODEL.eval() print("✅ 性別モデル(統合モデル)の読み込みが完了しました") gender_model_loaded = True except Exception as e2: print(f"性別モデル(統合モデル)も失敗: {e2}") if not gender_model_loaded: print("❌ 性別モデルの読み込みに失敗しました") print("⚠️ 性別予測は利用できません") GENDER_MODEL = None else: print("🎉 性別モデル(統合モデル)の読み込みが完了しました!") def load_models(): """アプリケーション起動時にモデルを一度だけ読み込む(後方互換性のため)""" global MODELS_LOADED # 既に読み込み済みの場合はスキップ if MODELS_LOADED: print("✅ モデルは既に読み込み済みです。スキップします。") return # 年代モデルと性別モデルを個別に読み込み load_age_model() load_gender_model() MODELS_LOADED = True # 読み込み完了フラグを設定 def predict_age(text): """テキストから年代を予測(統合関数を使用)""" # 統合予測を実行 full_result = predict_text(text) return full_result['age_percentages'] def predict_gender(text): """テキストから性別を予測(統合関数を使用)""" # 統合予測を実行 full_result = predict_text(text) return full_result['gender_percentages'] def predict_text(text): """テキストから年代と性別を統合予測""" global TOKENIZER, AGE_MODEL, GENDER_MODEL # モデルが読み込まれていない場合は読み込み if AGE_MODEL is None: load_age_model() if GENDER_MODEL is None: load_gender_model() if TOKENIZER is None or AGE_MODEL is None or GENDER_MODEL is None: raise Exception("モデルが読み込まれていません。") # テキストの前処理 if not text or not text.strip(): raise ValueError("テキストが空です。") # トークン化(一度だけ実行) inputs = TOKENIZER.encode_plus( text, add_special_tokens=True, max_length=128, padding='max_length', truncation=True, return_tensors='pt' ) # デバイスに移動 input_ids = inputs['input_ids'].to(DEVICE) attention_mask = inputs['attention_mask'].to(DEVICE) # 年代と性別を同時に予測 with torch.no_grad(): # 年代予測 age_output = AGE_MODEL(input_ids, attention_mask) if isinstance(age_output, tuple): age_logits = age_output[1] else: age_logits = age_output age_probs = torch.sigmoid(age_logits) age_probs = age_probs.cpu().numpy().flatten() # 性別予測(統合モデルから性別部分を抽出) classification_output = GENDER_MODEL(input_ids, attention_mask) if isinstance(classification_output, tuple): all_logits = classification_output[1] else: all_logits = classification_output # 統合モデルの12クラス出力から性別部分(最後の2クラス)を抽出 gender_logits = all_logits[:, -2:] gender_probs = torch.softmax(gender_logits, dim=1) gender_probs = gender_probs.cpu().numpy().flatten() # 結果の整形 age_result = {} for i, category in enumerate(AGE_CATEGORIES): age_result[category] = float(age_probs[i] * 100) gender_result = {} for i, category in enumerate(GENDER_CATEGORIES_JP): gender_result[category] = float(gender_probs[i] * 100) result = { 'age_percentages': age_result, 'gender_percentages': gender_result } return result def get_top_predictions(result, top_k=3): """予測結果から上位k個を取得""" # 年代の上位予測 age_sorted = sorted(result['age_percentages'].items(), key=lambda x: x[1], reverse=True) top_ages = age_sorted[:top_k] # 性別の上位予測 gender_sorted = sorted(result['gender_percentages'].items(), key=lambda x: x[1], reverse=True) top_genders = gender_sorted[:top_k] return { 'top_ages': top_ages, 'top_genders': top_genders } def format_prediction_result(result): """予測結果を読みやすい形式に整形""" formatted = "=== 予測結果 ===\n" # 年代予測結果 formatted += "\n📊 年代予測:\n" for age, percentage in sorted(result['age_percentages'].items(), key=lambda x: x[1], reverse=True): formatted += f" {age}: {percentage:.1f}%\n" # 性別予測結果 formatted += "\n👥 性別予測:\n" for gender, percentage in sorted(result['gender_percentages'].items(), key=lambda x: x[1], reverse=True): formatted += f" {gender}: {percentage:.1f}%\n" return formatted # テスト用の関数 def test_prediction(): """予測機能のテスト(統合予測)""" test_text = "こんにちは、今日は良い天気ですね。" print(f"テストテキスト: {test_text}") try: result = predict_text(test_text) print(format_prediction_result(result)) # 個別予測もテスト print("\n=== 個別予測テスト ===") age_result = predict_age(test_text) print(f"年代予測: {age_result}") gender_result = predict_gender(test_text) print(f"性別予測: {gender_result}") except Exception as e: print(f"予測エラー: {e}") if __name__ == "__main__": # モデルの読み込みテスト try: load_models() test_prediction() except Exception as e: print(f"エラー: {e}")