Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.neighbors import NearestNeighbors | |
| from tqdm import tqdm | |
| import pandas as pd | |
| # define the triple retrieval system class | |
| class TripleRetrievalSystem: | |
| def __init__(self, model_name='bert-base-chinese'): | |
| # Initialize BERT tokenizer and model (using pre-trained BERT base model) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| # Initialize training data storage structure | |
| self.train_embeddings = [] | |
| self.train_full_data = [] | |
| def _generate_embeddings(self, text): | |
| inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=4096) | |
| # get BERT model's hidden states (last layer) | |
| outputs = self.model(**inputs) | |
| # Convert output to numpy array and remove batch dimension | |
| return outputs.last_hidden_state.detach().numpy()[0] | |
| def load_train_data(self, train_path, sheet_name): | |
| """Load training data from Excel""" | |
| # Read specified sheet from Excel file | |
| train_df = pd.read_excel(train_path, sheet_name=sheet_name+' Test') | |
| print("Processing training data...") | |
| self.train_embeddings = [] | |
| self.train_full_data = [] # Store complete (text, question, answer) | |
| for _, row in tqdm(train_df.iterrows(), total=len(train_df)): | |
| full_text = f"{row['Text']} {row['Question']}" | |
| embeddings = self._generate_embeddings(full_text) | |
| self.train_embeddings.append(embeddings.mean(axis=0)) | |
| # Store complete triplet data | |
| self.train_full_data.append({ | |
| "text": row['Text'], | |
| "question": row['Question'], | |
| "answer": row['Answer'] | |
| }) | |
| self.train_embeddings = np.array(self.train_embeddings) | |
| self.nbrs = NearestNeighbors(n_neighbors=3, metric='cosine').fit(self.train_embeddings) | |
| def retrieve_similar(self, test_path, output_path, sheet_name): | |
| """Process test data and find similar samples""" | |
| test_df = pd.read_excel(test_path, sheet_name=sheet_name+' Train') | |
| results = [] | |
| print("Processing test data...") | |
| for _, row in tqdm(test_df.iterrows(), total=len(test_df)): | |
| test_text = f"{row['Text']} {row['Question']}" | |
| test_embed = self._generate_embeddings(test_text).mean(axis=0) | |
| # get the top 3 matching results | |
| distances, indices = self.nbrs.kneighbors([test_embed]) | |
| matched = [] | |
| for i, (dist, idx) in enumerate(zip(distances[0], indices[0])): | |
| # get the complete information of the matched sample | |
| matched_data = self.train_full_data[idx] | |
| matched.append({ | |
| "context": matched_data['text'], | |
| "question": matched_data['question'], | |
| "answer": matched_data['answer'], | |
| "score": float(1 - dist) | |
| }) | |
| results.append({ | |
| "test_query": { | |
| "text": row['Text'], | |
| "question": row['Question'], | |
| }, | |
| "matched_samples": matched | |
| }) | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| if __name__ == "__main__": | |
| system = TripleRetrievalSystem() | |
| system.load_train_data('./data/Task2/data.xlsx','Factoid') | |
| system.retrieve_similar( | |
| './data/Task2/data.xlsx', | |
| './data/Task2/knn_factoid.json', | |
| 'Factoid' | |
| ) | |
| system.load_train_data('./data/Task2/data.xlsx','Yes or No') | |
| system.retrieve_similar( | |
| './data/Task2/data.xlsx', | |
| './data/Task2/knn_yes_no1.json', | |
| 'Yes or No' | |
| ) |