|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertPreTrainedModel, BertOnlyMLMHead |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
|
MaskedLMOutput, |
|
|
SequenceClassifierOutput, |
|
|
) |
|
|
from transformers.utils import auto_docstring, logging |
|
|
|
|
|
from .configuration_bert_hash import BertHashConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class BertHashTokens(nn.Module): |
|
|
""" |
|
|
Module that embeds token vocabulary to an intermediate embeddings layer then projects those embeddings to the |
|
|
hidden size. |
|
|
|
|
|
The number of projections is like a hash. Setting the projections parameter to 5 is like generating a |
|
|
160-bit hash (5 x float32) for each token. That hash is then projected to the hidden size. |
|
|
|
|
|
This significantly reduces the number of parameters necessary for token embeddings. |
|
|
|
|
|
For example: |
|
|
Standard token embeddings: |
|
|
30,522 (vocab size) x 768 (hidden size) = 23,440,896 parameters |
|
|
23,440,896 x 4 (float32) = 93,763,584 bytes |
|
|
|
|
|
Hash token embeddings: |
|
|
30,522 (vocab size) x 5 (hash buckets) + 5 x 768 (projection matrix)= 156,450 parameters |
|
|
156,450 x 4 (float32) = 625,800 bytes |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.projections, padding_idx=config.pad_token_id) |
|
|
|
|
|
|
|
|
self.projections = nn.Linear(config.projections, config.hidden_size) |
|
|
|
|
|
def forward(self, input_ids): |
|
|
|
|
|
return self.projections(self.embeddings(input_ids)) |
|
|
|
|
|
|
|
|
class BertHashEmbeddings(nn.Module): |
|
|
"""Construct the embeddings from word, position and token_type embeddings.""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.word_embeddings = BertHashTokens(config) |
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
|
self.register_buffer( |
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
|
) |
|
|
self.register_buffer( |
|
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
past_key_values_length: int = 0, |
|
|
) -> torch.Tensor: |
|
|
if input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
else: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
|
|
seq_length = input_shape[1] |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self, "token_type_ids"): |
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
|
if self.position_embedding_type == "absolute": |
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
embeddings += position_embeddings |
|
|
embeddings = self.LayerNorm(embeddings) |
|
|
embeddings = self.dropout(embeddings) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of |
|
|
cross-attention is added between the self-attention layers, following the architecture described in [Attention is |
|
|
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, |
|
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. |
|
|
|
|
|
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set |
|
|
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and |
|
|
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. |
|
|
""" |
|
|
) |
|
|
class BertHashModel(BertPreTrainedModel): |
|
|
config_class = BertHashConfig |
|
|
|
|
|
_no_split_modules = ["BertEmbeddings", "BertLayer"] |
|
|
|
|
|
def __init__(self, config, add_pooling_layer=True): |
|
|
r""" |
|
|
add_pooling_layer (bool, *optional*, defaults to `True`): |
|
|
Whether to add a pooling layer |
|
|
""" |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embeddings = BertHashEmbeddings(config) |
|
|
self.encoder = BertEncoder(config) |
|
|
|
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None |
|
|
|
|
|
self.attn_implementation = config._attn_implementation |
|
|
self.position_embedding_type = config.position_embedding_type |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embeddings.word_embeddings.embeddings |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embeddings.word_embeddings.embeddings = value |
|
|
|
|
|
def _prune_heads(self, heads_to_prune): |
|
|
""" |
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
|
class PreTrainedModel |
|
|
""" |
|
|
for layer, heads in heads_to_prune.items(): |
|
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if self.config.is_decoder: |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
else: |
|
|
use_cache = False |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
input_shape = input_ids.size() |
|
|
elif inputs_embeds is not None: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
else: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
batch_size, seq_length = input_shape |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
past_key_values_length = 0 |
|
|
if past_key_values is not None: |
|
|
past_key_values_length = ( |
|
|
past_key_values[0][0].shape[-2] |
|
|
if not isinstance(past_key_values, Cache) |
|
|
else past_key_values.get_seq_length() |
|
|
) |
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self.embeddings, "token_type_ids"): |
|
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
embedding_output = self.embeddings( |
|
|
input_ids=input_ids, |
|
|
position_ids=position_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
past_key_values_length=past_key_values_length, |
|
|
) |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) |
|
|
|
|
|
use_sdpa_attention_masks = ( |
|
|
self.attn_implementation == "sdpa" |
|
|
and self.position_embedding_type == "absolute" |
|
|
and head_mask is None |
|
|
and not output_attentions |
|
|
) |
|
|
|
|
|
|
|
|
if use_sdpa_attention_masks and attention_mask.dim() == 2: |
|
|
|
|
|
|
|
|
if self.config.is_decoder: |
|
|
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
|
attention_mask, |
|
|
input_shape, |
|
|
embedding_output, |
|
|
past_key_values_length, |
|
|
) |
|
|
else: |
|
|
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
|
|
attention_mask, embedding_output.dtype, tgt_len=seq_length |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
|
|
|
|
|
|
|
|
|
|
|
if self.config.is_decoder and encoder_hidden_states is not None: |
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
|
if encoder_attention_mask is None: |
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
|
|
|
|
|
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: |
|
|
|
|
|
|
|
|
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
|
|
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length |
|
|
) |
|
|
else: |
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
|
else: |
|
|
encoder_extended_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
embedding_output, |
|
|
attention_mask=extended_attention_mask, |
|
|
head_mask=head_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
sequence_output = encoder_outputs[0] |
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
|
|
if not return_dict: |
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions( |
|
|
last_hidden_state=sequence_output, |
|
|
pooler_output=pooled_output, |
|
|
past_key_values=encoder_outputs.past_key_values, |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
cross_attentions=encoder_outputs.cross_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BertHashForMaskedLM(BertPreTrainedModel): |
|
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
|
config_class = BertHashConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
if config.is_decoder: |
|
|
logger.warning( |
|
|
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " |
|
|
"bi-directional self-attention." |
|
|
) |
|
|
|
|
|
self.bert = BertHashModel(config, add_pooling_layer=False) |
|
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the |
|
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
|
""" |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
prediction_scores = self.cls(sequence_output) |
|
|
|
|
|
masked_lm_loss = None |
|
|
if labels is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (prediction_scores,) + outputs[2:] |
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
|
|
return MaskedLMOutput( |
|
|
loss=masked_lm_loss, |
|
|
logits=prediction_scores, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): |
|
|
input_shape = input_ids.shape |
|
|
effective_batch_size = input_shape[0] |
|
|
|
|
|
|
|
|
if self.config.pad_token_id is None: |
|
|
raise ValueError("The PAD token should be defined for generation") |
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) |
|
|
dummy_token = torch.full( |
|
|
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device |
|
|
) |
|
|
input_ids = torch.cat([input_ids, dummy_token], dim=1) |
|
|
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
|
|
@classmethod |
|
|
def can_generate(cls) -> bool: |
|
|
""" |
|
|
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a |
|
|
`prepare_inputs_for_generation` method. |
|
|
""" |
|
|
return False |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled |
|
|
output) e.g. for GLUE tasks. |
|
|
""" |
|
|
) |
|
|
class BertHashForSequenceClassification(BertPreTrainedModel): |
|
|
config_class = BertHashConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.config = config |
|
|
|
|
|
self.bert = BertHashModel(config) |
|
|
classifier_dropout = ( |
|
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
|
) |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.problem_type is None: |
|
|
if self.num_labels == 1: |
|
|
self.config.problem_type = "regression" |
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
|
self.config.problem_type = "single_label_classification" |
|
|
else: |
|
|
self.config.problem_type = "multi_label_classification" |
|
|
|
|
|
if self.config.problem_type == "regression": |
|
|
loss_fct = MSELoss() |
|
|
if self.num_labels == 1: |
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
|
else: |
|
|
loss = loss_fct(logits, labels) |
|
|
elif self.config.problem_type == "single_label_classification": |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
elif self.config.problem_type == "multi_label_classification": |
|
|
loss_fct = BCEWithLogitsLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|