wav2vec2 / models.py
ThanhNguyen1811's picture
Upload 6 files
da78c2f verified
raw
history blame
1.44 kB
import torch
import torch.nn as nn
from transformers import AutoModel
# Kiến trúc mô hình A (Multimodal)
class MultimodalClassifier(nn.Module):
def __init__(self, num_classes, text_feature_dim=768, audio_feature_dim=768, hidden_dim=512):
super(MultimodalClassifier, self).__init__()
self.fc1 = nn.Linear(text_feature_dim + audio_feature_dim, hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, text_features, audio_features):
combined_features = torch.cat((text_features, audio_features), dim=1)
x = self.fc1(combined_features)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
# Kiến trúc mô hình B (Text-only)
class TextClassifier(nn.Module):
def __init__(self, n_classes):
super(TextClassifier, self).__init__()
# Load mô hình nền khi khởi tạo class
self.bert = AutoModel.from_pretrained("vinai/phobert-base")
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
output = self.drop(pooled_output)
return self.out(output)