| import torch | |
| from torch import nn, Tensor | |
| from typing import Union, Tuple, List, Iterable, Dict, Callable | |
| from ..SentenceTransformer import SentenceTransformer | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SoftmaxLoss(nn.Module): | |
| """ | |
| This loss was used in our SBERT publication (https://arxiv.org/abs/1908.10084) to train the SentenceTransformer | |
| model on NLI data. It adds a softmax classifier on top of the output of two transformer networks. | |
| :param model: SentenceTransformer model | |
| :param sentence_embedding_dimension: Dimension of your sentence embeddings | |
| :param num_labels: Number of different labels | |
| :param concatenation_sent_rep: Concatenate vectors u,v for the softmax classifier? | |
| :param concatenation_sent_difference: Add abs(u-v) for the softmax classifier? | |
| :param concatenation_sent_multiplication: Add u*v for the softmax classifier? | |
| :param loss_fct: Optional: Custom pytorch loss function. If not set, uses nn.CrossEntropyLoss() | |
| Example:: | |
| from sentence_transformers import SentenceTransformer, SentencesDataset, losses | |
| from sentence_transformers.readers import InputExample | |
| model = SentenceTransformer('distilbert-base-nli-mean-tokens') | |
| train_examples = [InputExample(texts=['First pair, sent A', 'First pair, sent B'], label=0), | |
| InputExample(texts=['Second Pair, sent A', 'Second Pair, sent B'], label=3)] | |
| train_dataset = SentencesDataset(train_examples, model) | |
| train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) | |
| train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=train_num_labels) | |
| """ | |
| def __init__(self, | |
| model: SentenceTransformer, | |
| sentence_embedding_dimension: int, | |
| num_labels: int, | |
| concatenation_sent_rep: bool = True, | |
| concatenation_sent_difference: bool = True, | |
| concatenation_sent_multiplication: bool = False, | |
| loss_fct: Callable = nn.CrossEntropyLoss()): | |
| super(SoftmaxLoss, self).__init__() | |
| self.model = model | |
| self.num_labels = num_labels | |
| self.concatenation_sent_rep = concatenation_sent_rep | |
| self.concatenation_sent_difference = concatenation_sent_difference | |
| self.concatenation_sent_multiplication = concatenation_sent_multiplication | |
| num_vectors_concatenated = 0 | |
| if concatenation_sent_rep: | |
| num_vectors_concatenated += 2 | |
| if concatenation_sent_difference: | |
| num_vectors_concatenated += 1 | |
| if concatenation_sent_multiplication: | |
| num_vectors_concatenated += 1 | |
| logger.info("Softmax loss: #Vectors concatenated: {}".format(num_vectors_concatenated)) | |
| self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels) | |
| self.loss_fct = loss_fct | |
| def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): | |
| reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] | |
| rep_a, rep_b = reps | |
| vectors_concat = [] | |
| if self.concatenation_sent_rep: | |
| vectors_concat.append(rep_a) | |
| vectors_concat.append(rep_b) | |
| if self.concatenation_sent_difference: | |
| vectors_concat.append(torch.abs(rep_a - rep_b)) | |
| if self.concatenation_sent_multiplication: | |
| vectors_concat.append(rep_a * rep_b) | |
| features = torch.cat(vectors_concat, 1) | |
| output = self.classifier(features) | |
| if labels is not None: | |
| loss = self.loss_fct(output, labels.view(-1)) | |
| return loss | |
| else: | |
| return reps, output | |