import numpy as np
from bert_score import score as score_bert
from scipy.optimize import linear_sum_assignment
from sklearn import preprocessing
from sklearn.metrics import precision_score, recall_score, f1_score

def modify_graph(original_graph):
    modified_graph = []
    for x in original_graph:
        # 将所有元素转换为字符串
        modified_graph.append([str(t).strip() for t in x])
    return modified_graph

# 获取三元组匹配的F1分数
def get_triple_match_f1(gold_graphs, pred_graphs):
    new_gold_graphs = [modify_graph(graph) for graph in gold_graphs]
    new_pred_graphs = [modify_graph(graph) for graph in pred_graphs]
    new_gold_graphs_list = [[str(string) for string in sublist] for sublist in new_gold_graphs]
    new_pred_graphs_list = [[str(string) for string in sublist] for sublist in new_pred_graphs]
    # 获取所有类
    allclasses = new_pred_graphs_list + new_gold_graphs_list
    allclasses = [item for items in allclasses for item in items]
    allclasses = list(set(allclasses))

    lb = preprocessing.MultiLabelBinarizer(classes=allclasses)
    # 将三元组转换为二进制矩阵
    mcbin = lb.fit_transform(new_pred_graphs_list)
    mrbin = lb.fit_transform(new_gold_graphs_list)
    # 计算精确率、召回率和F1
    precision = precision_score(mrbin, mcbin, average='micro')
    recall = recall_score(mrbin, mcbin, average='micro')
    f1 = f1_score(mrbin, mcbin, average='micro')
    # 打印精确率、召回率和F1
    # print('精确率、召回率和F1')
    # print('-----------------------------------------------------------------')
    # print(f"精确率: {precision:.4f}, 召回率: {recall:.4f}, F1值: {f1:.4f}")

    return precision, recall, f1

def split_to_edges(graphs):
    # 计算空预测率
    empty_rate = sum(1 for graph in graphs if len(graph) == 0) / len(graphs)
    print(f"空预测率: {empty_rate:.2%}")
    # print(len(graphs))

    processed_graphs = []
    for graph in graphs:
        valid_triples = []
        for triple in graph:
            # 转换和清洗
            cleaned_triple = [
                (str(elem).strip() if elem != "" else "none")  
                for elem in triple
            ]
            # 检查有效性
            if all(cleaned_triple):  # 要求所有元素非空
                valid_triples.append(";".join(cleaned_triple))
        processed_graphs.append(valid_triples)
    return processed_graphs

def get_bert_score(all_gold_edges, all_pred_edges):
    """使用BERTScore计算图匹配的精确率、召回率和F1值
    参数:
        all_gold_edges (list): 标准答案的边列表，每个元素是一个样本的边集合
        all_pred_edges (list): 预测结果的边列表，每个元素是一个样本的边集合
    返回:
        tuple: (精确率数组, 召回率数组, F1值数组)
    """
    # 准备BERTScore需要的参考文本和候选文本
    references = []  # 存储所有标准答案边
    candidates = []  # 存储所有预测边
    ref_cand_index = {}  # 记录边对在列表中的位置索引

    # 构建参考-候选对并建立索引映射
    for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges):
        for (i, gold_edge) in enumerate(gold_edges):
            for (j, pred_edge) in enumerate(pred_edges):
                references.append(gold_edge)
                candidates.append(pred_edge)
                # 记录每个边对在列表中的位置
                ref_cand_index[(gold_edge, pred_edge)] = len(references) - 1

    # 批量计算所有边对的BERTScore F1值
    _, _, bs_F1 = score_bert(cands=candidates, refs=references, lang='zh', idf=False)
    print("已完成所有边对的BERTScore计算")

    # 初始化各样本的指标存储列表
    precisions, recalls, f1s = [], [], []
    
    # 对每个样本进行匹配和指标计算
    for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges):
        # 创建得分矩阵（金标边数 x 预测边数）
        score_matrix = np.zeros((len(gold_edges), len(pred_edges)))
        # 填充得分矩阵（使用预计算的BERTScore F1值）
        for (i, gold_edge) in enumerate(gold_edges):
            for (j, pred_edge) in enumerate(pred_edges):
                score_matrix[i][j] = bs_F1[ref_cand_index[(gold_edge, pred_edge)]]

        # 使用匈牙利算法进行最优匹配
        row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)

        # 计算样本级别的指标
        matched_scores = score_matrix[row_ind, col_ind]
        epsilon = 1e-10  # 防止除零的最小值
        sample_precision = (matched_scores.sum() + epsilon) / (len(pred_edges) + epsilon)
        sample_recall = (matched_scores.sum() + epsilon) / (len(gold_edges) + epsilon)
        f1 = 2 * sample_precision * sample_recall / (sample_precision + sample_recall + epsilon)
        
        precisions.append(sample_precision)
        recalls.append(sample_recall)
        f1s.append(f1)
        # print(f"样本{i+1}的精确率: {sample_precision:.4f}, 召回率: {sample_recall:.4f}, F1值: {f1:.4f}")

    return np.array(precisions), np.array(recalls), np.array(f1s)

def clean_and_validate_data(gold_graphs, pred_graphs):
    """清理和验证数据"""
    valid_gold = []
    valid_pred = []
    
    for gold, pred in zip(gold_graphs, pred_graphs):
        # 确保两者都不为空
        if gold and pred and isinstance(gold, list) and isinstance(pred, list):
            valid_gold.append(gold)
            valid_pred.append(pred)
            
    return valid_gold, valid_pred
