''' Entity-level retrieval using KNN algorithm for geological triplet extraction ''' import json import numpy as np import faiss from transformers import AutoTokenizer, AutoModel import torch from collections import defaultdict class EntityLevelRetriever: def __init__(self, model_name='bert-base-chinese'): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.index = faiss.IndexFlatL2(768) self.entity_db = [] self.metadata = [] def _get_entity_span(self, text, entity): """Get character span of entity in text""" start = text.find(entity) if start == -1: return None return (start, start + len(entity)) def _generate_entity_embedding(self, text, entity): """Generate entity-level contextual embeddings""" span = self._get_entity_span(text, entity) if not span: return None inputs = self.tokenizer(text, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = self.model(**inputs) # Convert character positions to token positions char_to_token = lambda x: inputs.char_to_token(x) start_token = char_to_token(span[0]) end_token = char_to_token(span[1]-1) if not start_token or not end_token: return None # Extract and average token embeddings for the entity entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy() return entity_embedding.astype('float32') def build_index(self, train_path): """Build entity index from training data""" with open(train_path, 'r', encoding='utf-8') as f: dataset = json.load(f) dataset = dataset[500:1000] for item in dataset: text = item['text'] for triple in item['triple_list']: # Process head and tail entities for entity in [triple[0], triple[2]]: embedding = self._generate_entity_embedding(text, entity) if embedding is not None: self.entity_db.append(embedding) self.metadata.append({ 'entity': entity, 'type': triple[1], # Save relation type 'context': text }) # print(f"Entity count check - Vector count: {len(self.entity_db)}, Metadata count: {len(self.metadata)}") self.index.add(np.array(self.entity_db)) # print(f"Index dimension: {self.index.d}, Stored count: {self.index.ntotal}") def search_texts(self, test_path, top_k=3, score_mode='weighted'): """Search for similar texts based on entity matching""" with open(test_path, 'r', encoding='utf-8') as f: test_data = json.load(f) results = [] for item in test_data: text = item['text'] context_scores = defaultdict(float) context_hits = defaultdict(int) # Extract entities from query text and generate embeddings for triple in item['triple_list']: for entity in [triple[0], triple[2]]: embedding = self._generate_entity_embedding(text, entity) if embedding is None: continue distances, indices = self.index.search(np.array([embedding]), top_k) for j in range(top_k): idx = indices[0][j] if 0 <= idx < len(self.metadata): ctx_info = self.metadata[idx] distance = distances[0][j] # Update context scores based on distance if score_mode == 'simple': context_scores[ctx_info['context']] += 1 elif score_mode == 'weighted': context_scores[ctx_info['context']] += 1 / (1 + distance) context_hits[ctx_info['context']] += 1 # Results normalization scored_contexts = [] for ctx, score in context_scores.items(): # Normalize by hit count normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0 scored_contexts.append((ctx, normalized_score)) # Sort by score and take top_k scored_contexts.sort(key=lambda x: x[1], reverse=True) final_results = [{'context': ctx, 'score': float(score)} for ctx, score in scored_contexts[:top_k]] results.append({ 'query_text': text, 'matched_texts': final_results, 'total_hits': sum(context_hits.values()) }) return results # Example usage if __name__ == "__main__": # Initialize retriever retriever = EntityLevelRetriever() # Build training index print("Building training index...") retriever.build_index('./data/Task1/train_triples.json') # Execute test retrieval print("\nSearching similar entities...") text_results = retriever.search_texts('./data/Task1/GT_500.json', top_k=3) # Save results with open('./data/Task1/text_retrieval_results.json', 'w', encoding='utf-8') as f: json.dump(text_results, f, ensure_ascii=False, indent=2)