import json import numpy as np import faiss from transformers import AutoTokenizer, AutoModel import torch class EntityLevelRetriever: def __init__(self, model_name='bert-base-chinese'): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入 self.entity_db = [] self.metadata = [] def _get_entity_span(self, text, entity): """通过精确匹配获取实体在文本中的位置""" start = text.find(entity) if start == -1: return None return (start, start + len(entity)) def _generate_entity_embedding(self, text, entity): """生成实体级上下文嵌入""" span = self._get_entity_span(text, entity) if not span: return None inputs = self.tokenizer(text, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = self.model(**inputs) # 将字符位置转换为token位置 char_to_token = lambda x: inputs.char_to_token(x) start_token = char_to_token(span[0]) end_token = char_to_token(span[1]-1) if not start_token or not end_token: return None # 提取实体对应的token嵌入并平均 entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy() return entity_embedding.astype('float32') def build_index(self, train_path): """构建实体索引""" with open(train_path, 'r', encoding='utf-8') as f: dataset = json.load(f) dataset = dataset[500:1000] embeddings = [] meta_info = [] for idx, item in enumerate(dataset): if idx % 100 == 0: print(f"处理进度: {idx}/{len(dataset)}") text = item['text'] for triple in item['triple_list']: for entity in [triple[0], triple[2]]: try: embedding = self._generate_entity_embedding(text, entity) if embedding is not None: embeddings.append(embedding) meta_info.append({ 'entity': entity, 'type': triple[1], 'context': text }) except Exception as e: print(f"处理实体 {entity} 时出错: {str(e)}") if embeddings: self.entity_db = embeddings self.metadata = meta_info self.index.add(np.array(embeddings)) print(f"索引构建完成 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}") print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}") else: print("警告:没有有效的实体嵌入被添加到索引中") def search_entities(self, test_path, top_k=3, batch_size=32): """优化的实体检索""" with open(test_path, 'r', encoding='utf-8') as f: test_data = json.load(f) results = [] for item_idx, item in enumerate(test_data): if item_idx % 10 == 0: print(f"检索进度: {item_idx}/{len(test_data)}") text = item['text'] entity_results = {} batch_embeddings = [] batch_entities = [] for triple in item['triple_list']: for entity in [triple[0], triple[2]]: embedding = self._generate_entity_embedding(text, entity) if embedding is not None: batch_embeddings.append(embedding) batch_entities.append(entity) if len(batch_embeddings) >= batch_size: self._process_batch(batch_embeddings, batch_entities, entity_results, top_k) batch_embeddings = [] batch_entities = [] # 处理剩余的实体 if batch_embeddings: self._process_batch(batch_embeddings, batch_entities, entity_results, top_k) results.append({ 'text': text, 'entity_matches': entity_results }) return results def _process_batch(self, embeddings, entities, entity_results, top_k): """批量处理实体检索""" distances, indices = self.index.search(np.array(embeddings), top_k) for idx, (entity, dist, ind) in enumerate(zip(entities, distances, indices)): neighbors = [] for j, (distance, index) in enumerate(zip(dist, ind)): if 0 <= index < len(self.metadata): neighbors.append({ 'entity': self.metadata[index]['entity'], 'relation': self.metadata[index]['type'], 'context': self.metadata[index]['context'], 'distance': float(distance) }) entity_results[entity] = neighbors # 使用示例 if __name__ == "__main__": # 初始化检索系统 retriever = EntityLevelRetriever() # 构建训练索引(约需2-5分钟,取决于数据量) print("Building training index...") retriever.build_index('./data/train_triples.json') # 执行测试检索 print("\nSearching similar entities...") results = retriever.search_entities('./data/test_triples.json') # 保存结果 with open('./data/entity_search_results.json', 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) # print("检索完成!结果已保存至entity_search_results.json")