Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| import faiss | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from collections import defaultdict | |
| class EntityLevelRetriever: | |
| def __init__(self, model_name='bert-base-chinese'): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入 | |
| self.entity_db = [] | |
| self.metadata = [] | |
| def _get_entity_span(self, text, entity): | |
| """通过精确匹配获取实体在文本中的位置""" | |
| start = text.find(entity) | |
| if start == -1: | |
| return None | |
| return (start, start + len(entity)) | |
| def _generate_entity_embedding(self, text, entity): | |
| """生成实体级上下文嵌入""" | |
| # 通过BERT模型获取实体在文本中的上下文表示 | |
| # 核心实现:提取实体对应token的嵌入并平均 | |
| span = self._get_entity_span(text, entity) | |
| if not span: | |
| return None | |
| inputs = self.tokenizer(text, return_tensors='pt', truncation=True) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # 将字符位置转换为token位置 | |
| char_to_token = lambda x: inputs.char_to_token(x) | |
| start_token = char_to_token(span[0]) | |
| end_token = char_to_token(span[1]-1) | |
| if not start_token or not end_token: | |
| return None | |
| # 提取实体对应的token嵌入并平均 | |
| entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy() | |
| return entity_embedding.astype('float32') | |
| def build_index(self, train_path): | |
| """构建实体索引""" | |
| # 关键设计:为每个三元组中的头实体和尾实体分别建立索引 | |
| # 存储实体嵌入时同时保存关系类型和上下文信息 | |
| with open(train_path, 'r', encoding='utf-8') as f: | |
| dataset = json.load(f) | |
| # 仅处理500-1000索引的数据(演示用切片操作) | |
| dataset = dataset[500:1000] | |
| for item in dataset: | |
| text = item['text'] | |
| for triple in item['triple_list']: | |
| # 处理头实体和尾实体 | |
| for entity in [triple[0], triple[2]]: | |
| embedding = self._generate_entity_embedding(text, entity) | |
| if embedding is not None: | |
| self.entity_db.append(embedding) | |
| self.metadata.append({ | |
| 'entity': entity, | |
| 'type': triple[1], # 保存关系类型 | |
| 'context': text | |
| }) | |
| print(f"实体数量检查 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}") | |
| self.index.add(np.array(self.entity_db)) | |
| print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}") | |
| def search_texts(self, test_path, top_k=3, score_mode='weighted'): | |
| """ | |
| 基于实体聚合的文本级检索 | |
| :param score_mode: 评分模式,可选'simple'(简单累加)/'weighted'(带距离权重) | |
| """ | |
| # 通过以下方式实现实体级到文本级的检索转换 | |
| # 1. 对查询文本中的每个实体进行相似搜索 | |
| # 2. 聚合多个实体的匹配结果到上下文层面 | |
| # 3. 通过加权评分机制综合判断文本相似度 | |
| with open(test_path, 'r', encoding='utf-8') as f: | |
| test_data = json.load(f) | |
| results = [] | |
| for item in test_data: | |
| text = item['text'] | |
| context_scores = defaultdict(float) | |
| context_hits = defaultdict(int) | |
| # 第一阶段:收集所有实体的匹配上下文 | |
| for triple in item['triple_list']: | |
| for entity in [triple[0], triple[2]]: | |
| embedding = self._generate_entity_embedding(text, entity) | |
| if embedding is None: | |
| continue | |
| distances, indices = self.index.search(np.array([embedding]), top_k) | |
| for j in range(top_k): | |
| idx = indices[0][j] | |
| if 0 <= idx < len(self.metadata): | |
| ctx_info = self.metadata[idx] | |
| distance = distances[0][j] | |
| # 两种评分模式 | |
| if score_mode == 'simple': | |
| context_scores[ctx_info['context']] += 1 | |
| elif score_mode == 'weighted': | |
| context_scores[ctx_info['context']] += 1 / (1 + distance) | |
| context_hits[ctx_info['context']] += 1 | |
| # 第二阶段:结果归一化处理 | |
| scored_contexts = [] | |
| for ctx, score in context_scores.items(): | |
| # 根据命中次数进行归一化 | |
| normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0 | |
| scored_contexts.append((ctx, normalized_score)) | |
| # 按分数排序取前top_k | |
| scored_contexts.sort(key=lambda x: x[1], reverse=True) | |
| final_results = [{'context': ctx, 'score': float(score)} | |
| for ctx, score in scored_contexts[:top_k]] | |
| results.append({ | |
| 'query_text': text, | |
| 'matched_texts': final_results, | |
| 'total_hits': sum(context_hits.values()) | |
| }) | |
| return results | |
| # 使用示例 | |
| if __name__ == "__main__": | |
| # 初始化检索系统 | |
| retriever = EntityLevelRetriever() | |
| # 构建训练索引(约需2-5分钟,取决于数据量) | |
| print("Building training index...") | |
| retriever.build_index('./data/train_triples.json') | |
| # 执行测试检索 | |
| print("\nSearching similar entities...") | |
| # 执行改进后的检索 | |
| text_results = retriever.search_texts('./data/GT_500.json', top_k=3) | |
| # 保存结果 | |
| with open('./data/text_retrieval_results.json', 'w', encoding='utf-8') as f: | |
| json.dump(text_results, f, ensure_ascii=False, indent=2) | |
| print("text_retrieval_results.json") | |