| from enum import Enum | |
| from typing import Iterable, Dict | |
| import torch.nn.functional as F | |
| from torch import nn, Tensor | |
| from sentence_transformers.SentenceTransformer import SentenceTransformer | |
| class SiameseDistanceMetric(Enum): | |
| """ | |
| The metric for the contrastive loss | |
| """ | |
| EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) | |
| MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) | |
| COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y) | |
| class ContrastiveLoss(nn.Module): | |
| """ | |
| Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the | |
| two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. | |
| Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | |
| :param model: SentenceTransformer model | |
| :param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used | |
| :param margin: Negative samples (label == 0) should have a distance of at least the margin value. | |
| :param size_average: Average by the size of the mini-batch. | |
| Example:: | |
| from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample | |
| from torch.utils.data import DataLoader | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| train_examples = [ | |
| InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), | |
| InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] | |
| train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) | |
| train_loss = losses.ContrastiveLoss(model=model) | |
| model.fit([(train_dataloader, train_loss)], show_progress_bar=True) | |
| """ | |
| def __init__(self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True): | |
| super(ContrastiveLoss, self).__init__() | |
| self.distance_metric = distance_metric | |
| self.margin = margin | |
| self.model = model | |
| self.size_average = size_average | |
| def get_config_dict(self): | |
| distance_metric_name = self.distance_metric.__name__ | |
| for name, value in vars(SiameseDistanceMetric).items(): | |
| if value == self.distance_metric: | |
| distance_metric_name = "SiameseDistanceMetric.{}".format(name) | |
| break | |
| return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average} | |
| def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): | |
| reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] | |
| assert len(reps) == 2 | |
| rep_anchor, rep_other = reps | |
| distances = self.distance_metric(rep_anchor, rep_other) | |
| losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)) | |
| return losses.mean() if self.size_average else losses.sum() | |