File size: 8,012 Bytes
da78c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
import torch
import torchaudio
import pandas as pd
import os
import torch.nn as nn
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer

# Import các class mô hình từ file models.py
from models import MultimodalClassifier, TextClassifier

# --- 1. Thiết lập và Tải Mô hình (Tải một lần khi app khởi động) ---
print("Đang thiết lập thiết bị...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Sử dụng thiết bị: {device}")

# Định nghĩa nhãn
LABELS_A = {0: "Tức giận", 1: "Bình thường", 2: "Vui vẻ"}
LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường", 3: "Trung tính", 4: "Tích cực", 5: "Vui vẻ", 6: "Châm Biếm"}

# Đường dẫn (Tương đối với thư mục gốc của Space)
MODEL_A_PATH = "saved_models/best_model_A.pth"
MODEL_B_PATH = "saved_models/best_model_B.pth"
FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác

# Tải các mô hình nền (từ Hugging Face Hub)
print("Đang tải các mô hình nền (STT, PhoBERT)...")
audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)
text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device)

# Tải các mô hình đã huấn luyện (từ file .pth)
print("Đang tải các mô hình đã huấn luyện (A & B)...")
model_A = MultimodalClassifier(num_classes=len(LABELS_A)).to(device)
model_A.load_state_dict(torch.load(MODEL_A_PATH, map_location=device))
model_A.eval()

model_B = TextClassifier(n_classes=len(LABELS_B)).to(device)
model_B.load_state_dict(torch.load(MODEL_B_PATH, map_location=device))
model_B.eval()

# Đặt các mô hình nền sang chế độ eval
stt_model.eval()
text_feature_extractor.eval()

# Tải luật fuzzy
print("Đang tải luật fuzzy...")
try:
    fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';')
    fuzzy_rules = {}
    for _, row in fuzzy_rules_df.iterrows():
        # Đảm bảo tên cột khớp với file CSV của bạn
        fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label']
    print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.")
except Exception as e:
    print(f"Lỗi khi tải luật fuzzy: {e}. Sử dụng luật dự phòng.")
    fuzzy_rules = {("Bình thường", "Tiêu cực thông thường"): "Nguy cơ thấp (Dự phòng)"} 

print("Tất cả mô hình đã sẵn sàng.")

# --- 2. Định nghĩa Hàm Dự đoán ---
# Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit"
def predict_sentiment(audio_input):
    if audio_input is None:
        return "[Chưa có âm thanh]", "N/A", "N/A", "N/A"

    sample_rate, waveform_numpy = audio_input
    
    # Đảm bảo waveform là tensor float
    waveform = torch.from_numpy(waveform_numpy).float()
    
    # Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo
    if waveform.ndim > 1:
        waveform = waveform[0]
        
    # Thêm chiều batch (1,)
    waveform = waveform.unsqueeze(0)

    # --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio ---
    try:
        # 1a. Resample
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)

        # 1b. Chuẩn bị input audio
        input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device)

        with torch.no_grad():
            audio_outputs = stt_model(input_values, output_hidden_states=True)

        # 2a. Trích xuất Văn bản (STT)
        logits = audio_outputs.logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower()
        
        if not transcribed_text:
            transcribed_text = "[Không nhận diện được giọng nói]"

        # 2b. Trích xuất Đặc trưng Audio (cho Model A)
        audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1)

    except Exception as e:
        return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio"

    # --- Bước 3: Đặc trưng Text và Dự đoán Model B ---
    try:
        inputs_text = text_tokenizer(
            transcribed_text, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=256
        ).to(device)

        with torch.no_grad():
            # 3a. Đặc trưng Text (cho Model A)
            text_outputs = text_feature_extractor(**inputs_text)
            text_feat_A = text_outputs.pooler_output

            # 3b. Dự đoán Model B
            output_B = model_B(inputs_text['input_ids'], inputs_text['attention_mask'])
            pred_idx_B = torch.argmax(output_B, dim=1).item()
            pred_label_B = LABELS_B.get(pred_idx_B, f"Lỗi Nhãn B ({pred_idx_B})")

    except Exception as e:
         return f"[Lỗi xử lý text: {e}]", "Lỗi Text", "Lỗi Text", "Lỗi Text"

    # --- Bước 4: Dự đoán Model A ---
    try:
        with torch.no_grad():
            output_A = model_A(text_feat_A, audio_feat_A)
            pred_idx_A = torch.argmax(output_A, dim=1).item()
            pred_label_A = LABELS_A.get(pred_idx_A, f"Lỗi Nhãn A ({pred_idx_A})")
    
    except Exception as e:
        return transcribed_text, "Lỗi Model A", pred_label_B, f"[Lỗi Model A: {e}]"

    # --- Bước 5: Kết hợp Fuzzy Logic ---
    final_prediction = fuzzy_rules.get((pred_label_A, pred_label_B), "Không có luật")

    # Trả về các giá trị cho các ô output của Gradio
    return transcribed_text, pred_label_A, pred_label_B, final_prediction

# --- 3. Xây dựng Giao diện Gradio ---
print("Đang xây dựng giao diện Gradio...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Ứng dụng Phân tích Cảm xúc Đa phương tiện")
    gr.Markdown("Tải lên một tệp âm thanh (.wav, .mp3, v.v.) **hoặc ghi âm trực tiếp** để dự đoán cảm xúc.")
    
    with gr.Row():
        with gr.Column(scale=2):
            # === BỔ SUNG TÍNH NĂNG ===
            # Thêm "microphone" vào sources để cho phép ghi âm
            audio_in = gr.Audio(
                sources=["upload", "microphone"],  # Cho phép cả tải lên và ghi âm
                type="numpy", 
                label="Tải lên tệp âm thanh hoặc Ghi âm"
            )
            submit_btn = gr.Button("Phân tích", variant="primary")
        
        with gr.Column(scale=3):
            gr.Markdown("### Kết quả Phân tích")
            # Các ô output
            text_out = gr.Textbox(label="Văn bản được nhận diện (STT)")
            final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)")
            
            with gr.Accordion("Xem chi tiết dự đoán của từng mô hình", open=False):
                pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)")
                pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)")

    # Liên kết nút bấm với hàm dự đoán
    submit_btn.click(
        fn=predict_sentiment,
        inputs=audio_in,
        outputs=[text_out, pred_A_out, pred_B_out, final_pred_out]
    )
    
    gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.")

print("Đang khởi chạy demo...")
demo.launch() # Không cần (share=True) khi chạy trên Spaces