import torch from torch import nn, Tensor from typing import Union, Tuple, List, Iterable, Dict, Callable from ..SentenceTransformer import SentenceTransformer import logging logger = logging.getLogger(__name__) class SoftmaxLoss(nn.Module): """ This loss was used in our SBERT publication (https://arxiv.org/abs/1908.10084) to train the SentenceTransformer model on NLI data. It adds a softmax classifier on top of the output of two transformer networks. :param model: SentenceTransformer model :param sentence_embedding_dimension: Dimension of your sentence embeddings :param num_labels: Number of different labels :param concatenation_sent_rep: Concatenate vectors u,v for the softmax classifier? :param concatenation_sent_difference: Add abs(u-v) for the softmax classifier? :param concatenation_sent_multiplication: Add u*v for the softmax classifier? :param loss_fct: Optional: Custom pytorch loss function. If not set, uses nn.CrossEntropyLoss() Example:: from sentence_transformers import SentenceTransformer, SentencesDataset, losses from sentence_transformers.readers import InputExample model = SentenceTransformer('distilbert-base-nli-mean-tokens') train_examples = [InputExample(texts=['First pair, sent A', 'First pair, sent B'], label=0), InputExample(texts=['Second Pair, sent A', 'Second Pair, sent B'], label=3)] train_dataset = SentencesDataset(train_examples, model) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=train_num_labels) """ def __init__(self, model: SentenceTransformer, sentence_embedding_dimension: int, num_labels: int, concatenation_sent_rep: bool = True, concatenation_sent_difference: bool = True, concatenation_sent_multiplication: bool = False, loss_fct: Callable = nn.CrossEntropyLoss()): super(SoftmaxLoss, self).__init__() self.model = model self.num_labels = num_labels self.concatenation_sent_rep = concatenation_sent_rep self.concatenation_sent_difference = concatenation_sent_difference self.concatenation_sent_multiplication = concatenation_sent_multiplication num_vectors_concatenated = 0 if concatenation_sent_rep: num_vectors_concatenated += 2 if concatenation_sent_difference: num_vectors_concatenated += 1 if concatenation_sent_multiplication: num_vectors_concatenated += 1 logger.info("Softmax loss: #Vectors concatenated: {}".format(num_vectors_concatenated)) self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels) self.loss_fct = loss_fct def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] rep_a, rep_b = reps vectors_concat = [] if self.concatenation_sent_rep: vectors_concat.append(rep_a) vectors_concat.append(rep_b) if self.concatenation_sent_difference: vectors_concat.append(torch.abs(rep_a - rep_b)) if self.concatenation_sent_multiplication: vectors_concat.append(rep_a * rep_b) features = torch.cat(vectors_concat, 1) output = self.classifier(features) if labels is not None: loss = self.loss_fct(output, labels.view(-1)) return loss else: return reps, output