from typing import Any import torch from torch import nn from transformers import ( PreTrainedModel, XLMRobertaConfig, XLMRobertaModel, ) from .configuration_comet import CometModelConfig class Encoder(nn.Module): """Encoder module based on XLMRoberta.""" def __init__(self): super().__init__() self.model = XLMRobertaModel( config=XLMRobertaConfig.from_pretrained("microsoft/infoxlm-large"), add_pooling_layer=False, ) self.model.encoder.output_hidden_states = True def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Any ) -> dict[str, Any]: return self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=False, )[-1] @property def num_layers(self) -> int: """Number of model layers available.""" return self.model.config.num_hidden_layers + 1 @property def output_units(self) -> int: """Max number of tokens the encoder handles.""" return self.model.config.hidden_size class LayerwiseAttention(nn.Module): """Module that applies attention across model layers.""" def __init__( self, num_layers: int, layer_weights: list[float] | None = None, ) -> None: super().__init__() layer_weights = layer_weights or [0.0] * num_layers self.scalar_parameters = nn.ParameterList( [ nn.Parameter(torch.HalfTensor([layer_weights[i]]), requires_grad=True) for i in range(num_layers) ] ) self.weight = nn.Parameter(torch.HalfTensor([1.0]), requires_grad=True) def forward( self, tensors: list[torch.Tensor], mask: torch.Tensor, ) -> torch.Tensor: weights = torch.cat([parameter for parameter in self.scalar_parameters]) normed_weights = torch.softmax(weights, dim=0) normed_weights = torch.split(normed_weights, split_size_or_sections=1) return self.weight * sum( weight * tensor for weight, tensor in zip(normed_weights, tensors) ) class Estimator(nn.Module): """Feed-forward estimator module.""" def _get_activation(self, activation: str) -> nn.Module: """Get activation function by name.""" if hasattr(nn, activation.title()): return getattr(nn, activation.title())() raise ValueError(f"{activation} is not a valid activation function!") def __init__( self, in_dim: int, out_dim: int = 1, hidden_sizes: list[int] = [3072, 1024], activations: str = "Tanh", dropout: float = 0.1, ) -> None: super().__init__() modules: list[nn.Module] = [] # First layer modules.append(nn.Linear(in_dim, hidden_sizes[0])) modules.append(self._get_activation(activations)) modules.append(nn.Dropout(dropout)) # Hidden layers for i in range(1, len(hidden_sizes)): modules.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) modules.append(self._get_activation(activations)) modules.append(nn.Dropout(dropout)) # Output layer modules.append(nn.Linear(hidden_sizes[-1], int(out_dim))) self.ff = nn.Sequential(*modules) def forward(self, in_features: torch.Tensor) -> torch.Tensor: return self.ff(in_features) class CometModel(PreTrainedModel): config_class = CometModelConfig _no_split_modules = ["Encoder", "LayerwiseAttention", "Estimator"] def __init__(self, config: CometModelConfig) -> None: super().__init__(config) self.encoder = Encoder() self.layerwise_attention = LayerwiseAttention( num_layers=self.encoder.num_layers ) self.estimator = Estimator( in_dim=self.encoder.output_units, hidden_sizes=config.hidden_sizes, activations=config.activations, dropout=config.dropout, ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor | None = None, **kwargs: Any, ) -> torch.Tensor: encoder_out = self.encoder( input_ids, attention_mask, token_type_ids=token_type_ids, ) embeddings = self.layerwise_attention( encoder_out, attention_mask, ) # Use CLS token as sentence embedding embedding = embeddings[:, 0, :] return self.estimator(embedding).view(-1)