GeoLLM / metrics /graph_matching - 副本.py
Ciallo0d00's picture
Upload folder using huggingface_hub
badcf3c verified
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as score_bert
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
from scipy.optimize import linear_sum_assignment
from spacy.tokenizer import Tokenizer
# from spacy.lang.en import English
from spacy.lang.zh import Chinese
import re
import networkx as nx
from sklearn import preprocessing
from sklearn.metrics import precision_score, recall_score, f1_score
import json
# 修改图
def modify_graph(original_graph):
modified_graph = []
for x in original_graph:
# modified_graph.append([str(t).strip() for t in x])
# 将所有元素转换为字符串,去除首尾的空格(无需转换大小写)
modified_graph.append([str(t).strip() for t in x])
return modified_graph
# 获取三元组匹配的F1分数
def get_triple_match_f1(gold_graphs, pred_graphs):
# 1. 预处理:标准化格式,确保所有数据类型一致
# 2. 使用MultiLabelBinarizer将三元组转换为二进制矩阵
# 3. 计算精确率、召回率和F1
new_gold_graphs = [modify_graph(graph) for graph in gold_graphs]
new_pred_graphs = [modify_graph(graph) for graph in pred_graphs]
new_gold_graphs_list = [[str(string) for string in sublist] for sublist in new_gold_graphs]
new_pred_graphs_list = [[str(string) for string in sublist] for sublist in new_pred_graphs]
# 获取所有类
#First get all the classes by combining the triples in the pred_graphs and gold_graphs
allclasses = new_pred_graphs_list + new_gold_graphs_list
allclasses = [item for items in allclasses for item in items]
allclasses = list(set(allclasses))
lb = preprocessing.MultiLabelBinarizer(classes=allclasses)
# 将三元组转换为二进制矩阵
mcbin = lb.fit_transform(new_pred_graphs_list)
mrbin = lb.fit_transform(new_gold_graphs_list)
# 计算精确率、召回率和F1
precision = precision_score(mrbin, mcbin, average='micro')
recall = recall_score(mrbin, mcbin, average='micro')
f1 = f1_score(mrbin, mcbin, average='micro')
print('Full triple scores')
print('-----------------------------------------------------------------')
print('Precision: ' + str(precision) + ' Recall: ' + str(recall) + '\nF1: ' + str(f1))
return f1
def get_triple_match_accuracy(pred_graph, gold_graph):
"""
计算三元组匹配准确率:
直接比较预测和真实三元组是否完全相同,不考虑图结构
返回预测三元组中与真实三元组完全匹配的比例
"""
pred = modify_graph(pred_graph)
gold = modify_graph(gold_graph)
matchs = 0
for x in pred:
if x in gold:
matchs += 1
acc = matchs/len(pred)
return acc
def get_graph_match_accuracy(pred_graphs, gold_graphs):
"""
计算图匹配准确率:
将三元组转换为有向图,考虑节点和边的连接关系
通过图同构判断两个图的结构是否完全相同
返回预测图与真实图完全同构的比例
"""
matchs = 0
for pred, gold in zip(pred_graphs, gold_graphs):
g1 = nx.DiGraph()
g2 = nx.DiGraph()
for edge in gold:
g1.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
g1.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
g1.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())
for edge in pred:
if len(edge) == 2:
edge.append('NULL')
elif len(edge) == 1:
edge.append('NULL')
edge.append('NULL')
g2.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
g2.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
g2.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())
if nx.is_isomorphic(g1, g2, edge_match=lambda x, y: x == y):
matchs += 1
acc = matchs/len(pred_graphs)
return acc
def get_tokens(gold_edges, pred_edges):
nlp = Chinese() # 使用中文分词器
# tokenizer = Tokenizer(nlp.vocab, infix_finditer=re.compile(r'''[;]''').finditer)
tokenizer = Tokenizer(nlp.vocab,
infix_finditer=re.compile(r'''[;|,|。]''').finditer,
prefix_search=re.compile(r'''^[【(\[[]''').search,
suffix_search=re.compile(r'''[】)\]]]$''').search)
gold_tokens = []
pred_tokens = []
for i in range(len(gold_edges)):
gold_tokens_edges = []
pred_tokens_edges = []
for sample in tokenizer.pipe(gold_edges[i]):
gold_tokens_edges.append([j.text for j in sample])
for sample in tokenizer.pipe(pred_edges[i]):
pred_tokens_edges.append([j.text for j in sample])
gold_tokens.append(gold_tokens_edges)
pred_tokens.append(pred_tokens_edges)
return gold_tokens, pred_tokens
def split_to_edges(graphs):
processed_graphs = []
for graph in graphs:
#print(graph)
processed_graphs.append([";".join(triple).strip() for triple in graph])
return processed_graphs
def get_bert_score(all_gold_edges, all_pred_edges):
references = []
candidates = []
ref_cand_index = {}
for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges):
for (i, gold_edge) in enumerate(gold_edges):
for (j, pred_edge) in enumerate(pred_edges):
references.append(gold_edge)
candidates.append(pred_edge)
ref_cand_index[(gold_edge, pred_edge)] = len(references) - 1
_, _, bs_F1 = score_bert(cands=candidates, refs=references, lang='en', idf=False)
print("Computed bert scores for all pairs")
precisions, recalls, f1s = [], [], []
for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges):
score_matrix = np.zeros((len(gold_edges), len(pred_edges)))
for (i, gold_edge) in enumerate(gold_edges):
for (j, pred_edge) in enumerate(pred_edges):
score_matrix[i][j] = bs_F1[ref_cand_index[(gold_edge, pred_edge)]]
row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)
sample_precision = score_matrix[row_ind, col_ind].sum() / len(pred_edges)
sample_recall = score_matrix[row_ind, col_ind].sum() / len(gold_edges)
precisions.append(sample_precision)
recalls.append(sample_recall)
f1s.append(2 * sample_precision * sample_recall / (sample_precision + sample_recall))
return np.array(precisions), np.array(recalls), np.array(f1s)
# Note: These graph matching metrics are computed by considering each graph as a set of edges and each edge as a
# sentence
def get_bleu_rouge(gold_tokens, pred_tokens, gold_sent, pred_sent):
scorer_rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rougeL'], use_stemmer=True)
precisions_bleu = []
recalls_bleu = []
f1s_bleu = []
precisions_rouge = []
recalls_rouge = []
f1s_rouge = []
for graph_idx in range(len(gold_tokens)):
score_bleu = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx])))
score_rouge = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx])))
for p_idx in range(len(pred_tokens[graph_idx])):
for g_idx in range(len(gold_tokens[graph_idx])):
score_bleu[p_idx, g_idx] = sentence_bleu([gold_tokens[graph_idx][g_idx]], pred_tokens[graph_idx][p_idx], smoothing_function=SmoothingFunction().method1)
# # 将token列表转换为字符串
# gold_text = ' '.join(gold_tokens[graph_idx][g_idx])
# pred_text = ' '.join(pred_tokens[graph_idx][p_idx])
# score_rouge[p_idx, g_idx] = scorer_rouge.score(gold_text, pred_text)['rouge2'].precision
score_rouge[p_idx, g_idx] = \
scorer_rouge.score(gold_sent[graph_idx][g_idx], pred_sent[graph_idx][p_idx])['rouge2'].fmeasure # 使用F1而不是precision
def _scores(cost_matrix):
row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
precision = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[0]
recall = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[1]
f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0
return precision, recall, f1
precision_bleu, recall_bleu, f1_bleu = _scores(score_bleu)
precisions_bleu.append(precision_bleu)
recalls_bleu.append(recall_bleu)
f1s_bleu.append(f1_bleu)
precision_rouge, recall_rouge, f1_rouge = _scores(score_rouge)
precisions_rouge.append(precision_rouge)
recalls_rouge.append(recall_rouge)
f1s_rouge.append(f1_rouge)
return np.array(precisions_rouge), np.array(recalls_rouge), np.array(f1s_rouge), np.array(
precisions_bleu), np.array(recalls_bleu), np.array(f1s_bleu)
def return_eq_node(node1, node2):
return node1['label'] == node2['label']
def return_eq_edge(edge1, edge2):
return edge1['label'] == edge2['label']
def get_ged(gold_graph, pred_graph=None):
g1 = nx.DiGraph()
g2 = nx.DiGraph()
for edge in gold_graph:
g1.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
g1.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
g1.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())
# The upper bound is defined wrt the graph for which GED is the worst.
# Since ExplaGraphs (by construction) allows a maximum of 8 edges, the worst GED = gold_nodes + gold_edges + 8 + 9.
# This happens when the predicted graph is linear with 8 edges and 9 nodes.
# In such a case, for GED to be the worst, we assume that all nodes and edges of the predicted graph are deleted and
# then all nodes and edges of the gold graph are added.
# Note that a stricter upper bound can be computed by considering some replacement operations but we ignore that for convenience
normalizing_constant = g1.number_of_nodes() + g1.number_of_edges() + 30
if pred_graph is None:
return 1
for edge in pred_graph:
if len(edge) == 2:
edge.append('NULL')
elif len(edge) == 1:
edge.append('NULL')
edge.append('NULL')
g2.add_node(str(edge[0]).strip(), label=str(edge[0]).strip())
g2.add_node(str(edge[2]).strip(), label=str(edge[2]).strip())
g2.add_edge(str(edge[0]).strip(), str(edge[2]).strip(), label=str(edge[1]).strip())
ged = nx.graph_edit_distance(g1, g2, node_match=return_eq_node, edge_match=return_eq_edge)
assert ged <= normalizing_constant
return ged / normalizing_constant
def clean_and_validate_data(gold_graphs, pred_graphs):
"""清理和验证数据"""
valid_gold = []
valid_pred = []
for gold, pred in zip(gold_graphs, pred_graphs):
# 确保两者都不为空
if gold and pred and isinstance(gold, list) and isinstance(pred, list):
valid_gold.append(gold)
valid_pred.append(pred)
return valid_gold, valid_pred
def evaluate_triples(gold_graphs, pred_graphs):
print("开始评估...")
print("="*50)
# 清理和验证数据
gold_graphs, pred_graphs = clean_and_validate_data(gold_graphs, pred_graphs)
if not gold_graphs or not pred_graphs:
print("警告:没有有效的评估数据!")
return None
# 1. Triple Match F1
try:
triple_f1 = get_triple_match_f1(gold_graphs, pred_graphs)
print(f"三元组匹配F1分数: {triple_f1:.4f}")
except Exception as e:
print(f"计算Triple Match F1出错: {str(e)}")
triple_f1 = 0.0
# 2. Graph Match Accuracy
try:
graph_acc = get_graph_match_accuracy(pred_graphs, gold_graphs)
print(f"图匹配准确率: {graph_acc:.4f}")
except Exception as e:
print(f"计算Graph Match Accuracy出错: {str(e)}")
graph_acc = 0.0
# 防止除零错误
def safe_divide(a, b):
return a / b if b != 0 else 0.0
# 修改BERT Score计算
gold_edges = split_to_edges(gold_graphs)
pred_edges = split_to_edges(pred_graphs)
if any(gold_edges) and any(pred_edges):
precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges)
print(f"BERT Score:")
print(f"- Precision: {safe_divide(precisions_BS.sum(), len(precisions_BS)):.4f}")
print(f"- Recall: {safe_divide(recalls_BS.sum(), len(recalls_BS)):.4f}")
print(f"- F1: {safe_divide(f1s_BS.sum(), len(f1s_BS)):.4f}")
# ... 其余评估代码 ...
def load_data(gold_path, pred_path):
try:
# 加载真实数据
with open(gold_path, 'r', encoding='utf-8') as f:
gold_data = json.load(f)
# 加载预测数据
with open(pred_path, 'r', encoding='utf-8') as f:
pred_data = json.load(f)
# 数据验证
if not isinstance(gold_data, list) or not isinstance(pred_data, list):
print("警告:数据格式不正确,应为列表格式")
return [], []
# 提取并验证三元组列表
gold_graphs = []
pred_graphs = []
for pred_item in pred_data:
if not isinstance(pred_item, dict) or 'text' not in pred_item or 'triple_list' not in pred_item:
continue
pred_text = pred_item['text']
for gold_item in gold_data:
if gold_item.get('text') == pred_text:
if gold_item.get('triple_list') and pred_item.get('triple_list'):
gold_graphs.append(gold_item['triple_list'])
pred_graphs.append(pred_item['triple_list'])
break
print(f"加载的数据数量: Gold={len(gold_graphs)}, Pred={len(pred_graphs)}")
print("Gold样本:", gold_graphs[0] if gold_graphs else "空")
print("Pred样本:", pred_graphs[0] if pred_graphs else "空")
return gold_graphs, pred_graphs
except Exception as e:
print(f"加载数据时出错: {str(e)}")
return [], []