Spaces:
Runtime error
Runtime error
File size: 6,427 Bytes
badcf3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import json
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
from collections import defaultdict
class EntityLevelRetriever:
def __init__(self, model_name='bert-base-chinese'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入
self.entity_db = []
self.metadata = []
def _get_entity_span(self, text, entity):
"""通过精确匹配获取实体在文本中的位置"""
start = text.find(entity)
if start == -1:
return None
return (start, start + len(entity))
def _generate_entity_embedding(self, text, entity):
"""生成实体级上下文嵌入"""
# 通过BERT模型获取实体在文本中的上下文表示
# 核心实现:提取实体对应token的嵌入并平均
span = self._get_entity_span(text, entity)
if not span:
return None
inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
# 将字符位置转换为token位置
char_to_token = lambda x: inputs.char_to_token(x)
start_token = char_to_token(span[0])
end_token = char_to_token(span[1]-1)
if not start_token or not end_token:
return None
# 提取实体对应的token嵌入并平均
entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
return entity_embedding.astype('float32')
def build_index(self, train_path):
"""构建实体索引"""
# 关键设计:为每个三元组中的头实体和尾实体分别建立索引
# 存储实体嵌入时同时保存关系类型和上下文信息
with open(train_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
# 仅处理500-1000索引的数据(演示用切片操作)
dataset = dataset[500:1000]
for item in dataset:
text = item['text']
for triple in item['triple_list']:
# 处理头实体和尾实体
for entity in [triple[0], triple[2]]:
embedding = self._generate_entity_embedding(text, entity)
if embedding is not None:
self.entity_db.append(embedding)
self.metadata.append({
'entity': entity,
'type': triple[1], # 保存关系类型
'context': text
})
print(f"实体数量检查 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
self.index.add(np.array(self.entity_db))
print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")
def search_texts(self, test_path, top_k=3, score_mode='weighted'):
"""
基于实体聚合的文本级检索
:param score_mode: 评分模式,可选'simple'(简单累加)/'weighted'(带距离权重)
"""
# 通过以下方式实现实体级到文本级的检索转换
# 1. 对查询文本中的每个实体进行相似搜索
# 2. 聚合多个实体的匹配结果到上下文层面
# 3. 通过加权评分机制综合判断文本相似度
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
results = []
for item in test_data:
text = item['text']
context_scores = defaultdict(float)
context_hits = defaultdict(int)
# 第一阶段:收集所有实体的匹配上下文
for triple in item['triple_list']:
for entity in [triple[0], triple[2]]:
embedding = self._generate_entity_embedding(text, entity)
if embedding is None:
continue
distances, indices = self.index.search(np.array([embedding]), top_k)
for j in range(top_k):
idx = indices[0][j]
if 0 <= idx < len(self.metadata):
ctx_info = self.metadata[idx]
distance = distances[0][j]
# 两种评分模式
if score_mode == 'simple':
context_scores[ctx_info['context']] += 1
elif score_mode == 'weighted':
context_scores[ctx_info['context']] += 1 / (1 + distance)
context_hits[ctx_info['context']] += 1
# 第二阶段:结果归一化处理
scored_contexts = []
for ctx, score in context_scores.items():
# 根据命中次数进行归一化
normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0
scored_contexts.append((ctx, normalized_score))
# 按分数排序取前top_k
scored_contexts.sort(key=lambda x: x[1], reverse=True)
final_results = [{'context': ctx, 'score': float(score)}
for ctx, score in scored_contexts[:top_k]]
results.append({
'query_text': text,
'matched_texts': final_results,
'total_hits': sum(context_hits.values())
})
return results
# 使用示例
if __name__ == "__main__":
# 初始化检索系统
retriever = EntityLevelRetriever()
# 构建训练索引(约需2-5分钟,取决于数据量)
print("Building training index...")
retriever.build_index('./data/train_triples.json')
# 执行测试检索
print("\nSearching similar entities...")
# 执行改进后的检索
text_results = retriever.search_texts('./data/GT_500.json', top_k=3)
# 保存结果
with open('./data/text_retrieval_results.json', 'w', encoding='utf-8') as f:
json.dump(text_results, f, ensure_ascii=False, indent=2)
print("text_retrieval_results.json")
|