import ast import argparse from graph_matching import split_to_edges, get_tokens, get_bleu_rouge, get_bert_score, get_ged, get_triple_match_f1, get_graph_match_accuracy if __name__ == '__main__': # 解析命令行参数 parser = argparse.ArgumentParser() # 预测文件 parser.add_argument("--pred_file", default=None, type=str, required=True) parser.add_argument("--gold_file", default=None, type=str, required=True) # 解析参数 args = parser.parse_args() gold_graphs = [] # 读取文件 with open(args.gold_file, 'r', encoding="utf-8") as f: for line in f.readlines(): gold_graphs.append(ast.literal_eval(line.strip())) # 读取预测文件 pred_graphs = [] with open(args.pred_file, 'r', encoding="utf-8") as f: for line in f.readlines(): #print(line) pred_graphs.append(ast.literal_eval(line.strip())) # 断言两个文件长度相同 assert len(gold_graphs) == len(pred_graphs) # 评估三元组匹配F1分数 triple_match_f1 = get_triple_match_f1(gold_graphs, pred_graphs) # 评估图匹配准确率 graph_match_accuracy = get_graph_match_accuracy(pred_graphs, gold_graphs) # 计算GED overall_ged = 0. for (gold, pred) in zip(gold_graphs, pred_graphs): ged = get_ged(gold, pred) overall_ged += ged # 评估图匹配 gold_edges = split_to_edges(gold_graphs) pred_edges = split_to_edges(pred_graphs) # 获取tokens gold_tokens, pred_tokens = get_tokens(gold_edges, pred_edges) # 评估BLEU和ROUGE precisions_rouge, recalls_rouge, f1s_rouge, precisions_bleu, recalls_bleu, f1s_bleu = get_bleu_rouge( gold_tokens, pred_tokens, gold_edges, pred_edges) # 评估BERT分数 precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges) print(f'Triple Match F1 Score: {triple_match_f1:.4f}\n') print(f'Graph Match F1 Score: {graph_match_accuracy:.4f}\n') print(f'G-BLEU Precision: {precisions_bleu.sum() / len(gold_graphs):.4f}') print(f'G-BLEU Recall: {recalls_bleu.sum() / len(gold_graphs):.4f}') print(f'G-BLEU F1: {f1s_bleu.sum() / len(gold_graphs):.4f}\n') print(f'G-Rouge Precision: {precisions_rouge.sum() / len(gold_graphs):.4f}') print(f'G-Rouge Recall Score: {recalls_rouge.sum() / len(gold_graphs):.4f}') print(f'G-Rouge F1 Score: {f1s_rouge.sum() / len(gold_graphs):.4f}\n') print(f'G-BertScore Precision Score: {precisions_BS.sum() / len(gold_graphs):.4f}') print(f'G-BertScore Recall Score: {recalls_BS.sum() / len(gold_graphs):.4f}') print(f'G-BertScore F1 Score: {f1s_BS.sum() / len(gold_graphs):.4f}\n') print(f'Graph Edit Distance (GED): {overall_ged / len(gold_graphs):.4f}\n')