feat: updated modeling_bert.py to allow MLM-only training
Browse files- modeling_bert.py +19 -15
modeling_bert.py
CHANGED
|
@@ -494,24 +494,28 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
| 494 |
)
|
| 495 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 496 |
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
|
|
|
| 510 |
next_sentence_loss = self.nsp_loss(
|
| 511 |
rearrange(seq_relationship_score, "... t -> (...) t"),
|
| 512 |
rearrange(next_sentence_label, "... -> (...)"),
|
| 513 |
-
)
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
return BertForPreTrainingOutput(
|
| 517 |
loss=total_loss,
|
|
|
|
| 494 |
)
|
| 495 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 496 |
|
| 497 |
+
if (
|
| 498 |
+
self.dense_seq_output and labels is not None
|
| 499 |
+
): # prediction_scores are already flattened
|
| 500 |
+
masked_lm_loss = self.mlm_loss(
|
| 501 |
+
prediction_scores, labels.flatten()[masked_token_idx]
|
| 502 |
+
).float()
|
| 503 |
+
elif labels is not None:
|
| 504 |
+
masked_lm_loss = self.mlm_loss(
|
| 505 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
| 506 |
+
rearrange(labels, "... -> (...)"),
|
| 507 |
+
).float()
|
| 508 |
+
else:
|
| 509 |
+
masked_lm_loss = 0
|
| 510 |
+
if next_sentence_label is not None:
|
| 511 |
next_sentence_loss = self.nsp_loss(
|
| 512 |
rearrange(seq_relationship_score, "... t -> (...) t"),
|
| 513 |
rearrange(next_sentence_label, "... -> (...)"),
|
| 514 |
+
).float()
|
| 515 |
+
else:
|
| 516 |
+
next_sentence_loss = 0
|
| 517 |
+
|
| 518 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
| 519 |
|
| 520 |
return BertForPreTrainingOutput(
|
| 521 |
loss=total_loss,
|