fix: try to skip initialization of task type embeddings
Browse files- modeling_bert.py +1 -1
modeling_bert.py
CHANGED
|
@@ -145,7 +145,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
| 145 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 146 |
if module.bias is not None:
|
| 147 |
nn.init.zeros_(module.bias)
|
| 148 |
-
elif isinstance(module, nn.Embedding) and not module
|
| 149 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 150 |
if module.padding_idx is not None:
|
| 151 |
nn.init.zeros_(module.weight[module.padding_idx])
|
|
|
|
| 145 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 146 |
if module.bias is not None:
|
| 147 |
nn.init.zeros_(module.bias)
|
| 148 |
+
elif isinstance(module, nn.Embedding) and not getattr(module, "skip_init", False):
|
| 149 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 150 |
if module.padding_idx is not None:
|
| 151 |
nn.init.zeros_(module.weight[module.padding_idx])
|