Spaces:
Runtime error
Runtime error
File size: 6,076 Bytes
badcf3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import json
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
class EntityLevelRetriever:
def __init__(self, model_name='bert-base-chinese'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入
self.entity_db = []
self.metadata = []
def _get_entity_span(self, text, entity):
"""通过精确匹配获取实体在文本中的位置"""
start = text.find(entity)
if start == -1:
return None
return (start, start + len(entity))
def _generate_entity_embedding(self, text, entity):
"""生成实体级上下文嵌入"""
span = self._get_entity_span(text, entity)
if not span:
return None
inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
# 将字符位置转换为token位置
char_to_token = lambda x: inputs.char_to_token(x)
start_token = char_to_token(span[0])
end_token = char_to_token(span[1]-1)
if not start_token or not end_token:
return None
# 提取实体对应的token嵌入并平均
entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
return entity_embedding.astype('float32')
def build_index(self, train_path):
"""构建实体索引"""
with open(train_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
dataset = dataset[500:1000]
embeddings = []
meta_info = []
for idx, item in enumerate(dataset):
if idx % 100 == 0:
print(f"处理进度: {idx}/{len(dataset)}")
text = item['text']
for triple in item['triple_list']:
for entity in [triple[0], triple[2]]:
try:
embedding = self._generate_entity_embedding(text, entity)
if embedding is not None:
embeddings.append(embedding)
meta_info.append({
'entity': entity,
'type': triple[1],
'context': text
})
except Exception as e:
print(f"处理实体 {entity} 时出错: {str(e)}")
if embeddings:
self.entity_db = embeddings
self.metadata = meta_info
self.index.add(np.array(embeddings))
print(f"索引构建完成 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")
else:
print("警告:没有有效的实体嵌入被添加到索引中")
def search_entities(self, test_path, top_k=3, batch_size=32):
"""优化的实体检索"""
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
results = []
for item_idx, item in enumerate(test_data):
if item_idx % 10 == 0:
print(f"检索进度: {item_idx}/{len(test_data)}")
text = item['text']
entity_results = {}
batch_embeddings = []
batch_entities = []
for triple in item['triple_list']:
for entity in [triple[0], triple[2]]:
embedding = self._generate_entity_embedding(text, entity)
if embedding is not None:
batch_embeddings.append(embedding)
batch_entities.append(entity)
if len(batch_embeddings) >= batch_size:
self._process_batch(batch_embeddings, batch_entities, entity_results, top_k)
batch_embeddings = []
batch_entities = []
# 处理剩余的实体
if batch_embeddings:
self._process_batch(batch_embeddings, batch_entities, entity_results, top_k)
results.append({
'text': text,
'entity_matches': entity_results
})
return results
def _process_batch(self, embeddings, entities, entity_results, top_k):
"""批量处理实体检索"""
distances, indices = self.index.search(np.array(embeddings), top_k)
for idx, (entity, dist, ind) in enumerate(zip(entities, distances, indices)):
neighbors = []
for j, (distance, index) in enumerate(zip(dist, ind)):
if 0 <= index < len(self.metadata):
neighbors.append({
'entity': self.metadata[index]['entity'],
'relation': self.metadata[index]['type'],
'context': self.metadata[index]['context'],
'distance': float(distance)
})
entity_results[entity] = neighbors
# 使用示例
if __name__ == "__main__":
# 初始化检索系统
retriever = EntityLevelRetriever()
# 构建训练索引(约需2-5分钟,取决于数据量)
print("Building training index...")
retriever.build_index('./data/train_triples.json')
# 执行测试检索
print("\nSearching similar entities...")
results = retriever.search_entities('./data/test_triples.json')
# 保存结果
with open('./data/entity_search_results.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# print("检索完成!结果已保存至entity_search_results.json")
|