File size: 14,545 Bytes
badcf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as score_bert
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
from scipy.optimize import linear_sum_assignment
from spacy.tokenizer import Tokenizer
# from spacy.lang.en import English
from spacy.lang.zh import Chinese
import re
import networkx as nx
from sklearn import preprocessing
from sklearn.metrics import precision_score, recall_score, f1_score
import json

# 修改图
def modify_graph(original_graph):
    modified_graph = []
    for x in original_graph:
        # modified_graph.append([str(t).strip() for t in x])
        # 将所有元素转换为字符串,去除首尾的空格(无需转换大小写)
        modified_graph.append([str(t).strip() for t in x])
    return modified_graph

# 获取三元组匹配的F1分数
def get_triple_match_f1(gold_graphs, pred_graphs):
    # 1. 预处理:标准化格式,确保所有数据类型一致
    # 2. 使用MultiLabelBinarizer将三元组转换为二进制矩阵
    # 3. 计算精确率、召回率和F1
    new_gold_graphs = [modify_graph(graph) for graph in gold_graphs]
    new_pred_graphs = [modify_graph(graph) for graph in pred_graphs]
    new_gold_graphs_list = [[str(string) for string in sublist] for sublist in new_gold_graphs]
    new_pred_graphs_list = [[str(string) for string in sublist] for sublist in new_pred_graphs]
    # 获取所有类
    #First get all the classes by combining the triples in the pred_graphs and gold_graphs
    allclasses = new_pred_graphs_list + new_gold_graphs_list
    allclasses = [item for items in allclasses for item in items]
    allclasses = list(set(allclasses))

    lb = preprocessing.MultiLabelBinarizer(classes=allclasses)
    # 将三元组转换为二进制矩阵
    mcbin = lb.fit_transform(new_pred_graphs_list)
    mrbin = lb.fit_transform(new_gold_graphs_list)
    # 计算精确率、召回率和F1
    precision = precision_score(mrbin, mcbin, average='micro')
    recall = recall_score(mrbin, mcbin, average='micro')
    f1 = f1_score(mrbin, mcbin, average='micro')

    print('Full triple scores')
    print('-----------------------------------------------------------------')
    print('Precision: ' + str(precision) + ' Recall: ' + str(recall) + '\nF1: ' + str(f1))
    return f1

def get_triple_match_accuracy(pred_graph, gold_graph):
    """
    计算三元组匹配准确率:
    直接比较预测和真实三元组是否完全相同,不考虑图结构
    返回预测三元组中与真实三元组完全匹配的比例
    """
    pred = modify_graph(pred_graph)
    gold = modify_graph(gold_graph)    
    matchs = 0
    for x in pred:
        if x in gold:
            matchs += 1
    acc = matchs/len(pred)
    return acc

def get_graph_match_accuracy(pred_graphs, gold_graphs):
    """
    计算图匹配准确率:
    将三元组转换为有向图,考虑节点和边的连接关系
    通过图同构判断两个图的结构是否完全相同
    返回预测图与真实图完全同构的比例
    """
    matchs = 0
    for pred, gold in zip(pred_graphs, gold_graphs):
        g1 = nx.DiGraph()
        g2 = nx.DiGraph()

        for edge in gold:
            g1.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
            g1.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
            g1.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())

        for edge in pred:
            if len(edge) == 2:
                edge.append('NULL')
            elif len(edge) == 1:
                edge.append('NULL')
                edge.append('NULL')
            g2.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
            g2.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
            g2.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())

        if nx.is_isomorphic(g1, g2, edge_match=lambda x, y: x == y):
            matchs += 1
    acc = matchs/len(pred_graphs)
    return acc

def get_tokens(gold_edges, pred_edges):
    nlp = Chinese()  # 使用中文分词器
    # tokenizer = Tokenizer(nlp.vocab, infix_finditer=re.compile(r'''[;]''').finditer)
    tokenizer = Tokenizer(nlp.vocab, 
                         infix_finditer=re.compile(r'''[;|,|。]''').finditer,
                         prefix_search=re.compile(r'''^[【(\[[]''').search,
                         suffix_search=re.compile(r'''[】)\]]]$''').search)
    gold_tokens = []
    pred_tokens = []

    for i in range(len(gold_edges)):
        gold_tokens_edges = []
        pred_tokens_edges = []

        for sample in tokenizer.pipe(gold_edges[i]):
            gold_tokens_edges.append([j.text for j in sample])
        for sample in tokenizer.pipe(pred_edges[i]):
            pred_tokens_edges.append([j.text for j in sample])
        gold_tokens.append(gold_tokens_edges)
        pred_tokens.append(pred_tokens_edges)

    return gold_tokens, pred_tokens


