Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.neighbors import NearestNeighbors | |
| from tqdm import tqdm | |
| # 定义三元组检索系统类 | |
| class TripleRetrievalSystem: | |
| def __init__(self, model_name='bert-base-uncased'): | |
| # 初始化BERT分词器和模型(使用预训练的BERT基础模型) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| # 初始化训练数据存储结构 | |
| self.train_embeddings = [] | |
| self.train_texts = [] | |
| def _generate_embeddings(self, text): | |
| """生成上下文敏感的token嵌入""" | |
| # 对输入文本进行分词和编码(自动截断到512个token) | |
| inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=4096) | |
| # 获取BERT模型的隐藏层输出(最后一层) | |
| outputs = self.model(**inputs) | |
| # 将输出转换为numpy数组并去除批次维度 | |
| return outputs.last_hidden_state.detach().numpy()[0] | |
| def load_train_data(self, train_path): | |
| """预处理并存储训练数据嵌入""" | |
| with open(train_path, encoding='utf-8') as f: | |
| train_data = json.load(f) | |
| print("Processing training data...") | |
| # 仅处理500-1000索引的数据(演示用切片操作) | |
| train_data = train_data[500:1000] | |
| # 使用进度条遍历训练数据 | |
| for item in tqdm(train_data): | |
| text = item['text'] | |
| # 生成文本的token级嵌入 | |
| embeddings = self._generate_embeddings(text) | |
| # 平均池化操作(将token向量平均为文本向量) | |
| self.train_embeddings.append(embeddings.mean(axis=0)) | |
| self.train_texts.append(text) | |
| # 转换为numpy数组提升计算效率 | |
| self.train_embeddings = np.array(self.train_embeddings) | |
| # 构建k-NN模型(使用余弦相似度,k=1) | |
| self.nbrs = NearestNeighbors(n_neighbors=1, metric='cosine').fit(self.train_embeddings) | |
| def retrieve_similar(self, test_path, output_path): | |
| """处理测试数据并查找相似训练样本""" | |
| with open(test_path, encoding='utf-8') as f: | |
| test_data = json.load(f) | |
| results = [] | |
| print("Processing test data...") | |
| # 遍历测试数据并生成结果 | |
| for item in tqdm(test_data): | |
| test_text = item['text'] | |
| # 生成测试文本的嵌入向量 | |
| test_embed = self._generate_embeddings(test_text).mean(axis=0) | |
| # 查找最近邻(返回距离和索引) | |
| distances, indices = self.nbrs.kneighbors([test_embed]) | |
| # 收集相关训练文本 | |
| relevant = [self.train_texts[i] for i in indices[0]] | |
| results.append({ | |
| "test_text": test_text, | |
| "relevant_train_texts": relevant | |
| }) | |
| # 保存JSON结果(确保中文字符正常显示) | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| # 示例运行 | |
| if __name__ == "__main__": | |
| system = TripleRetrievalSystem() | |
| system.load_train_data('./data/train_triples.json') | |
| # 处理测试数据并输出结果 | |
| system.retrieve_similar( | |
| './data/test_triples.json', | |
| './data/output_results.json' | |
| ) | |