Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import pandas as pd | |
| import os | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer | |
| # Import các class mô hình từ file models.py | |
| from models import MultimodalClassifier, TextClassifier | |
| # --- 1. Thiết lập và Tải Mô hình (Tải một lần khi app khởi động) --- | |
| print("Đang thiết lập thiết bị...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Sử dụng thiết bị: {device}") | |
| # Định nghĩa nhãn | |
| LABELS_A = {0: "Tức giận", 1: "Bình thường", 2: "Vui vẻ"} | |
| LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường", 3: "Trung tính", 4: "Tích cực", 5: "Vui vẻ", 6: "Châm Biếm"} | |
| # Đường dẫn (Tương đối với thư mục gốc của Space) | |
| MODEL_A_PATH = "saved_models/best_model_A.pth" | |
| MODEL_B_PATH = "saved_models/best_model_B.pth" | |
| FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác | |
| # Tải các mô hình nền (từ Hugging Face Hub) | |
| print("Đang tải các mô hình nền (STT, PhoBERT)...") | |
| audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h") | |
| stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device) | |
| text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") | |
| text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device) | |
| # Tải các mô hình đã huấn luyện (từ file .pth) | |
| print("Đang tải các mô hình đã huấn luyện (A & B)...") | |
| model_A = MultimodalClassifier(num_classes=len(LABELS_A)).to(device) | |
| model_A.load_state_dict(torch.load(MODEL_A_PATH, map_location=device)) | |
| model_A.eval() | |
| model_B = TextClassifier(n_classes=len(LABELS_B)).to(device) | |
| model_B.load_state_dict(torch.load(MODEL_B_PATH, map_location=device)) | |
| model_B.eval() | |
| # Đặt các mô hình nền sang chế độ eval | |
| stt_model.eval() | |
| text_feature_extractor.eval() | |
| # Tải luật fuzzy | |
| print("Đang tải luật fuzzy...") | |
| try: | |
| fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';') | |
| fuzzy_rules = {} | |
| for _, row in fuzzy_rules_df.iterrows(): | |
| # Đảm bảo tên cột khớp với file CSV của bạn | |
| fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label'] | |
| print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.") | |
| except Exception as e: | |
| print(f"Lỗi khi tải luật fuzzy: {e}. Sử dụng luật dự phòng.") | |
| fuzzy_rules = {("Bình thường", "Tiêu cực thông thường"): "Nguy cơ thấp (Dự phòng)"} | |
| print("Tất cả mô hình đã sẵn sàng.") | |
| # --- 2. Định nghĩa Hàm Dự đoán --- | |
| # Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit" | |
| def predict_sentiment(audio_input): | |
| if audio_input is None: | |
| return "[Chưa có âm thanh]", "N/A", "N/A", "N/A" | |
| sample_rate, waveform_numpy = audio_input | |
| # Đảm bảo waveform là tensor float | |
| waveform = torch.from_numpy(waveform_numpy).float() | |
| # Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo | |
| if waveform.ndim > 1: | |
| waveform = waveform[0] | |
| # Thêm chiều batch (1,) | |
| waveform = waveform.unsqueeze(0) | |
| # --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio --- | |
| try: | |
| # 1a. Resample | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| # 1b. Chuẩn bị input audio | |
| input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device) | |
| with torch.no_grad(): | |
| audio_outputs = stt_model(input_values, output_hidden_states=True) | |
| # 2a. Trích xuất Văn bản (STT) | |
| logits = audio_outputs.logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower() | |
| if not transcribed_text: | |
| transcribed_text = "[Không nhận diện được giọng nói]" | |
| # 2b. Trích xuất Đặc trưng Audio (cho Model A) | |
| audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1) | |
| except Exception as e: | |
| return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio" | |
| # --- Bước 3: Đặc trưng Text và Dự đoán Model B --- | |
| try: | |
| inputs_text = text_tokenizer( | |
| transcribed_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256 | |
| ).to(device) | |
| with torch.no_grad(): | |
| # 3a. Đặc trưng Text (cho Model A) | |
| text_outputs = text_feature_extractor(**inputs_text) | |
| text_feat_A = text_outputs.pooler_output | |
| # 3b. Dự đoán Model B | |
| output_B = model_B(inputs_text['input_ids'], inputs_text['attention_mask']) | |
| pred_idx_B = torch.argmax(output_B, dim=1).item() | |
| pred_label_B = LABELS_B.get(pred_idx_B, f"Lỗi Nhãn B ({pred_idx_B})") | |
| except Exception as e: | |
| return f"[Lỗi xử lý text: {e}]", "Lỗi Text", "Lỗi Text", "Lỗi Text" | |
| # --- Bước 4: Dự đoán Model A --- | |
| try: | |
| with torch.no_grad(): | |
| output_A = model_A(text_feat_A, audio_feat_A) | |
| pred_idx_A = torch.argmax(output_A, dim=1).item() | |
| pred_label_A = LABELS_A.get(pred_idx_A, f"Lỗi Nhãn A ({pred_idx_A})") | |
| except Exception as e: | |
| return transcribed_text, "Lỗi Model A", pred_label_B, f"[Lỗi Model A: {e}]" | |
| # --- Bước 5: Kết hợp Fuzzy Logic --- | |
| final_prediction = fuzzy_rules.get((pred_label_A, pred_label_B), "Không có luật") | |
| # Trả về các giá trị cho các ô output của Gradio | |
| return transcribed_text, pred_label_A, pred_label_B, final_prediction | |
| # --- 3. Xây dựng Giao diện Gradio --- | |
| print("Đang xây dựng giao diện Gradio...") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Ứng dụng Phân tích Cảm xúc Đa phương tiện") | |
| gr.Markdown("Tải lên một tệp âm thanh (.wav, .mp3, v.v.) **hoặc ghi âm trực tiếp** để dự đoán cảm xúc.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # === BỔ SUNG TÍNH NĂNG === | |
| # Thêm "microphone" vào sources để cho phép ghi âm | |
| audio_in = gr.Audio( | |
| sources=["upload", "microphone"], # Cho phép cả tải lên và ghi âm | |
| type="numpy", | |
| label="Tải lên tệp âm thanh hoặc Ghi âm" | |
| ) | |
| submit_btn = gr.Button("Phân tích", variant="primary") | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Kết quả Phân tích") | |
| # Các ô output | |
| text_out = gr.Textbox(label="Văn bản được nhận diện (STT)") | |
| final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)") | |
| with gr.Accordion("Xem chi tiết dự đoán của từng mô hình", open=False): | |
| pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)") | |
| pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)") | |
| # Liên kết nút bấm với hàm dự đoán | |
| submit_btn.click( | |
| fn=predict_sentiment, | |
| inputs=audio_in, | |
| outputs=[text_out, pred_A_out, pred_B_out, final_pred_out] | |
| ) | |
| gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.") | |
| print("Đang khởi chạy demo...") | |
| demo.launch() # Không cần (share=True) khi chạy trên Spaces |