def split_to_edges(graphs):
    processed_graphs = []
    for graph in graphs:
        #print(graph)
        processed_graphs.append([";".join(triple).strip() for triple in graph])
    return processed_graphs


def get_bert_score(all_gold_edges, all_pred_edges):
    references = []
    candidates = []

    ref_cand_index = {}
    for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges):
        for (i, gold_edge) in enumerate(gold_edges):
            for (j, pred_edge) in enumerate(pred_edges):
                references.append(gold_edge)
                candidates.append(pred_edge)
                ref_cand_index[(gold_edge, pred_edge)] = len(references) - 1

    _, _, bs_F1 = score_bert(cands=candidates, refs=references, lang='en', idf=False)
    print("Computed bert scores for all pairs")

    precisions, recalls, f1s = [], [], []
    for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges):
        score_matrix = np.zeros((len(gold_edges), len(pred_edges)))
        for (i, gold_edge) in enumerate(gold_edges):
            for (j, pred_edge) in enumerate(pred_edges):
                score_matrix[i][j] = bs_F1[ref_cand_index[(gold_edge, pred_edge)]]

        row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)

        sample_precision = score_matrix[row_ind, col_ind].sum() / len(pred_edges)
        sample_recall = score_matrix[row_ind, col_ind].sum() / len(gold_edges)

        precisions.append(sample_precision)
        recalls.append(sample_recall)
        f1s.append(2 * sample_precision * sample_recall / (sample_precision + sample_recall))

    return np.array(precisions), np.array(recalls), np.array(f1s)


# Note: These graph matching metrics are computed by considering each graph as a set of edges and each edge as a
# sentence
def get_bleu_rouge(gold_tokens, pred_tokens, gold_sent, pred_sent):
    scorer_rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rougeL'], use_stemmer=True)

    precisions_bleu = []
    recalls_bleu = []
    f1s_bleu = []

    precisions_rouge = []
    recalls_rouge = []
    f1s_rouge = []

    for graph_idx in range(len(gold_tokens)):
        score_bleu = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx])))
        score_rouge = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx])))
        for p_idx in range(len(pred_tokens[graph_idx])):
            for g_idx in range(len(gold_tokens[graph_idx])):
                score_bleu[p_idx, g_idx] = sentence_bleu([gold_tokens[graph_idx][g_idx]], pred_tokens[graph_idx][p_idx], smoothing_function=SmoothingFunction().method1)
                # # 将token列表转换为字符串
                # gold_text = ' '.join(gold_tokens[graph_idx][g_idx])
                # pred_text = ' '.join(pred_tokens[graph_idx][p_idx])
                # score_rouge[p_idx, g_idx] = scorer_rouge.score(gold_text, pred_text)['rouge2'].precision
                
                score_rouge[p_idx, g_idx] = \
                    scorer_rouge.score(gold_sent[graph_idx][g_idx], pred_sent[graph_idx][p_idx])['rouge2'].fmeasure  # 使用F1而不是precision
                
        def _scores(cost_matrix):
            row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
            precision = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[0]
            recall = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[1]
            f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0
            return precision, recall, f1

        precision_bleu, recall_bleu, f1_bleu = _scores(score_bleu)
        precisions_bleu.append(precision_bleu)
        recalls_bleu.append(recall_bleu)
        f1s_bleu.append(f1_bleu)

        precision_rouge, recall_rouge, f1_rouge = _scores(score_rouge)
        precisions_rouge.append(precision_rouge)
        recalls_rouge.append(recall_rouge)
        f1s_rouge.append(f1_rouge)

    return np.array(precisions_rouge), np.array(recalls_rouge), np.array(f1s_rouge), np.array(
        precisions_bleu), np.array(recalls_bleu), np.array(f1s_bleu)


def return_eq_node(node1, node2):
    return node1['label'] == node2['label']


def return_eq_edge(edge1, edge2):
    return edge1['label'] == edge2['label']

