ThanhNguyen1811 commited on
Commit
da78c2f
·
verified ·
1 Parent(s): ece149d

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import pandas as pd
5
+ import os
6
+ import torch.nn as nn
7
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer
8
+
9
+ # Import các class mô hình từ file models.py
10
+ from models import MultimodalClassifier, TextClassifier
11
+
12
+ # --- 1. Thiết lập và Tải Mô hình (Tải một lần khi app khởi động) ---
13
+ print("Đang thiết lập thiết bị...")
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Sử dụng thiết bị: {device}")
16
+
17
+ # Định nghĩa nhãn
18
+ LABELS_A = {0: "Tức giận", 1: "Bình thường", 2: "Vui vẻ"}
19
+ LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường", 3: "Trung tính", 4: "Tích cực", 5: "Vui vẻ", 6: "Châm Biếm"}
20
+
21
+ # Đường dẫn (Tương đối với thư mục gốc của Space)
22
+ MODEL_A_PATH = "saved_models/best_model_A.pth"
23
+ MODEL_B_PATH = "saved_models/best_model_B.pth"
24
+ FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác
25
+
26
+ # Tải các mô hình nền (từ Hugging Face Hub)
27
+ print("Đang tải các mô hình nền (STT, PhoBERT)...")
28
+ audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
29
+ stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)
30
+ text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
31
+ text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device)
32
+
33
+ # Tải các mô hình đã huấn luyện (từ file .pth)
34
+ print("Đang tải các mô hình đã huấn luyện (A & B)...")
35
+ model_A = MultimodalClassifier(num_classes=len(LABELS_A)).to(device)
36
+ model_A.load_state_dict(torch.load(MODEL_A_PATH, map_location=device))
37
+ model_A.eval()
38
+
39
+ model_B = TextClassifier(n_classes=len(LABELS_B)).to(device)
40
+ model_B.load_state_dict(torch.load(MODEL_B_PATH, map_location=device))
41
+ model_B.eval()
42
+
43
+ # Đặt các mô hình nền sang chế độ eval
44
+ stt_model.eval()
45
+ text_feature_extractor.eval()
46
+
47
+ # Tải luật fuzzy
48
+ print("Đang tải luật fuzzy...")
49
+ try:
50
+ fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';')
51
+ fuzzy_rules = {}
52
+ for _, row in fuzzy_rules_df.iterrows():
53
+ # Đảm bảo tên cột khớp với file CSV của bạn
54
+ fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label']
55
+ print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.")
56
+ except Exception as e:
57
+ print(f"Lỗi khi tải luật fuzzy: {e}. Sử dụng luật dự phòng.")
58
+ fuzzy_rules = {("Bình thường", "Tiêu cực thông thường"): "Nguy cơ thấp (Dự phòng)"}
59
+
60
+ print("Tất cả mô hình đã sẵn sàng.")
61
+
62
+ # --- 2. Định nghĩa Hàm Dự đoán ---
63
+ # Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit"
64
+ def predict_sentiment(audio_input):
65
+ if audio_input is None:
66
+ return "[Chưa có âm thanh]", "N/A", "N/A", "N/A"
67
+
68
+ sample_rate, waveform_numpy = audio_input
69
+
70
+ # Đảm bảo waveform là tensor float
71
+ waveform = torch.from_numpy(waveform_numpy).float()
72
+
73
+ # Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo
74
+ if waveform.ndim > 1:
75
+ waveform = waveform[0]
76
+
77
+ # Thêm chiều batch (1,)
78
+ waveform = waveform.unsqueeze(0)
79
+
80
+ # --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio ---
81
+ try:
82
+ # 1a. Resample
83
+ if sample_rate != 16000:
84
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
85
+ waveform = resampler(waveform)
86
+
87
+ # 1b. Chuẩn bị input audio
88
+ input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device)
89
+
90
+ with torch.no_grad():
91
+ audio_outputs = stt_model(input_values, output_hidden_states=True)
92
+
93
+ # 2a. Trích xuất Văn bản (STT)
94
+ logits = audio_outputs.logits
95
+ predicted_ids = torch.argmax(logits, dim=-1)
96
+ transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower()
97
+
98
+ if not transcribed_text:
99
+ transcribed_text = "[Không nhận diện được giọng nói]"
100
+
101
+ # 2b. Trích xuất Đặc trưng Audio (cho Model A)
102
+ audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1)
103
+
104
+ except Exception as e:
105
+ return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio"
106
+
107
+ # --- Bước 3: Đặc trưng Text và Dự đoán Model B ---
108
+ try:
109
+ inputs_text = text_tokenizer(
110
+ transcribed_text,
111
+ return_tensors="pt",
112
+ padding=True,
113
+ truncation=True,
114
+ max_length=256
115
+ ).to(device)
116
+
117
+ with torch.no_grad():
118
+ # 3a. Đặc trưng Text (cho Model A)
119
+ text_outputs = text_feature_extractor(**inputs_text)
120
+ text_feat_A = text_outputs.pooler_output
121
+
122
+ # 3b. Dự đoán Model B
123
+ output_B = model_B(inputs_text['input_ids'], inputs_text['attention_mask'])
124
+ pred_idx_B = torch.argmax(output_B, dim=1).item()
125
+ pred_label_B = LABELS_B.get(pred_idx_B, f"Lỗi Nhãn B ({pred_idx_B})")
126
+
127
+ except Exception as e:
128
+ return f"[Lỗi xử lý text: {e}]", "Lỗi Text", "Lỗi Text", "Lỗi Text"
129
+
130
+ # --- Bước 4: Dự đoán Model A ---
131
+ try:
132
+ with torch.no_grad():
133
+ output_A = model_A(text_feat_A, audio_feat_A)
134
+ pred_idx_A = torch.argmax(output_A, dim=1).item()
135
+ pred_label_A = LABELS_A.get(pred_idx_A, f"Lỗi Nhãn A ({pred_idx_A})")
136
+
137
+ except Exception as e:
138
+ return transcribed_text, "Lỗi Model A", pred_label_B, f"[Lỗi Model A: {e}]"
139
+
140
+ # --- Bước 5: Kết hợp Fuzzy Logic ---
141
+ final_prediction = fuzzy_rules.get((pred_label_A, pred_label_B), "Không có luật")
142
+
143
+ # Trả về các giá trị cho các ô output của Gradio
144
+ return transcribed_text, pred_label_A, pred_label_B, final_prediction
145
+
146
+ # --- 3. Xây dựng Giao diện Gradio ---
147
+ print("Đang xây dựng giao diện Gradio...")
148
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
149
+ gr.Markdown("# Ứng dụng Phân tích Cảm xúc Đa phương tiện")
150
+ gr.Markdown("Tải lên một tệp âm thanh (.wav, .mp3, v.v.) **hoặc ghi âm trực tiếp** để dự đoán cảm xúc.")
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=2):
154
+ # === BỔ SUNG TÍNH NĂNG ===
155
+ # Thêm "microphone" vào sources để cho phép ghi âm
156
+ audio_in = gr.Audio(
157
+ sources=["upload", "microphone"], # Cho phép cả tải lên và ghi âm
158
+ type="numpy",
159
+ label="Tải lên tệp âm thanh hoặc Ghi âm"
160
+ )
161
+ submit_btn = gr.Button("Phân tích", variant="primary")
162
+
163
+ with gr.Column(scale=3):
164
+ gr.Markdown("### Kết quả Phân tích")
165
+ # Các ô output
166
+ text_out = gr.Textbox(label="Văn bản được nhận diện (STT)")
167
+ final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)")
168
+
169
+ with gr.Accordion("Xem chi tiết dự đoán của từng mô hình", open=False):
170
+ pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)")
171
+ pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)")
172
+
173
+ # Liên kết nút bấm với hàm dự đoán
174
+ submit_btn.click(
175
+ fn=predict_sentiment,
176
+ inputs=audio_in,
177
+ outputs=[text_out, pred_A_out, pred_B_out, final_pred_out]
178
+ )
179
+
180
+ gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.")
181
+
182
+ print("Đang khởi chạy demo...")
183
+ demo.launch() # Không cần (share=True) khi chạy trên Spaces
data/datafuzzy29d.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_a_label;model_b_label;final_label
2
+ Bình thường;Châm Biếm;Nguy cơ thấp
3
+ Vui vẻ;Châm Biếm;Nguy cơ thấp
4
+ Tức giận;Châm Biếm;Nguy cơ cao
5
+ Bình thường;Đe dọa;Nguy cơ cao
6
+ Vui vẻ;Đe dọa;Nguy cơ cao
7
+ Tức giận;Đe dọa;Nguy cơ cao
8
+ Bình thường;Tích cực;Không có nguy cơ
9
+ Vui vẻ;Tích cực;Không có nguy cơ
10
+ Tức giận;Tích cực;Nguy cơ thấp
11
+ Bình thường;Tiêu cực thông thường;Nguy cơ thấp
12
+ Vui vẻ;Tiêu cực thông thường;Nguy cơ thấp
13
+ Tức giận;Tiêu cực thông thường;Nguy cơ cao
14
+ Bình thường;Tức giận;Nguy cơ cao
15
+ Vui vẻ;Tức giận;Nguy cơ cao
16
+ Tức giận;Tức giận;Nguy cơ cao
17
+ Bình thường;Trung tính;Không có nguy cơ
18
+ Tức giận;Trung tính;Nguy cơ thấp
19
+ Vui vẻ;Trung tính;Không có nguy cơ
20
+ Bình thường;Vui vẻ;Không có nguy cơ
21
+ Vui vẻ;Vui vẻ;Không có nguy cơ
22
+ Tức giận;Vui vẻ;Nguy cơ thấp
models.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ # Kiến trúc mô hình A (Multimodal)
6
+ class MultimodalClassifier(nn.Module):
7
+ def __init__(self, num_classes, text_feature_dim=768, audio_feature_dim=768, hidden_dim=512):
8
+ super(MultimodalClassifier, self).__init__()
9
+ self.fc1 = nn.Linear(text_feature_dim + audio_feature_dim, hidden_dim)
10
+ self.relu = nn.ReLU()
11
+ self.dropout = nn.Dropout(0.5)
12
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
13
+
14
+ def forward(self, text_features, audio_features):
15
+ combined_features = torch.cat((text_features, audio_features), dim=1)
16
+ x = self.fc1(combined_features)
17
+ x = self.relu(x)
18
+ x = self.dropout(x)
19
+ x = self.fc2(x)
20
+ return x
21
+
22
+ # Kiến trúc mô hình B (Text-only)
23
+ class TextClassifier(nn.Module):
24
+ def __init__(self, n_classes):
25
+ super(TextClassifier, self).__init__()
26
+ # Load mô hình nền khi khởi tạo class
27
+ self.bert = AutoModel.from_pretrained("vinai/phobert-base")
28
+ self.drop = nn.Dropout(p=0.3)
29
+ self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
30
+
31
+ def forward(self, input_ids, attention_mask):
32
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled_output = outputs.pooler_output
34
+ output = self.drop(pooled_output)
35
+ return self.out(output)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ transformers
4
+ pandas
5
+ gradio
6
+ accelerate
saved_models/best_model_A.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9a75d9d345618892c3bf01eace1f7d4c00c3060711c60fe3e825f4c9cb6afb2
3
+ size 3156495
saved_models/best_model_B.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b7dc0cd2284c6b0211ca968493fa5e70f45f9428db4053aab3373b3a18ae376
3
+ size 540097656