wav2vec2 / app.py
ThanhNguyen1811's picture
Upload 6 files
da78c2f verified
raw
history blame
8.01 kB
import gradio as gr
import torch
import torchaudio
import pandas as pd
import os
import torch.nn as nn
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer
# Import các class mô hình từ file models.py
from models import MultimodalClassifier, TextClassifier
# --- 1. Thiết lập và Tải Mô hình (Tải một lần khi app khởi động) ---
print("Đang thiết lập thiết bị...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Sử dụng thiết bị: {device}")
# Định nghĩa nhãn
LABELS_A = {0: "Tức giận", 1: "Bình thường", 2: "Vui vẻ"}
LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường", 3: "Trung tính", 4: "Tích cực", 5: "Vui vẻ", 6: "Châm Biếm"}
# Đường dẫn (Tương đối với thư mục gốc của Space)
MODEL_A_PATH = "saved_models/best_model_A.pth"
MODEL_B_PATH = "saved_models/best_model_B.pth"
FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác
# Tải các mô hình nền (từ Hugging Face Hub)
print("Đang tải các mô hình nền (STT, PhoBERT)...")
audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)
text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device)
# Tải các mô hình đã huấn luyện (từ file .pth)
print("Đang tải các mô hình đã huấn luyện (A & B)...")
model_A = MultimodalClassifier(num_classes=len(LABELS_A)).to(device)
model_A.load_state_dict(torch.load(MODEL_A_PATH, map_location=device))
model_A.eval()
model_B = TextClassifier(n_classes=len(LABELS_B)).to(device)
model_B.load_state_dict(torch.load(MODEL_B_PATH, map_location=device))
model_B.eval()
# Đặt các mô hình nền sang chế độ eval
stt_model.eval()
text_feature_extractor.eval()
# Tải luật fuzzy
print("Đang tải luật fuzzy...")
try:
fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';')
fuzzy_rules = {}
for _, row in fuzzy_rules_df.iterrows():
# Đảm bảo tên cột khớp với file CSV của bạn
fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label']
print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.")
except Exception as e:
print(f"Lỗi khi tải luật fuzzy: {e}. Sử dụng luật dự phòng.")
fuzzy_rules = {("Bình thường", "Tiêu cực thông thường"): "Nguy cơ thấp (Dự phòng)"}
print("Tất cả mô hình đã sẵn sàng.")
# --- 2. Định nghĩa Hàm Dự đoán ---
# Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit"
def predict_sentiment(audio_input):
if audio_input is None:
return "[Chưa có âm thanh]", "N/A", "N/A", "N/A"
sample_rate, waveform_numpy = audio_input
# Đảm bảo waveform là tensor float
waveform = torch.from_numpy(waveform_numpy).float()
# Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo
if waveform.ndim > 1:
waveform = waveform[0]
# Thêm chiều batch (1,)
waveform = waveform.unsqueeze(0)
# --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio ---
try:
# 1a. Resample
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
# 1b. Chuẩn bị input audio
input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device)
with torch.no_grad():
audio_outputs = stt_model(input_values, output_hidden_states=True)
# 2a. Trích xuất Văn bản (STT)
logits = audio_outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower()
if not transcribed_text:
transcribed_text = "[Không nhận diện được giọng nói]"
# 2b. Trích xuất Đặc trưng Audio (cho Model A)
audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1)
except Exception as e:
return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio"
# --- Bước 3: Đặc trưng Text và Dự đoán Model B ---
try:
inputs_text = text_tokenizer(
transcribed_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
).to(device)
with torch.no_grad():
# 3a. Đặc trưng Text (cho Model A)
text_outputs = text_feature_extractor(**inputs_text)
text_feat_A = text_outputs.pooler_output
# 3b. Dự đoán Model B
output_B = model_B(inputs_text['input_ids'], inputs_text['attention_mask'])
pred_idx_B = torch.argmax(output_B, dim=1).item()
pred_label_B = LABELS_B.get(pred_idx_B, f"Lỗi Nhãn B ({pred_idx_B})")
except Exception as e:
return f"[Lỗi xử lý text: {e}]", "Lỗi Text", "Lỗi Text", "Lỗi Text"
# --- Bước 4: Dự đoán Model A ---
try:
with torch.no_grad():
output_A = model_A(text_feat_A, audio_feat_A)
pred_idx_A = torch.argmax(output_A, dim=1).item()
pred_label_A = LABELS_A.get(pred_idx_A, f"Lỗi Nhãn A ({pred_idx_A})")
except Exception as e:
return transcribed_text, "Lỗi Model A", pred_label_B, f"[Lỗi Model A: {e}]"
# --- Bước 5: Kết hợp Fuzzy Logic ---
final_prediction = fuzzy_rules.get((pred_label_A, pred_label_B), "Không có luật")
# Trả về các giá trị cho các ô output của Gradio
return transcribed_text, pred_label_A, pred_label_B, final_prediction
# --- 3. Xây dựng Giao diện Gradio ---
print("Đang xây dựng giao diện Gradio...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Ứng dụng Phân tích Cảm xúc Đa phương tiện")
gr.Markdown("Tải lên một tệp âm thanh (.wav, .mp3, v.v.) **hoặc ghi âm trực tiếp** để dự đoán cảm xúc.")
with gr.Row():
with gr.Column(scale=2):
# === BỔ SUNG TÍNH NĂNG ===
# Thêm "microphone" vào sources để cho phép ghi âm
audio_in = gr.Audio(
sources=["upload", "microphone"], # Cho phép cả tải lên và ghi âm
type="numpy",
label="Tải lên tệp âm thanh hoặc Ghi âm"
)
submit_btn = gr.Button("Phân tích", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### Kết quả Phân tích")
# Các ô output
text_out = gr.Textbox(label="Văn bản được nhận diện (STT)")
final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)")
with gr.Accordion("Xem chi tiết dự đoán của từng mô hình", open=False):
pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)")
pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)")
# Liên kết nút bấm với hàm dự đoán
submit_btn.click(
fn=predict_sentiment,
inputs=audio_in,
outputs=[text_out, pred_A_out, pred_B_out, final_pred_out]
)
gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.")
print("Đang khởi chạy demo...")
demo.launch() # Không cần (share=True) khi chạy trên Spaces