feat: added head_mask
Browse files- modeling_bert.py +4 -0
modeling_bert.py
CHANGED
|
@@ -379,12 +379,16 @@ class BertModel(BertPreTrainedModel):
|
|
| 379 |
task_type_ids=None,
|
| 380 |
attention_mask=None,
|
| 381 |
masked_tokens_mask=None,
|
|
|
|
| 382 |
):
|
| 383 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
| 384 |
we only want the output for the masked tokens. This means that we only compute the last
|
| 385 |
layer output for these tokens.
|
| 386 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 387 |
"""
|
|
|
|
|
|
|
|
|
|
| 388 |
hidden_states = self.embeddings(
|
| 389 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 390 |
)
|
|
|
|
| 379 |
task_type_ids=None,
|
| 380 |
attention_mask=None,
|
| 381 |
masked_tokens_mask=None,
|
| 382 |
+
head_mask=None,
|
| 383 |
):
|
| 384 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
| 385 |
we only want the output for the masked tokens. This means that we only compute the last
|
| 386 |
layer output for these tokens.
|
| 387 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 388 |
"""
|
| 389 |
+
if head_mask is not None:
|
| 390 |
+
raise NotImplementedError('Masking heads is not supported')
|
| 391 |
+
|
| 392 |
hidden_states = self.embeddings(
|
| 393 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 394 |
)
|