| """ | |
| Tests the correct computation of evaluation scores from BinaryClassificationEvaluator | |
| """ | |
| from sentence_transformers import SentenceTransformer, evaluation, util, losses, LoggingHandler | |
| import logging | |
| import unittest | |
| from sklearn.metrics import f1_score, accuracy_score | |
| import numpy as np | |
| import gzip | |
| import csv | |
| from sentence_transformers import InputExample | |
| from torch.utils.data import DataLoader | |
| import os | |
| class EvaluatorTest(unittest.TestCase): | |
| def test_BinaryClassificationEvaluator_find_best_f1_and_threshold(self): | |
| """Tests that the F1 score for the computed threshold is correct""" | |
| y_true = np.random.randint(0, 2, 1000) | |
| y_pred_cosine = np.random.randn(1000) | |
| best_f1, best_precision, best_recall, threshold = evaluation.BinaryClassificationEvaluator.find_best_f1_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True) | |
| y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine] | |
| sklearn_f1score = f1_score(y_true, y_pred_labels) | |
| assert np.abs(best_f1 - sklearn_f1score) < 1e-6 | |
| def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self): | |
| """Tests that the Acc score for the computed threshold is correct""" | |
| y_true = np.random.randint(0, 2, 1000) | |
| y_pred_cosine = np.random.randn(1000) | |
| max_acc, threshold = evaluation.BinaryClassificationEvaluator.find_best_acc_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True) | |
| y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine] | |
| sklearn_acc = accuracy_score(y_true, y_pred_labels) | |
| assert np.abs(max_acc - sklearn_acc) < 1e-6 | |
| def test_LabelAccuracyEvaluator(self): | |
| """Tests that the LabelAccuracyEvaluator can be loaded correctly""" | |
| model = SentenceTransformer('paraphrase-distilroberta-base-v1') | |
| nli_dataset_path = 'datasets/AllNLI.tsv.gz' | |
| if not os.path.exists(nli_dataset_path): | |
| util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path) | |
| label2int = {"contradiction": 0, "entailment": 1, "neutral": 2} | |
| dev_samples = [] | |
| with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn: | |
| reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE) | |
| for row in reader: | |
| if row['split'] == 'train': | |
| label_id = label2int[row['label']] | |
| dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label_id)) | |
| if len(dev_samples) >= 100: | |
| break | |
| train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int)) | |
| dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16) | |
| evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss) | |
| acc = evaluator(model) | |
| assert acc > 0.2 | |
| def test_ParaphraseMiningEvaluator(self): | |
| """Tests that the ParaphraseMiningEvaluator can be loaded""" | |
| model = SentenceTransformer('paraphrase-distilroberta-base-v1') | |
| sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"} | |
| data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)]) | |
| score = data_eval(model) | |
| assert score > 0.99 |