def get_ged(gold_graph, pred_graph=None):
    g1 = nx.DiGraph()
    g2 = nx.DiGraph()

    for edge in gold_graph:
        g1.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
        g1.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
        g1.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())

    # The upper bound is defined wrt the graph for which GED is the worst.
    # Since ExplaGraphs (by construction) allows a maximum of 8 edges, the worst GED = gold_nodes + gold_edges + 8 + 9.
    # This happens when the predicted graph is linear with 8 edges and 9 nodes.
    # In such a case, for GED to be the worst, we assume that all nodes and edges of the predicted graph are deleted and
    # then all nodes and edges of the gold graph are added.
    # Note that a stricter upper bound can be computed by considering some replacement operations but we ignore that for convenience
    normalizing_constant = g1.number_of_nodes() + g1.number_of_edges() + 30

    if pred_graph is None:
        return 1

    for edge in pred_graph:
        if len(edge) == 2:
            edge.append('NULL')
        elif len(edge) == 1:
            edge.append('NULL')
            edge.append('NULL')
        g2.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
        g2.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
        g2.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())

    ged = nx.graph_edit_distance(g1, g2, node_match=return_eq_node, edge_match=return_eq_edge)

    assert ged <= normalizing_constant

    return ged / normalizing_constant

def clean_and_validate_data(gold_graphs, pred_graphs):
    """清理和验证数据"""
    valid_gold = []
    valid_pred = []
    
    for gold, pred in zip(gold_graphs, pred_graphs):
        # 确保两者都不为空
        if gold and pred and isinstance(gold, list) and isinstance(pred, list):
            valid_gold.append(gold)
            valid_pred.append(pred)
            
    return valid_gold, valid_pred

def evaluate_triples(gold_graphs, pred_graphs):
    print("开始评估...")
    print("="*50)
    
    # 清理和验证数据
    gold_graphs, pred_graphs = clean_and_validate_data(gold_graphs, pred_graphs)
    
    if not gold_graphs or not pred_graphs:
        print("警告:没有有效的评估数据!")
        return None
        
    # 1. Triple Match F1
    try:
        triple_f1 = get_triple_match_f1(gold_graphs, pred_graphs)
        print(f"三元组匹配F1分数: {triple_f1:.4f}")
    except Exception as e:
        print(f"计算Triple Match F1出错: {str(e)}")
        triple_f1 = 0.0
    
    # 2. Graph Match Accuracy
    try:
        graph_acc = get_graph_match_accuracy(pred_graphs, gold_graphs)
        print(f"图匹配准确率: {graph_acc:.4f}")
    except Exception as e:
        print(f"计算Graph Match Accuracy出错: {str(e)}")
        graph_acc = 0.0
        
    # 防止除零错误
    def safe_divide(a, b):
        return a / b if b != 0 else 0.0
        
    # 修改BERT Score计算
    gold_edges = split_to_edges(gold_graphs)
    pred_edges = split_to_edges(pred_graphs)
    
    if any(gold_edges) and any(pred_edges):
        precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges)
        print(f"BERT Score:")
        print(f"- Precision: {safe_divide(precisions_BS.sum(), len(precisions_BS)):.4f}")
        print(f"- Recall: {safe_divide(recalls_BS.sum(), len(recalls_BS)):.4f}")
        print(f"- F1: {safe_divide(f1s_BS.sum(), len(f1s_BS)):.4f}")
    
    # ... 其余评估代码 ...

def load_data(gold_path, pred_path):
    try:
        # 加载真实数据
        with open(gold_path, 'r', encoding='utf-8') as f:
            gold_data = json.load(f)
        
        # 加载预测数据
        with open(pred_path, 'r', encoding='utf-8') as f:
            pred_data = json.load(f)
        
        # 数据验证
        if not isinstance(gold_data, list) or not isinstance(pred_data, list):
            print("警告:数据格式不正确,应为列表格式")
            return [], []
            
        # 提取并验证三元组列表
        gold_graphs = []
        pred_graphs = []
        
        for pred_item in pred_data:
            if not isinstance(pred_item, dict) or 'text' not in pred_item or 'triple_list' not in pred_item:
                continue
                
            pred_text = pred_item['text']
            for gold_item in gold_data:
                if gold_item.get('text') == pred_text:
                    if gold_item.get('triple_list') and pred_item.get('triple_list'):
                        gold_graphs.append(gold_item['triple_list'])
                        pred_graphs.append(pred_item['triple_list'])
                    break
        
        print(f"加载的数据数量: Gold={len(gold_graphs)}, Pred={len(pred_graphs)}")
        print("Gold样本:", gold_graphs[0] if gold_graphs else "空")
        print("Pred样本:", pred_graphs[0] if pred_graphs else "空")
        
        return gold_graphs, pred_graphs
        
    except Exception as e:
        print(f"加载数据时出错: {str(e)}")
        return [], []