Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from rouge_score import rouge_scorer | |
| from bert_score import score as score_bert | |
| from nltk.translate.bleu_score import sentence_bleu | |
| from nltk.translate.bleu_score import SmoothingFunction | |
| from scipy.optimize import linear_sum_assignment | |
| from spacy.tokenizer import Tokenizer | |
| # from spacy.lang.en import English | |
| from spacy.lang.zh import Chinese | |
| import re | |
| import networkx as nx | |
| from sklearn import preprocessing | |
| from sklearn.metrics import precision_score, recall_score, f1_score | |
| import json | |
| # 修改图 | |
| def modify_graph(original_graph): | |
| modified_graph = [] | |
| for x in original_graph: | |
| # modified_graph.append([str(t).strip() for t in x]) | |
| # 将所有元素转换为字符串,去除首尾的空格(无需转换大小写) | |
| modified_graph.append([str(t).strip() for t in x]) | |
| return modified_graph | |
| # 获取三元组匹配的F1分数 | |
| def get_triple_match_f1(gold_graphs, pred_graphs): | |
| # 1. 预处理:标准化格式,确保所有数据类型一致 | |
| # 2. 使用MultiLabelBinarizer将三元组转换为二进制矩阵 | |
| # 3. 计算精确率、召回率和F1 | |
| new_gold_graphs = [modify_graph(graph) for graph in gold_graphs] | |
| new_pred_graphs = [modify_graph(graph) for graph in pred_graphs] | |
| new_gold_graphs_list = [[str(string) for string in sublist] for sublist in new_gold_graphs] | |
| new_pred_graphs_list = [[str(string) for string in sublist] for sublist in new_pred_graphs] | |
| # 获取所有类 | |
| #First get all the classes by combining the triples in the pred_graphs and gold_graphs | |
| allclasses = new_pred_graphs_list + new_gold_graphs_list | |
| allclasses = [item for items in allclasses for item in items] | |
| allclasses = list(set(allclasses)) | |
| lb = preprocessing.MultiLabelBinarizer(classes=allclasses) | |
| # 将三元组转换为二进制矩阵 | |
| mcbin = lb.fit_transform(new_pred_graphs_list) | |
| mrbin = lb.fit_transform(new_gold_graphs_list) | |
| # 计算精确率、召回率和F1 | |
| precision = precision_score(mrbin, mcbin, average='micro') | |
| recall = recall_score(mrbin, mcbin, average='micro') | |
| f1 = f1_score(mrbin, mcbin, average='micro') | |
| print('Full triple scores') | |
| print('-----------------------------------------------------------------') | |
| print('Precision: ' + str(precision) + ' Recall: ' + str(recall) + '\nF1: ' + str(f1)) | |
| return f1 | |
| def get_triple_match_accuracy(pred_graph, gold_graph): | |
| """ | |
| 计算三元组匹配准确率: | |
| 直接比较预测和真实三元组是否完全相同,不考虑图结构 | |
| 返回预测三元组中与真实三元组完全匹配的比例 | |
| """ | |
| pred = modify_graph(pred_graph) | |
| gold = modify_graph(gold_graph) | |
| matchs = 0 | |
| for x in pred: | |
| if x in gold: | |
| matchs += 1 | |
| acc = matchs/len(pred) | |
| return acc | |
| def get_graph_match_accuracy(pred_graphs, gold_graphs): | |
| """ | |
| 计算图匹配准确率: | |
| 将三元组转换为有向图,考虑节点和边的连接关系 | |
| 通过图同构判断两个图的结构是否完全相同 | |
| 返回预测图与真实图完全同构的比例 | |
| """ | |
| matchs = 0 | |
| for pred, gold in zip(pred_graphs, gold_graphs): | |
| g1 = nx.DiGraph() | |
| g2 = nx.DiGraph() | |
| for edge in gold: | |
| g1.add_node(str(edge[0]).strip(), label=str(edge[0]).strip()) | |
| g1.add_node(str(edge[2]).strip(), label=str(edge[2]).strip()) | |
| g1.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip()) | |
| for edge in pred: | |
| if len(edge) == 2: | |
| edge.append('NULL') | |
| elif len(edge) == 1: | |
| edge.append('NULL') | |
| edge.append('NULL') | |
| g2.add_node(str(edge[0]).strip(), label=str(edge[0]).strip()) | |
| g2.add_node(str(edge[2]).strip(), label=str(edge[2]).strip()) | |
| g2.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip()) | |
| if nx.is_isomorphic(g1, g2, edge_match=lambda x, y: x == y): | |
| matchs += 1 | |
| acc = matchs/len(pred_graphs) | |
| return acc | |
| def get_tokens(gold_edges, pred_edges): | |
| nlp = Chinese() # 使用中文分词器 | |
| # tokenizer = Tokenizer(nlp.vocab, infix_finditer=re.compile(r'''[;]''').finditer) | |
| tokenizer = Tokenizer(nlp.vocab, | |
| infix_finditer=re.compile(r'''[;|,|。]''').finditer, | |
| prefix_search=re.compile(r'''^[【(\[[]''').search, | |
| suffix_search=re.compile(r'''[】)\]]]$''').search) | |
| gold_tokens = [] | |
| pred_tokens = [] | |
| for i in range(len(gold_edges)): | |
| gold_tokens_edges = [] | |
| pred_tokens_edges = [] | |
| for sample in tokenizer.pipe(gold_edges[i]): | |
| gold_tokens_edges.append([j.text for j in sample]) | |
| for sample in tokenizer.pipe(pred_edges[i]): | |
| pred_tokens_edges.append([j.text for j in sample]) | |
| gold_tokens.append(gold_tokens_edges) | |
| pred_tokens.append(pred_tokens_edges) | |
| return gold_tokens, pred_tokens | |
| def split_to_edges(graphs): | |
| processed_graphs = [] | |
| for graph in graphs: | |
| #print(graph) | |
| processed_graphs.append([";".join(triple).strip() for triple in graph]) | |
| return processed_graphs | |
| def get_bert_score(all_gold_edges, all_pred_edges): | |
| references = [] | |
| candidates = [] | |
| ref_cand_index = {} | |
| for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges): | |
| for (i, gold_edge) in enumerate(gold_edges): | |
| for (j, pred_edge) in enumerate(pred_edges): | |
| references.append(gold_edge) | |
| candidates.append(pred_edge) | |
| ref_cand_index[(gold_edge, pred_edge)] = len(references) - 1 | |
| _, _, bs_F1 = score_bert(cands=candidates, refs=references, lang='en', idf=False) | |
| print("Computed bert scores for all pairs") | |
| precisions, recalls, f1s = [], [], [] | |
| for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges): | |
| score_matrix = np.zeros((len(gold_edges), len(pred_edges))) | |
| for (i, gold_edge) in enumerate(gold_edges): | |
| for (j, pred_edge) in enumerate(pred_edges): | |
| score_matrix[i][j] = bs_F1[ref_cand_index[(gold_edge, pred_edge)]] | |
| row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True) | |
| sample_precision = score_matrix[row_ind, col_ind].sum() / len(pred_edges) | |
| sample_recall = score_matrix[row_ind, col_ind].sum() / len(gold_edges) | |
| precisions.append(sample_precision) | |
| recalls.append(sample_recall) | |
| f1s.append(2 * sample_precision * sample_recall / (sample_precision + sample_recall)) | |
| return np.array(precisions), np.array(recalls), np.array(f1s) | |
| # Note: These graph matching metrics are computed by considering each graph as a set of edges and each edge as a | |
| # sentence | |
| def get_bleu_rouge(gold_tokens, pred_tokens, gold_sent, pred_sent): | |
| scorer_rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rougeL'], use_stemmer=True) | |
| precisions_bleu = [] | |
| recalls_bleu = [] | |
| f1s_bleu = [] | |
| precisions_rouge = [] | |
| recalls_rouge = [] | |
| f1s_rouge = [] | |
| for graph_idx in range(len(gold_tokens)): | |
| score_bleu = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx]))) | |
| score_rouge = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx]))) | |
| for p_idx in range(len(pred_tokens[graph_idx])): | |
| for g_idx in range(len(gold_tokens[graph_idx])): | |
| score_bleu[p_idx, g_idx] = sentence_bleu([gold_tokens[graph_idx][g_idx]], pred_tokens[graph_idx][p_idx], smoothing_function=SmoothingFunction().method1) | |
| # # 将token列表转换为字符串 | |
| # gold_text = ' '.join(gold_tokens[graph_idx][g_idx]) | |
| # pred_text = ' '.join(pred_tokens[graph_idx][p_idx]) | |
| # score_rouge[p_idx, g_idx] = scorer_rouge.score(gold_text, pred_text)['rouge2'].precision | |
| score_rouge[p_idx, g_idx] = \ | |
| scorer_rouge.score(gold_sent[graph_idx][g_idx], pred_sent[graph_idx][p_idx])['rouge2'].fmeasure # 使用F1而不是precision | |
| def _scores(cost_matrix): | |
| row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True) | |
| precision = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[0] | |
| recall = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[1] | |
| f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0 | |
| return precision, recall, f1 | |
| precision_bleu, recall_bleu, f1_bleu = _scores(score_bleu) | |
| precisions_bleu.append(precision_bleu) | |
| recalls_bleu.append(recall_bleu) | |
| f1s_bleu.append(f1_bleu) | |
| precision_rouge, recall_rouge, f1_rouge = _scores(score_rouge) | |
| precisions_rouge.append(precision_rouge) | |
| recalls_rouge.append(recall_rouge) | |
| f1s_rouge.append(f1_rouge) | |
| return np.array(precisions_rouge), np.array(recalls_rouge), np.array(f1s_rouge), np.array( | |
| precisions_bleu), np.array(recalls_bleu), np.array(f1s_bleu) | |
| def return_eq_node(node1, node2): | |
| return node1['label'] == node2['label'] | |
| def return_eq_edge(edge1, edge2): | |
| return edge1['label'] == edge2['label'] | |
| def get_ged(gold_graph, pred_graph=None): | |
| g1 = nx.DiGraph() | |
| g2 = nx.DiGraph() | |
| for edge in gold_graph: | |
| g1.add_node(str(edge[0]).strip(), label=str(edge[0]).strip()) | |
| g1.add_node(str(edge[2]).strip(), label=str(edge[2]).strip()) | |
| g1.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip()) | |
| # The upper bound is defined wrt the graph for which GED is the worst. | |
| # Since ExplaGraphs (by construction) allows a maximum of 8 edges, the worst GED = gold_nodes + gold_edges + 8 + 9. | |
| # This happens when the predicted graph is linear with 8 edges and 9 nodes. | |
| # In such a case, for GED to be the worst, we assume that all nodes and edges of the predicted graph are deleted and | |
| # then all nodes and edges of the gold graph are added. | |
| # Note that a stricter upper bound can be computed by considering some replacement operations but we ignore that for convenience | |
| normalizing_constant = g1.number_of_nodes() + g1.number_of_edges() + 30 | |
| if pred_graph is None: | |
| return 1 | |
| for edge in pred_graph: | |
| if len(edge) == 2: | |
| edge.append('NULL') | |
| elif len(edge) == 1: | |
| edge.append('NULL') | |
| edge.append('NULL') | |
| g2.add_node(str(edge[0]).strip(), label=str(edge[0]).strip()) | |
| g2.add_node(str(edge[2]).strip(), label=str(edge[2]).strip()) | |
| g2.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip()) | |
| ged = nx.graph_edit_distance(g1, g2, node_match=return_eq_node, edge_match=return_eq_edge) | |
| assert ged <= normalizing_constant | |
| return ged / normalizing_constant | |
| def clean_and_validate_data(gold_graphs, pred_graphs): | |
| """清理和验证数据""" | |
| valid_gold = [] | |
| valid_pred = [] | |
| for gold, pred in zip(gold_graphs, pred_graphs): | |
| # 确保两者都不为空 | |
| if gold and pred and isinstance(gold, list) and isinstance(pred, list): | |
| valid_gold.append(gold) | |
| valid_pred.append(pred) | |
| return valid_gold, valid_pred | |
| def evaluate_triples(gold_graphs, pred_graphs): | |
| print("开始评估...") | |
| print("="*50) | |
| # 清理和验证数据 | |
| gold_graphs, pred_graphs = clean_and_validate_data(gold_graphs, pred_graphs) | |
| if not gold_graphs or not pred_graphs: | |
| print("警告:没有有效的评估数据!") | |
| return None | |
| # 1. Triple Match F1 | |
| try: | |
| triple_f1 = get_triple_match_f1(gold_graphs, pred_graphs) | |
| print(f"三元组匹配F1分数: {triple_f1:.4f}") | |
| except Exception as e: | |
| print(f"计算Triple Match F1出错: {str(e)}") | |
| triple_f1 = 0.0 | |
| # 2. Graph Match Accuracy | |
| try: | |
| graph_acc = get_graph_match_accuracy(pred_graphs, gold_graphs) | |
| print(f"图匹配准确率: {graph_acc:.4f}") | |
| except Exception as e: | |
| print(f"计算Graph Match Accuracy出错: {str(e)}") | |
| graph_acc = 0.0 | |
| # 防止除零错误 | |
| def safe_divide(a, b): | |
| return a / b if b != 0 else 0.0 | |
| # 修改BERT Score计算 | |
| gold_edges = split_to_edges(gold_graphs) | |
| pred_edges = split_to_edges(pred_graphs) | |
| if any(gold_edges) and any(pred_edges): | |
| precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges) | |
| print(f"BERT Score:") | |
| print(f"- Precision: {safe_divide(precisions_BS.sum(), len(precisions_BS)):.4f}") | |
| print(f"- Recall: {safe_divide(recalls_BS.sum(), len(recalls_BS)):.4f}") | |
| print(f"- F1: {safe_divide(f1s_BS.sum(), len(f1s_BS)):.4f}") | |
| # ... 其余评估代码 ... | |
| def load_data(gold_path, pred_path): | |
| try: | |
| # 加载真实数据 | |
| with open(gold_path, 'r', encoding='utf-8') as f: | |
| gold_data = json.load(f) | |
| # 加载预测数据 | |
| with open(pred_path, 'r', encoding='utf-8') as f: | |
| pred_data = json.load(f) | |
| # 数据验证 | |
| if not isinstance(gold_data, list) or not isinstance(pred_data, list): | |
| print("警告:数据格式不正确,应为列表格式") | |
| return [], [] | |
| # 提取并验证三元组列表 | |
| gold_graphs = [] | |
| pred_graphs = [] | |
| for pred_item in pred_data: | |
| if not isinstance(pred_item, dict) or 'text' not in pred_item or 'triple_list' not in pred_item: | |
| continue | |
| pred_text = pred_item['text'] | |
| for gold_item in gold_data: | |
| if gold_item.get('text') == pred_text: | |
| if gold_item.get('triple_list') and pred_item.get('triple_list'): | |
| gold_graphs.append(gold_item['triple_list']) | |
| pred_graphs.append(pred_item['triple_list']) | |
| break | |
| print(f"加载的数据数量: Gold={len(gold_graphs)}, Pred={len(pred_graphs)}") | |
| print("Gold样本:", gold_graphs[0] if gold_graphs else "空") | |
| print("Pred样本:", pred_graphs[0] if pred_graphs else "空") | |
| return gold_graphs, pred_graphs | |
| except Exception as e: | |
| print(f"加载数据时出错: {str(e)}") | |
| return [], [] | |