fix BertForMaskedLM
Browse files- modeling_bert.py +8 -8
modeling_bert.py
CHANGED
|
@@ -752,18 +752,18 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
| 752 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 753 |
|
| 754 |
if (
|
| 755 |
-
|
| 756 |
): # prediction_scores are already flattened
|
| 757 |
masked_lm_loss = self.mlm_loss(
|
| 758 |
prediction_scores, labels.flatten()[masked_token_idx]
|
| 759 |
).float()
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
|
| 768 |
return BertForPreTrainingOutput(
|
| 769 |
loss=masked_lm_loss,
|
|
|
|
| 752 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 753 |
|
| 754 |
if (
|
| 755 |
+
self.dense_seq_output and labels is not None
|
| 756 |
): # prediction_scores are already flattened
|
| 757 |
masked_lm_loss = self.mlm_loss(
|
| 758 |
prediction_scores, labels.flatten()[masked_token_idx]
|
| 759 |
).float()
|
| 760 |
+
elif labels is not None:
|
| 761 |
+
masked_lm_loss = self.mlm_loss(
|
| 762 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
| 763 |
+
rearrange(labels, "... -> (...)"),
|
| 764 |
+
).float()
|
| 765 |
+
else:
|
| 766 |
+
raise ValueError('MLM labels must not be None')
|
| 767 |
|
| 768 |
return BertForPreTrainingOutput(
|
| 769 |
loss=masked_lm_loss,
|