| from . import SentenceEvaluator | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import logging | |
| from ..util import batch_to_device | |
| import os | |
| import csv | |
| logger = logging.getLogger(__name__) | |
| class LabelAccuracyEvaluator(SentenceEvaluator): | |
| """ | |
| Evaluate a model based on its accuracy on a labeled dataset | |
| This requires a model with LossFunction.SOFTMAX | |
| The results are written in a CSV. If a CSV already exists, then values are appended. | |
| """ | |
| def __init__(self, dataloader: DataLoader, name: str = "", softmax_model = None, write_csv: bool = True): | |
| """ | |
| Constructs an evaluator for the given dataset | |
| :param dataloader: | |
| the data for the evaluation | |
| """ | |
| self.dataloader = dataloader | |
| self.name = name | |
| self.softmax_model = softmax_model | |
| if name: | |
| name = "_"+name | |
| self.write_csv = write_csv | |
| self.csv_file = "accuracy_evaluation"+name+"_results.csv" | |
| self.csv_headers = ["epoch", "steps", "accuracy"] | |
| def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: | |
| model.eval() | |
| total = 0 | |
| correct = 0 | |
| if epoch != -1: | |
| if steps == -1: | |
| out_txt = " after epoch {}:".format(epoch) | |
| else: | |
| out_txt = " in epoch {} after {} steps:".format(epoch, steps) | |
| else: | |
| out_txt = ":" | |
| logger.info("Evaluation on the "+self.name+" dataset"+out_txt) | |
| self.dataloader.collate_fn = model.smart_batching_collate | |
| for step, batch in enumerate(self.dataloader): | |
| features, label_ids = batch | |
| for idx in range(len(features)): | |
| features[idx] = batch_to_device(features[idx], model.device) | |
| label_ids = label_ids.to(model.device) | |
| with torch.no_grad(): | |
| _, prediction = self.softmax_model(features, labels=None) | |
| total += prediction.size(0) | |
| correct += torch.argmax(prediction, dim=1).eq(label_ids).sum().item() | |
| accuracy = correct/total | |
| logger.info("Accuracy: {:.4f} ({}/{})\n".format(accuracy, correct, total)) | |
| if output_path is not None and self.write_csv: | |
| csv_path = os.path.join(output_path, self.csv_file) | |
| if not os.path.isfile(csv_path): | |
| with open(csv_path, newline='', mode="w", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(self.csv_headers) | |
| writer.writerow([epoch, steps, accuracy]) | |
| else: | |
| with open(csv_path, newline='', mode="a", encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| writer.writerow([epoch, steps, accuracy]) | |
| return accuracy | |