litagin commited on
Commit
7bdd676
·
verified ·
1 Parent(s): 165fe2c

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration.py +11 -0
  2. modeling.py +135 -0
configuration.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Config
2
+
3
+
4
+ class Wav2Vec2ForEmotionClassificationConfig(Wav2Vec2Config):
5
+ model_type = "wav2vec2_for_emotion_classification"
6
+
7
+ def __init__(
8
+ self,
9
+ **kwargs,
10
+ ):
11
+ super().__init__(**kwargs)
modeling.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.activations import get_activation
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
8
+ Wav2Vec2Model,
9
+ Wav2Vec2PreTrainedModel,
10
+ )
11
+
12
+ from configuration import Wav2Vec2ForEmotionClassificationConfig
13
+
14
+ _HIDDEN_STATES_START_POSITION = 2
15
+
16
+
17
+ class ClassificationHead(nn.Module):
18
+ def __init__(self, config):
19
+ super().__init__()
20
+ self.dense = nn.Linear(config.hidden_size, config.classifier_proj_size)
21
+ self.layer_norm = nn.LayerNorm(config.classifier_proj_size)
22
+ self.dropout = nn.Dropout(config.final_dropout)
23
+ self.out_proj = nn.Linear(config.classifier_proj_size, config.num_labels)
24
+ self.activation = get_activation(config.head_activation)
25
+
26
+ def forward(self, features, **kwargs):
27
+ x = features
28
+ x = self.dense(x)
29
+ x = self.layer_norm(x)
30
+ x = self.activation(x)
31
+ x = self.dropout(x)
32
+ x = self.out_proj(x)
33
+ return x
34
+
35
+
36
+ class Wav2Vec2ForEmotionClassification(Wav2Vec2PreTrainedModel):
37
+ """Speech emotion classifier."""
38
+
39
+ config_class = Wav2Vec2ForEmotionClassificationConfig
40
+
41
+ def __init__(self, config, counts: Optional[dict[int, int]] = None):
42
+ super().__init__(config)
43
+
44
+ self.config = config
45
+ self.wav2vec2 = Wav2Vec2Model(config)
46
+ self.classifier = ClassificationHead(config)
47
+ num_layers = (
48
+ config.num_hidden_layers + 1
49
+ ) # transformer layers + input embeddings
50
+ if config.use_weighted_layer_sum:
51
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
52
+ self.init_weights()
53
+
54
+ # counts が与えられている場合、クラスの重みを計算
55
+ if counts is not None:
56
+ print(f"Using class weights: {counts}")
57
+ counts_list = [counts[i] for i in range(config.num_labels)]
58
+ counts_tensor = torch.tensor(
59
+ counts_list, dtype=torch.float, device="cuda:0"
60
+ )
61
+ total_samples = counts_tensor.sum()
62
+ class_weights = total_samples / (config.num_labels * counts_tensor)
63
+ # 重みを正規化(任意)
64
+ class_weights = class_weights / class_weights.sum() * config.num_labels
65
+ self.class_weights = class_weights
66
+ else:
67
+ self.class_weights = None # counts がない場合は None に設定
68
+
69
+ def forward(
70
+ self,
71
+ input_values: Optional[torch.Tensor],
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ output_attentions: Optional[bool] = None,
74
+ output_hidden_states: Optional[bool] = None,
75
+ return_dict: Optional[bool] = None,
76
+ labels: Optional[torch.Tensor] = None,
77
+ ):
78
+ return_dict = (
79
+ return_dict if return_dict is not None else self.config.use_return_dict
80
+ )
81
+ output_hidden_states = (
82
+ True if self.config.use_weighted_layer_sum else output_hidden_states
83
+ )
84
+ # print(f"output_hidden_states: {output_hidden_states}")
85
+
86
+ outputs = self.wav2vec2(
87
+ input_values,
88
+ attention_mask=attention_mask,
89
+ output_attentions=output_attentions,
90
+ output_hidden_states=output_hidden_states,
91
+ return_dict=return_dict,
92
+ )
93
+
94
+ if self.config.use_weighted_layer_sum:
95
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
96
+ hidden_states = torch.stack(hidden_states, dim=1)
97
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
98
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
99
+ else:
100
+ hidden_states = outputs[0]
101
+
102
+ if attention_mask is None:
103
+ pooled_output = hidden_states.mean(dim=1)
104
+ else:
105
+ padding_mask = self._get_feature_vector_attention_mask(
106
+ hidden_states.shape[1], attention_mask
107
+ )
108
+ hidden_states[~padding_mask] = 0.0
109
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(
110
+ -1, 1
111
+ )
112
+
113
+ logits = self.classifier(pooled_output)
114
+
115
+ loss = None
116
+ if labels is not None:
117
+ # CrossEntropyLoss に重みを適用(class_weights が None でも機能する)
118
+ loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
119
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
120
+
121
+ return SequenceClassifierOutput(
122
+ loss=loss,
123
+ logits=logits,
124
+ hidden_states=outputs.hidden_states,
125
+ attentions=outputs.attentions,
126
+ )
127
+
128
+ def freeze_base_model(self):
129
+ r"""Freeze base model."""
130
+ for param in self.wav2vec2.parameters():
131
+ param.requires_grad = False
132
+
133
+ def freeze_feature_encoder(self):
134
+ r"""Freeze feature extractor."""
135
+ self.wav2vec2.freeze_feature_encoder()