Spaces:
Runtime error
Runtime error
| import ast | |
| import argparse | |
| from graph_matching import split_to_edges, get_tokens, get_bleu_rouge, get_bert_score, get_ged, get_triple_match_f1, get_graph_match_accuracy | |
| if __name__ == '__main__': | |
| # 解析命令行参数 | |
| parser = argparse.ArgumentParser() | |
| # 预测文件 | |
| parser.add_argument("--pred_file", default=None, type=str, required=True) | |
| parser.add_argument("--gold_file", default=None, type=str, required=True) | |
| # 解析参数 | |
| args = parser.parse_args() | |
| gold_graphs = [] | |
| # 读取文件 | |
| with open(args.gold_file, 'r', encoding="utf-8") as f: | |
| for line in f.readlines(): | |
| gold_graphs.append(ast.literal_eval(line.strip())) | |
| # 读取预测文件 | |
| pred_graphs = [] | |
| with open(args.pred_file, 'r', encoding="utf-8") as f: | |
| for line in f.readlines(): | |
| #print(line) | |
| pred_graphs.append(ast.literal_eval(line.strip())) | |
| # 断言两个文件长度相同 | |
| assert len(gold_graphs) == len(pred_graphs) | |
| # 评估三元组匹配F1分数 | |
| triple_match_f1 = get_triple_match_f1(gold_graphs, pred_graphs) | |
| # 评估图匹配准确率 | |
| graph_match_accuracy = get_graph_match_accuracy(pred_graphs, gold_graphs) | |
| # 计算GED | |
| overall_ged = 0. | |
| for (gold, pred) in zip(gold_graphs, pred_graphs): | |
| ged = get_ged(gold, pred) | |
| overall_ged += ged | |
| # 评估图匹配 | |
| gold_edges = split_to_edges(gold_graphs) | |
| pred_edges = split_to_edges(pred_graphs) | |
| # 获取tokens | |
| gold_tokens, pred_tokens = get_tokens(gold_edges, pred_edges) | |
| # 评估BLEU和ROUGE | |
| precisions_rouge, recalls_rouge, f1s_rouge, precisions_bleu, recalls_bleu, f1s_bleu = get_bleu_rouge( | |
| gold_tokens, pred_tokens, gold_edges, pred_edges) | |
| # 评估BERT分数 | |
| precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges) | |
| print(f'Triple Match F1 Score: {triple_match_f1:.4f}\n') | |
| print(f'Graph Match F1 Score: {graph_match_accuracy:.4f}\n') | |
| print(f'G-BLEU Precision: {precisions_bleu.sum() / len(gold_graphs):.4f}') | |
| print(f'G-BLEU Recall: {recalls_bleu.sum() / len(gold_graphs):.4f}') | |
| print(f'G-BLEU F1: {f1s_bleu.sum() / len(gold_graphs):.4f}\n') | |
| print(f'G-Rouge Precision: {precisions_rouge.sum() / len(gold_graphs):.4f}') | |
| print(f'G-Rouge Recall Score: {recalls_rouge.sum() / len(gold_graphs):.4f}') | |
| print(f'G-Rouge F1 Score: {f1s_rouge.sum() / len(gold_graphs):.4f}\n') | |
| print(f'G-BertScore Precision Score: {precisions_BS.sum() / len(gold_graphs):.4f}') | |
| print(f'G-BertScore Recall Score: {recalls_BS.sum() / len(gold_graphs):.4f}') | |
| print(f'G-BertScore F1 Score: {f1s_BS.sum() / len(gold_graphs):.4f}\n') | |
| print(f'Graph Edit Distance (GED): {overall_ged / len(gold_graphs):.4f}\n') | |