File size: 6,427 Bytes
badcf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import json
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
from collections import defaultdict
class EntityLevelRetriever:
    def __init__(self, model_name='bert-base-chinese'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.index = faiss.IndexFlatL2(768)  # L2距离更适合BERT嵌入
        self.entity_db = []
        self.metadata = []

    def _get_entity_span(self, text, entity):
        """通过精确匹配获取实体在文本中的位置"""
        start = text.find(entity)
        if start == -1:
            return None
        return (start, start + len(entity))

    def _generate_entity_embedding(self, text, entity):
        """生成实体级上下文嵌入"""
        # 通过BERT模型获取实体在文本中的上下文表示
        # 核心实现:提取实体对应token的嵌入并平均
        span = self._get_entity_span(text, entity)
        if not span:
            return None

        inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)

        # 将字符位置转换为token位置
        char_to_token = lambda x: inputs.char_to_token(x)
        start_token = char_to_token(span[0])
        end_token = char_to_token(span[1]-1)

        if not start_token or not end_token:
            return None

        # 提取实体对应的token嵌入并平均
        entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
        return entity_embedding.astype('float32')

    def build_index(self, train_path):
        """构建实体索引"""
        # 关键设计:为每个三元组中的头实体和尾实体分别建立索引
        # 存储实体嵌入时同时保存关系类型和上下文信息
        with open(train_path, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
        # 仅处理500-1000索引的数据(演示用切片操作)
        dataset = dataset[500:1000]
        for item in dataset:
            text = item['text']
            for triple in item['triple_list']:
                # 处理头实体和尾实体
                for entity in [triple[0], triple[2]]:
                    embedding = self._generate_entity_embedding(text, entity)
                    if embedding is not None:
                        self.entity_db.append(embedding)
                        self.metadata.append({
                            'entity': entity,
                            'type': triple[1],  # 保存关系类型
                            'context': text
                        })

        print(f"实体数量检查 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
        self.index.add(np.array(self.entity_db))
        print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")

    def search_texts(self, test_path, top_k=3, score_mode='weighted'):
        """

        基于实体聚合的文本级检索

        :param score_mode: 评分模式,可选'simple'(简单累加)/'weighted'(带距离权重)

        """
        # 通过以下方式实现实体级到文本级的检索转换
        # 1. 对查询文本中的每个实体进行相似搜索
        # 2. 聚合多个实体的匹配结果到上下文层面
        # 3. 通过加权评分机制综合判断文本相似度
        with open(test_path, 'r', encoding='utf-8') as f:
            test_data = json.load(f)

        results = []
        for item in test_data:
            text = item['text']
            context_scores = defaultdict(float)
            context_hits = defaultdict(int)

            # 第一阶段:收集所有实体的匹配上下文
            for triple in item['triple_list']:
                for entity in [triple[0], triple[2]]:
                    embedding = self._generate_entity_embedding(text, entity)
                    if embedding is None:
                        continue

                    distances, indices = self.index.search(np.array([embedding]), top_k)

                    for j in range(top_k):
                        idx = indices[0][j]
                        if 0 <= idx < len(self.metadata):
                            ctx_info = self.metadata[idx]
                            distance = distances[0][j]

                            # 两种评分模式
                            if score_mode == 'simple':
                                context_scores[ctx_info['context']] += 1
                            elif score_mode == 'weighted':
                                context_scores[ctx_info['context']] += 1 / (1 + distance)

                            context_hits[ctx_info['context']] += 1

            # 第二阶段:结果归一化处理
            scored_contexts = []
            for ctx, score in context_scores.items():
                # 根据命中次数进行归一化
                normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0
                scored_contexts.append((ctx, normalized_score))

            # 按分数排序取前top_k
            scored_contexts.sort(key=lambda x: x[1], reverse=True)
            final_results = [{'context': ctx, 'score': float(score)}
                           for ctx, score in scored_contexts[:top_k]]

            results.append({
                'query_text': text,
                'matched_texts': final_results,
                'total_hits': sum(context_hits.values())
            })

        return results
# 使用示例
if __name__ == "__main__":
    # 初始化检索系统
    retriever = EntityLevelRetriever()

    # 构建训练索引(约需2-5分钟,取决于数据量)
    print("Building training index...")
    retriever.build_index('./data/train_triples.json')

    # 执行测试检索
    print("\nSearching similar entities...")
    # 执行改进后的检索
    text_results = retriever.search_texts('./data/GT_500.json', top_k=3)

    # 保存结果
    with open('./data/text_retrieval_results.json', 'w', encoding='utf-8') as f:
        json.dump(text_results, f, ensure_ascii=False, indent=2)

    print("text_retrieval_results.json")