|
|
from rank_bm25 import BM25Okapi |
|
|
from typing import List, Dict, Any, Tuple |
|
|
import pickle |
|
|
import os |
|
|
from utils.text_processor import VietnameseTextProcessor |
|
|
from config import Config |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
class BM25Retriever: |
|
|
"""BM25 retriever for initial document retrieval""" |
|
|
|
|
|
def __init__(self): |
|
|
self.text_processor = VietnameseTextProcessor() |
|
|
self.bm25 = None |
|
|
self.documents = [] |
|
|
self.tokenized_corpus = [] |
|
|
self.index_file = "bm25_index.pkl" |
|
|
|
|
|
def build_index(self, documents: List[Dict[str, Any]]): |
|
|
"""Build BM25 index from documents""" |
|
|
print("Building BM25 index...") |
|
|
|
|
|
self.documents = documents |
|
|
self.tokenized_corpus = [] |
|
|
|
|
|
|
|
|
for doc in tqdm(documents): |
|
|
content = doc.get("content", "") |
|
|
title = doc.get("title", "") |
|
|
|
|
|
|
|
|
full_text = f"{title} {content}" |
|
|
|
|
|
|
|
|
processed_text = self.text_processor.preprocess_for_search(full_text) |
|
|
tokens = processed_text.split() |
|
|
|
|
|
self.tokenized_corpus.append(tokens) |
|
|
|
|
|
|
|
|
self.bm25 = BM25Okapi(self.tokenized_corpus, b=Config.BM25_B, k1=Config.BM25_K1) |
|
|
|
|
|
print(f"BM25 index built with {len(self.documents)} documents") |
|
|
|
|
|
def save_index(self, filepath: str = None): |
|
|
"""Save BM25 index to file""" |
|
|
if filepath is None: |
|
|
filepath = self.index_file |
|
|
|
|
|
try: |
|
|
index_data = { |
|
|
"bm25": self.bm25, |
|
|
"documents": self.documents, |
|
|
"tokenized_corpus": self.tokenized_corpus, |
|
|
} |
|
|
|
|
|
with open(filepath, "wb") as f: |
|
|
pickle.dump(index_data, f) |
|
|
|
|
|
print(f"BM25 index saved to {filepath}") |
|
|
except Exception as e: |
|
|
print(f"Error saving BM25 index: {e}") |
|
|
|
|
|
def load_index(self, filepath: str = None): |
|
|
"""Load BM25 index from file""" |
|
|
if filepath is None: |
|
|
filepath = self.index_file |
|
|
|
|
|
try: |
|
|
if not os.path.exists(filepath): |
|
|
print(f"Index file {filepath} not found") |
|
|
return False |
|
|
|
|
|
with open(filepath, "rb") as f: |
|
|
index_data = pickle.load(f) |
|
|
|
|
|
self.bm25 = index_data["bm25"] |
|
|
self.documents = index_data["documents"] |
|
|
self.tokenized_corpus = index_data["tokenized_corpus"] |
|
|
|
|
|
print(f"BM25 index loaded from {filepath}") |
|
|
return True |
|
|
except UnicodeDecodeError as e: |
|
|
print(f"Encoding error loading BM25 index: {e}") |
|
|
print(f"Removing corrupted index file: {filepath}") |
|
|
try: |
|
|
os.remove(filepath) |
|
|
except: |
|
|
pass |
|
|
return False |
|
|
except (pickle.UnpicklingError, EOFError) as e: |
|
|
print(f"Corrupted BM25 index file: {e}") |
|
|
print(f"Removing corrupted index file: {filepath}") |
|
|
try: |
|
|
os.remove(filepath) |
|
|
except: |
|
|
pass |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"Error loading BM25 index: {e}") |
|
|
return False |
|
|
|
|
|
def search( |
|
|
self, query: str, top_k: int = None |
|
|
) -> List[Tuple[Dict[str, Any], float]]: |
|
|
"""Search documents using BM25""" |
|
|
if top_k is None: |
|
|
top_k = Config.BM25_TOP_K |
|
|
|
|
|
if self.bm25 is None: |
|
|
print("BM25 index not built. Please build index first.") |
|
|
return [] |
|
|
|
|
|
|
|
|
processed_query = self.text_processor.preprocess_for_search(query) |
|
|
query_tokens = processed_query.split() |
|
|
|
|
|
if not query_tokens: |
|
|
return [] |
|
|
|
|
|
|
|
|
scores = self.bm25.get_scores(query_tokens) |
|
|
|
|
|
|
|
|
doc_score_pairs = [ |
|
|
(self.documents[i], scores[i]) for i in range(len(self.documents)) |
|
|
] |
|
|
|
|
|
|
|
|
doc_score_pairs.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
results = doc_score_pairs[:top_k] |
|
|
|
|
|
print(f"BM25 search returned {len(results)} results for query: {query}") |
|
|
return results |
|
|
|
|
|
def get_relevant_documents( |
|
|
self, query: str, top_k: int = None |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Get relevant documents using BM25""" |
|
|
results = self.search(query, top_k) |
|
|
|
|
|
|
|
|
filtered_results = [] |
|
|
max_score = max([score for _, score in results]) if results else 0 |
|
|
min_score = min([score for _, score in results]) if results else 0 |
|
|
|
|
|
for doc, score in results: |
|
|
if score >= min_score: |
|
|
|
|
|
doc_with_score = doc.copy() |
|
|
|
|
|
doc_with_score['score'] = float((score - min_score) / (max_score - min_score)) if max_score > 0 and min_score > 0 and max_score != min_score else 0 |
|
|
doc_with_score['retrieval_method'] = 'bm25' |
|
|
filtered_results.append(doc_with_score) |
|
|
|
|
|
print(f"BM25 found {len(filtered_results)} relevant documents") |
|
|
return filtered_results |
|
|
|
|
|
def search_with_keywords( |
|
|
self, keywords: List[str], top_k: int = None |
|
|
) -> List[Tuple[Dict[str, Any], float]]: |
|
|
"""Search using multiple keywords""" |
|
|
|
|
|
query = " ".join(keywords) |
|
|
return self.search(query, top_k) |
|
|
|
|
|
def get_index_stats(self) -> Dict[str, Any]: |
|
|
"""Get statistics about the BM25 index""" |
|
|
if self.bm25 is None: |
|
|
return {} |
|
|
|
|
|
return { |
|
|
"total_documents": len(self.documents), |
|
|
"total_tokens": sum(len(tokens) for tokens in self.tokenized_corpus), |
|
|
"average_document_length": sum( |
|
|
len(tokens) for tokens in self.tokenized_corpus |
|
|
) |
|
|
/ len(self.tokenized_corpus) |
|
|
if self.tokenized_corpus |
|
|
else 0, |
|
|
"vocabulary_size": len( |
|
|
set(token for tokens in self.tokenized_corpus for token in tokens) |
|
|
), |
|
|
} |
|
|
|