File size: 2,799 Bytes
badcf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import ast
import argparse
from graph_matching import split_to_edges, get_tokens, get_bleu_rouge, get_bert_score, get_ged, get_triple_match_f1, get_graph_match_accuracy


if __name__ == '__main__':
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    # 预测文件
    parser.add_argument("--pred_file", default=None, type=str, required=True)
    parser.add_argument("--gold_file", default=None, type=str, required=True)
    # 解析参数
    args = parser.parse_args()
    
    gold_graphs = []
    # 读取文件
    with open(args.gold_file, 'r', encoding="utf-8") as f:
        for line in f.readlines():
            gold_graphs.append(ast.literal_eval(line.strip()))
			
    # 读取预测文件
    pred_graphs = []
    with open(args.pred_file, 'r', encoding="utf-8") as f:
        for line in f.readlines():
            #print(line)
            pred_graphs.append(ast.literal_eval(line.strip()))
			
    # 断言两个文件长度相同
    assert len(gold_graphs) == len(pred_graphs)
	
    # 评估三元组匹配F1分数
    triple_match_f1 = get_triple_match_f1(gold_graphs, pred_graphs)  

    # 评估图匹配准确率
    graph_match_accuracy = get_graph_match_accuracy(pred_graphs, gold_graphs)

    # 计算GED
    overall_ged = 0.
    for (gold, pred) in zip(gold_graphs, pred_graphs):
        ged = get_ged(gold, pred)
        overall_ged += ged

    # 评估图匹配
    gold_edges = split_to_edges(gold_graphs)
    pred_edges = split_to_edges(pred_graphs)
	
    # 获取tokens
    gold_tokens, pred_tokens = get_tokens(gold_edges, pred_edges)

    # 评估BLEU和ROUGE
    precisions_rouge, recalls_rouge, f1s_rouge, precisions_bleu, recalls_bleu, f1s_bleu = get_bleu_rouge(
        gold_tokens, pred_tokens, gold_edges, pred_edges)

    # 评估BERT分数
    precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges)
    
    print(f'Triple Match F1 Score: {triple_match_f1:.4f}\n')
    print(f'Graph Match F1 Score: {graph_match_accuracy:.4f}\n')
    
    print(f'G-BLEU Precision: {precisions_bleu.sum() / len(gold_graphs):.4f}')
    print(f'G-BLEU Recall: {recalls_bleu.sum() / len(gold_graphs):.4f}')
    print(f'G-BLEU F1: {f1s_bleu.sum() / len(gold_graphs):.4f}\n')

    print(f'G-Rouge Precision: {precisions_rouge.sum() / len(gold_graphs):.4f}')
    print(f'G-Rouge Recall Score: {recalls_rouge.sum() / len(gold_graphs):.4f}')
    print(f'G-Rouge F1 Score: {f1s_rouge.sum() / len(gold_graphs):.4f}\n')

    print(f'G-BertScore Precision Score: {precisions_BS.sum() / len(gold_graphs):.4f}')
    print(f'G-BertScore Recall Score: {recalls_BS.sum() / len(gold_graphs):.4f}')
    print(f'G-BertScore F1 Score: {f1s_BS.sum() / len(gold_graphs):.4f}\n')

    print(f'Graph Edit Distance (GED): {overall_ged / len(gold_graphs):.4f}\n')