import json import numpy as np import faiss from transformers import AutoTokenizer, AutoModel import torch from collections import defaultdict class EntityLevelRetriever: def __init__(self, model_name='bert-base-chinese'): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入 self.entity_db = [] self.metadata = [] def _get_entity_span(self, text, entity): """通过精确匹配获取实体在文本中的位置""" start = text.find(entity) if start == -1: return None return (start, start + len(entity)) def _generate_entity_embedding(self, text, entity): """生成实体级上下文嵌入""" # 通过BERT模型获取实体在文本中的上下文表示 # 核心实现:提取实体对应token的嵌入并平均 span = self._get_entity_span(text, entity) if not span: return None inputs = self.tokenizer(text, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = self.model(**inputs) # 将字符位置转换为token位置 char_to_token = lambda x: inputs.char_to_token(x) start_token = char_to_token(span[0]) end_token = char_to_token(span[1]-1) if not start_token or not end_token: return None # 提取实体对应的token嵌入并平均 entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy() return entity_embedding.astype('float32') def build_index(self, train_path): """构建实体索引""" # 关键设计:为每个三元组中的头实体和尾实体分别建立索引 # 存储实体嵌入时同时保存关系类型和上下文信息 with open(train_path, 'r', encoding='utf-8') as f: dataset = json.load(f) # 仅处理500-1000索引的数据(演示用切片操作) dataset = dataset[500:1000] for item in dataset: text = item['text'] for triple in item['triple_list']: # 处理头实体和尾实体 for entity in [triple[0], triple[2]]: embedding = self._generate_entity_embedding(text, entity) if embedding is not None: self.entity_db.append(embedding) self.metadata.append({ 'entity': entity, 'type': triple[1], # 保存关系类型 'context': text }) print(f"实体数量检查 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}") self.index.add(np.array(self.entity_db)) print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}") def search_texts(self, test_path, top_k=3, score_mode='weighted'): """ 基于实体聚合的文本级检索 :param score_mode: 评分模式,可选'simple'(简单累加)/'weighted'(带距离权重) """ # 通过以下方式实现实体级到文本级的检索转换 # 1. 对查询文本中的每个实体进行相似搜索 # 2. 聚合多个实体的匹配结果到上下文层面 # 3. 通过加权评分机制综合判断文本相似度 with open(test_path, 'r', encoding='utf-8') as f: test_data = json.load(f) results = [] for item in test_data: text = item['text'] context_scores = defaultdict(float) context_hits = defaultdict(int) # 第一阶段:收集所有实体的匹配上下文 for triple in item['triple_list']: for entity in [triple[0], triple[2]]: embedding = self._generate_entity_embedding(text, entity) if embedding is None: continue distances, indices = self.index.search(np.array([embedding]), top_k) for j in range(top_k): idx = indices[0][j] if 0 <= idx < len(self.metadata): ctx_info = self.metadata[idx] distance = distances[0][j] # 两种评分模式 if score_mode == 'simple': context_scores[ctx_info['context']] += 1 elif score_mode == 'weighted': context_scores[ctx_info['context']] += 1 / (1 + distance) context_hits[ctx_info['context']] += 1 # 第二阶段:结果归一化处理 scored_contexts = [] for ctx, score in context_scores.items(): # 根据命中次数进行归一化 normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0 scored_contexts.append((ctx, normalized_score)) # 按分数排序取前top_k scored_contexts.sort(key=lambda x: x[1], reverse=True) final_results = [{'context': ctx, 'score': float(score)} for ctx, score in scored_contexts[:top_k]] results.append({ 'query_text': text, 'matched_texts': final_results, 'total_hits': sum(context_hits.values()) }) return results # 使用示例 if __name__ == "__main__": # 初始化检索系统 retriever = EntityLevelRetriever() # 构建训练索引(约需2-5分钟,取决于数据量) print("Building training index...") retriever.build_index('./data/train_triples.json') # 执行测试检索 print("\nSearching similar entities...") # 执行改进后的检索 text_results = retriever.search_texts('./data/GT_500.json', top_k=3) # 保存结果 with open('./data/text_retrieval_results.json', 'w', encoding='utf-8') as f: json.dump(text_results, f, ensure_ascii=False, indent=2) print("text_retrieval_results.json")