Upload 4 files
Browse files- main/bm25_retriever.py +180 -0
- main/chatbot.py +652 -0
- main/reranker.py +199 -0
- main/vector_store.py +244 -0
main/bm25_retriever.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rank_bm25 import BM25Okapi
|
| 2 |
+
from typing import List, Dict, Any, Tuple
|
| 3 |
+
import pickle
|
| 4 |
+
import os
|
| 5 |
+
from utils.text_processor import VietnameseTextProcessor
|
| 6 |
+
from config import Config
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
|
| 9 |
+
os.makedirs("index", exist_ok=True)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BM25Retriever:
|
| 13 |
+
"""BM25 retriever for initial document retrieval"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.text_processor = VietnameseTextProcessor()
|
| 17 |
+
self.bm25 = None
|
| 18 |
+
self.documents = []
|
| 19 |
+
self.tokenized_corpus = []
|
| 20 |
+
self.index_file = "index/bm25_index.pkl"
|
| 21 |
+
|
| 22 |
+
def build_index(self, documents: List[Dict[str, Any]]):
|
| 23 |
+
"""Build BM25 index from documents"""
|
| 24 |
+
print("Building BM25 index...")
|
| 25 |
+
|
| 26 |
+
self.documents = documents
|
| 27 |
+
self.tokenized_corpus = []
|
| 28 |
+
|
| 29 |
+
# Tokenize all documents
|
| 30 |
+
for doc in tqdm(documents):
|
| 31 |
+
content = doc.get("content", "")
|
| 32 |
+
title = doc.get("title", "")
|
| 33 |
+
|
| 34 |
+
# Combine title and content for better search
|
| 35 |
+
full_text = f"{title} {content}"
|
| 36 |
+
|
| 37 |
+
# Preprocess and tokenize
|
| 38 |
+
processed_text = self.text_processor.preprocess_for_search(full_text)
|
| 39 |
+
tokens = processed_text.split()
|
| 40 |
+
|
| 41 |
+
self.tokenized_corpus.append(tokens)
|
| 42 |
+
|
| 43 |
+
# Build BM25 index
|
| 44 |
+
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 45 |
+
|
| 46 |
+
print(f"BM25 index built with {len(self.documents)} documents")
|
| 47 |
+
|
| 48 |
+
def save_index(self, filepath: str = None):
|
| 49 |
+
"""Save BM25 index to file"""
|
| 50 |
+
if filepath is None:
|
| 51 |
+
filepath = self.index_file
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
index_data = {
|
| 55 |
+
"bm25": self.bm25,
|
| 56 |
+
"documents": self.documents,
|
| 57 |
+
"tokenized_corpus": self.tokenized_corpus,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
with open(filepath, "wb") as f:
|
| 61 |
+
pickle.dump(index_data, f)
|
| 62 |
+
|
| 63 |
+
print(f"BM25 index saved to {filepath}")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"Error saving BM25 index: {e}")
|
| 66 |
+
|
| 67 |
+
def load_index(self, filepath: str = None):
|
| 68 |
+
"""Load BM25 index from file"""
|
| 69 |
+
if filepath is None:
|
| 70 |
+
filepath = self.index_file
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
if not os.path.exists(filepath):
|
| 74 |
+
print(f"Index file {filepath} not found")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
with open(filepath, "rb") as f:
|
| 78 |
+
index_data = pickle.load(f)
|
| 79 |
+
|
| 80 |
+
self.bm25 = index_data["bm25"]
|
| 81 |
+
self.documents = index_data["documents"]
|
| 82 |
+
self.tokenized_corpus = index_data["tokenized_corpus"]
|
| 83 |
+
|
| 84 |
+
print(f"BM25 index loaded from {filepath}")
|
| 85 |
+
return True
|
| 86 |
+
except UnicodeDecodeError as e:
|
| 87 |
+
print(f"Encoding error loading BM25 index: {e}")
|
| 88 |
+
print(f"Removing corrupted index file: {filepath}")
|
| 89 |
+
try:
|
| 90 |
+
os.remove(filepath)
|
| 91 |
+
except:
|
| 92 |
+
pass
|
| 93 |
+
return False
|
| 94 |
+
except (pickle.UnpicklingError, EOFError) as e:
|
| 95 |
+
print(f"Corrupted BM25 index file: {e}")
|
| 96 |
+
print(f"Removing corrupted index file: {filepath}")
|
| 97 |
+
try:
|
| 98 |
+
os.remove(filepath)
|
| 99 |
+
except:
|
| 100 |
+
pass
|
| 101 |
+
return False
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"Error loading BM25 index: {e}")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
def search(
|
| 107 |
+
self, query: str, top_k: int = None
|
| 108 |
+
) -> List[Tuple[Dict[str, Any], float]]:
|
| 109 |
+
"""Search documents using BM25"""
|
| 110 |
+
if top_k is None:
|
| 111 |
+
top_k = Config.BM25_TOP_K
|
| 112 |
+
|
| 113 |
+
if self.bm25 is None:
|
| 114 |
+
print("BM25 index not built. Please build index first.")
|
| 115 |
+
return []
|
| 116 |
+
|
| 117 |
+
# Preprocess query
|
| 118 |
+
processed_query = self.text_processor.preprocess_for_search(query)
|
| 119 |
+
query_tokens = processed_query.split()
|
| 120 |
+
|
| 121 |
+
if not query_tokens:
|
| 122 |
+
return []
|
| 123 |
+
|
| 124 |
+
# Get BM25 scores
|
| 125 |
+
scores = self.bm25.get_scores(query_tokens)
|
| 126 |
+
|
| 127 |
+
# Get top documents with scores
|
| 128 |
+
doc_score_pairs = [
|
| 129 |
+
(self.documents[i], scores[i]) for i in range(len(self.documents))
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
# Sort by score (descending)
|
| 133 |
+
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
|
| 134 |
+
|
| 135 |
+
# Return top k
|
| 136 |
+
results = doc_score_pairs[:top_k]
|
| 137 |
+
|
| 138 |
+
print(f"BM25 search returned {len(results)} results for query: {query}")
|
| 139 |
+
return results
|
| 140 |
+
|
| 141 |
+
def get_relevant_documents(
|
| 142 |
+
self, query: str, top_k: int = None, min_score: float = 0.0
|
| 143 |
+
) -> List[Dict[str, Any]]:
|
| 144 |
+
"""Get relevant documents above minimum score threshold"""
|
| 145 |
+
results = self.search(query, top_k)
|
| 146 |
+
|
| 147 |
+
# Filter by minimum score
|
| 148 |
+
filtered_results = [
|
| 149 |
+
(doc, score) for doc, score in results if score >= min_score
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# Return only documents
|
| 153 |
+
return [doc for doc, score in filtered_results]
|
| 154 |
+
|
| 155 |
+
def search_with_keywords(
|
| 156 |
+
self, keywords: List[str], top_k: int = None
|
| 157 |
+
) -> List[Tuple[Dict[str, Any], float]]:
|
| 158 |
+
"""Search using multiple keywords"""
|
| 159 |
+
# Combine keywords into a single query
|
| 160 |
+
query = " ".join(keywords)
|
| 161 |
+
return self.search(query, top_k)
|
| 162 |
+
|
| 163 |
+
def get_index_stats(self) -> Dict[str, Any]:
|
| 164 |
+
"""Get statistics about the BM25 index"""
|
| 165 |
+
if self.bm25 is None:
|
| 166 |
+
return {}
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"total_documents": len(self.documents),
|
| 170 |
+
"total_tokens": sum(len(tokens) for tokens in self.tokenized_corpus),
|
| 171 |
+
"average_document_length": sum(
|
| 172 |
+
len(tokens) for tokens in self.tokenized_corpus
|
| 173 |
+
)
|
| 174 |
+
/ len(self.tokenized_corpus)
|
| 175 |
+
if self.tokenized_corpus
|
| 176 |
+
else 0,
|
| 177 |
+
"vocabulary_size": len(
|
| 178 |
+
set(token for tokens in self.tokenized_corpus for token in tokens)
|
| 179 |
+
),
|
| 180 |
+
}
|
main/chatbot.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
+
from langchain.prompts import PromptTemplate
|
| 4 |
+
from langchain.schema import HumanMessage, SystemMessage
|
| 5 |
+
|
| 6 |
+
from main.vector_store import QdrantVectorStore
|
| 7 |
+
from main.bm25_retriever import BM25Retriever
|
| 8 |
+
from main.reranker import DocumentReranker
|
| 9 |
+
from utils.text_processor import VietnameseTextProcessor
|
| 10 |
+
from utils.google_search import GoogleSearchTool
|
| 11 |
+
from utils.question_refiner import VietnameseLegalQuestionRefiner
|
| 12 |
+
from config import Config
|
| 13 |
+
|
| 14 |
+
class VietnameseLegalRAG:
|
| 15 |
+
"""Vietnamese Legal RAG System"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.vector_store = None
|
| 19 |
+
self.bm25_retriever = None
|
| 20 |
+
self.reranker = None
|
| 21 |
+
self.llm = None
|
| 22 |
+
self.text_processor = VietnameseTextProcessor()
|
| 23 |
+
self.google_search = GoogleSearchTool()
|
| 24 |
+
self.question_refiner = VietnameseLegalQuestionRefiner()
|
| 25 |
+
|
| 26 |
+
self._initialize_components()
|
| 27 |
+
|
| 28 |
+
def _initialize_components(self):
|
| 29 |
+
"""Initialize RAG components"""
|
| 30 |
+
try:
|
| 31 |
+
# Initialize LLM
|
| 32 |
+
if Config.GOOGLE_API_KEY:
|
| 33 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 34 |
+
model=Config.MODEL_GEN,
|
| 35 |
+
google_api_key=Config.GOOGLE_API_KEY,
|
| 36 |
+
temperature=0.1
|
| 37 |
+
)
|
| 38 |
+
print("Google Gemini LLM initialized")
|
| 39 |
+
else:
|
| 40 |
+
print("Warning: Google API key not found")
|
| 41 |
+
|
| 42 |
+
# Initialize vector store
|
| 43 |
+
self.vector_store = QdrantVectorStore()
|
| 44 |
+
|
| 45 |
+
# Initialize BM25 retriever
|
| 46 |
+
self.bm25_retriever = BM25Retriever()
|
| 47 |
+
|
| 48 |
+
# Initialize reranker if enabled
|
| 49 |
+
if Config.ENABLE_RERANKING:
|
| 50 |
+
self.reranker = DocumentReranker(model_name=Config.RERANKER_MODEL)
|
| 51 |
+
else:
|
| 52 |
+
print("Reranking disabled in configuration")
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"Error initializing RAG components: {e}")
|
| 56 |
+
|
| 57 |
+
def setup_indices(self, documents: List[Dict[str, Any]], force_rebuild: bool = False):
|
| 58 |
+
"""Setup both vector and BM25 indices"""
|
| 59 |
+
print("Setting up RAG indices...")
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
# Setup vector store
|
| 63 |
+
if self.vector_store:
|
| 64 |
+
# Check if we need to create collection
|
| 65 |
+
try:
|
| 66 |
+
# First, do a simple existence check
|
| 67 |
+
collections = self.vector_store.client.get_collections().collections
|
| 68 |
+
collection_exists = any(col.name == self.vector_store.collection_name for col in collections)
|
| 69 |
+
print(f"Collection existence check: {collection_exists}")
|
| 70 |
+
|
| 71 |
+
if collection_exists:
|
| 72 |
+
# Collection exists, try to get detailed info
|
| 73 |
+
try:
|
| 74 |
+
collection_info = self.vector_store.get_collection_info()
|
| 75 |
+
has_documents = collection_info.get('points_count', 0) > 0
|
| 76 |
+
|
| 77 |
+
if force_rebuild:
|
| 78 |
+
print("Force rebuild requested - recreating vector store...")
|
| 79 |
+
self.vector_store.create_collection(force_recreate=True)
|
| 80 |
+
self.vector_store.add_documents(documents)
|
| 81 |
+
elif not has_documents:
|
| 82 |
+
print("Collection exists but is empty - adding documents...")
|
| 83 |
+
self.vector_store.add_documents(documents)
|
| 84 |
+
else:
|
| 85 |
+
print(f"Vector store collection already exists with {collection_info.get('points_count', 0)} documents")
|
| 86 |
+
except Exception as info_e:
|
| 87 |
+
print(f"Could not get collection info: {info_e}")
|
| 88 |
+
if force_rebuild:
|
| 89 |
+
print("Force rebuild requested - recreating vector store...")
|
| 90 |
+
self.vector_store.create_collection(force_recreate=True)
|
| 91 |
+
self.vector_store.add_documents(documents)
|
| 92 |
+
else:
|
| 93 |
+
print("Assuming collection has documents - skipping setup")
|
| 94 |
+
else:
|
| 95 |
+
# Collection doesn't exist, create it
|
| 96 |
+
print("Collection does not exist - creating new collection...")
|
| 97 |
+
self.vector_store.create_collection(force_recreate=False)
|
| 98 |
+
self.vector_store.add_documents(documents)
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"Error during vector store setup: {e}")
|
| 102 |
+
print("Attempting to create collection...")
|
| 103 |
+
self.vector_store.create_collection()
|
| 104 |
+
self.vector_store.add_documents(documents)
|
| 105 |
+
|
| 106 |
+
# Setup BM25 index
|
| 107 |
+
if self.bm25_retriever:
|
| 108 |
+
# Try to load existing index
|
| 109 |
+
if not self.bm25_retriever.load_index() or force_rebuild:
|
| 110 |
+
self.bm25_retriever.build_index(documents)
|
| 111 |
+
self.bm25_retriever.save_index()
|
| 112 |
+
else:
|
| 113 |
+
print("BM25 index loaded from file")
|
| 114 |
+
|
| 115 |
+
print("RAG indices setup completed")
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Error setting up indices: {e}")
|
| 119 |
+
raise
|
| 120 |
+
|
| 121 |
+
def retrieve_documents(self, query: str, use_hybrid: bool = True, use_reranking: bool = None) -> List[Dict[str, Any]]:
|
| 122 |
+
"""Retrieve relevant documents using hybrid approach with optional reranking"""
|
| 123 |
+
retrieved_docs = []
|
| 124 |
+
|
| 125 |
+
# Use config default if not specified
|
| 126 |
+
if use_reranking is None:
|
| 127 |
+
use_reranking = Config.ENABLE_RERANKING
|
| 128 |
+
|
| 129 |
+
# Adjust retrieval counts if reranking is enabled
|
| 130 |
+
bm25_top_k = Config.RERANK_BEFORE_RETRIEVAL_TOP_K if use_reranking else Config.BM25_TOP_K
|
| 131 |
+
vector_top_k = Config.RERANK_BEFORE_RETRIEVAL_TOP_K if use_reranking else Config.TOP_K_RETRIEVAL
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
if use_hybrid and self.bm25_retriever and self.vector_store:
|
| 135 |
+
# Hybrid retrieval: BM25 + Vector Search
|
| 136 |
+
|
| 137 |
+
# BM25 retrieval
|
| 138 |
+
bm25_results = self.bm25_retriever.get_relevant_documents(
|
| 139 |
+
query, top_k=bm25_top_k
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Vector search
|
| 143 |
+
vector_results = self.vector_store.search_similar_documents(
|
| 144 |
+
query, top_k=vector_top_k
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Combine and deduplicate results
|
| 148 |
+
all_docs = {}
|
| 149 |
+
|
| 150 |
+
# Add BM25 results
|
| 151 |
+
for doc in bm25_results:
|
| 152 |
+
doc_id = doc.get('id', '')
|
| 153 |
+
if doc_id:
|
| 154 |
+
all_docs[doc_id] = {**doc, 'retrieval_method': 'bm25'}
|
| 155 |
+
|
| 156 |
+
# Add vector results
|
| 157 |
+
for doc in vector_results:
|
| 158 |
+
doc_id = doc.get('id', '')
|
| 159 |
+
if doc_id:
|
| 160 |
+
if doc_id in all_docs:
|
| 161 |
+
# Combine scores if document found by both methods
|
| 162 |
+
all_docs[doc_id]['retrieval_method'] = 'hybrid'
|
| 163 |
+
all_docs[doc_id]['vector_score'] = doc.get('score', 0)
|
| 164 |
+
else:
|
| 165 |
+
all_docs[doc_id] = {**doc, 'retrieval_method': 'vector'}
|
| 166 |
+
|
| 167 |
+
retrieved_docs = list(all_docs.values())
|
| 168 |
+
|
| 169 |
+
elif self.vector_store:
|
| 170 |
+
# Vector search only
|
| 171 |
+
retrieved_docs = self.vector_store.search_similar_documents(query, top_k=vector_top_k)
|
| 172 |
+
|
| 173 |
+
elif self.bm25_retriever:
|
| 174 |
+
# BM25 only
|
| 175 |
+
retrieved_docs = self.bm25_retriever.get_relevant_documents(query, top_k=bm25_top_k)
|
| 176 |
+
|
| 177 |
+
# Improved similarity filtering logic
|
| 178 |
+
if retrieved_docs:
|
| 179 |
+
# Check for high-quality documents first
|
| 180 |
+
high_quality_docs = []
|
| 181 |
+
moderate_quality_docs = []
|
| 182 |
+
|
| 183 |
+
for doc in retrieved_docs:
|
| 184 |
+
score = doc.get('score', 0)
|
| 185 |
+
if score >= Config.SIMILARITY_THRESHOLD:
|
| 186 |
+
high_quality_docs.append(doc)
|
| 187 |
+
elif score >= Config.MIN_SIMILARITY_FOR_LEGAL_DOCS:
|
| 188 |
+
moderate_quality_docs.append(doc)
|
| 189 |
+
|
| 190 |
+
# Return high quality docs if available
|
| 191 |
+
if high_quality_docs:
|
| 192 |
+
print(f"Retrieved {len(high_quality_docs)} high-quality documents")
|
| 193 |
+
return high_quality_docs[:Config.TOP_K_RETRIEVAL]
|
| 194 |
+
|
| 195 |
+
# Return moderate quality docs if no high quality ones
|
| 196 |
+
elif moderate_quality_docs:
|
| 197 |
+
print(f"Retrieved {len(moderate_quality_docs)} moderate-quality documents")
|
| 198 |
+
return moderate_quality_docs[:Config.TOP_K_RETRIEVAL]
|
| 199 |
+
|
| 200 |
+
else:
|
| 201 |
+
print("No documents found with sufficient similarity scores")
|
| 202 |
+
return []
|
| 203 |
+
|
| 204 |
+
# Apply reranking if enabled
|
| 205 |
+
if use_reranking and self.reranker and retrieved_docs:
|
| 206 |
+
print(f"Applying reranking to {len(retrieved_docs)} documents...")
|
| 207 |
+
|
| 208 |
+
if Config.USE_SCORE_FUSION:
|
| 209 |
+
# Use score fusion for better results
|
| 210 |
+
retrieved_docs = self.reranker.rerank_with_fusion(
|
| 211 |
+
query,
|
| 212 |
+
retrieved_docs,
|
| 213 |
+
alpha=Config.RERANKER_FUSION_ALPHA,
|
| 214 |
+
top_k=Config.RERANKER_TOP_K
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
# Use pure reranker scores
|
| 218 |
+
retrieved_docs = self.reranker.rerank_documents(
|
| 219 |
+
query,
|
| 220 |
+
retrieved_docs,
|
| 221 |
+
top_k=Config.RERANKER_TOP_K
|
| 222 |
+
)
|
| 223 |
+
print(f"Reranking completed, returning {len(retrieved_docs)} documents")
|
| 224 |
+
else:
|
| 225 |
+
# No reranking, limit results to original config
|
| 226 |
+
retrieved_docs = retrieved_docs[:Config.TOP_K_RETRIEVAL]
|
| 227 |
+
|
| 228 |
+
print(f"Retrieved {len(retrieved_docs)} documents for query")
|
| 229 |
+
return retrieved_docs
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"Error retrieving documents: {e}")
|
| 233 |
+
return []
|
| 234 |
+
|
| 235 |
+
def format_context(self, documents: List[Dict[str, Any]]) -> str:
|
| 236 |
+
"""Format retrieved documents as context for LLM"""
|
| 237 |
+
if not documents:
|
| 238 |
+
return "Không có tài liệu pháp luật liên quan được tìm thấy."
|
| 239 |
+
|
| 240 |
+
context_parts = []
|
| 241 |
+
|
| 242 |
+
for i, doc in enumerate(documents, 1):
|
| 243 |
+
title = doc.get('title', 'Không có tiêu đề')
|
| 244 |
+
content = doc.get('content', '')
|
| 245 |
+
doc_id = doc.get('id', '')
|
| 246 |
+
metadata = doc.get('metadata', {})
|
| 247 |
+
law_id = metadata.get('law_id', '')
|
| 248 |
+
article_id = metadata.get('article_id', '')
|
| 249 |
+
|
| 250 |
+
# Limit content length
|
| 251 |
+
if len(content) > 500:
|
| 252 |
+
content = content[:500] + "..."
|
| 253 |
+
|
| 254 |
+
# Format law and article information
|
| 255 |
+
law_info = f"Luật: {law_id}" if law_id else ""
|
| 256 |
+
article_info = f"Điều {article_id}" if article_id else f"ID: {doc_id}"
|
| 257 |
+
|
| 258 |
+
context_part = f"""
|
| 259 |
+
Tài liệu {i}:
|
| 260 |
+
{law_info}
|
| 261 |
+
{article_info}: {title}
|
| 262 |
+
Nội dung: {content}
|
| 263 |
+
"""
|
| 264 |
+
context_parts.append(context_part.strip())
|
| 265 |
+
|
| 266 |
+
return "\n\n".join(context_parts)
|
| 267 |
+
|
| 268 |
+
def generate_answer(self, query: str, context: str, is_fallback: bool = False) -> str:
|
| 269 |
+
"""Generate answer using LLM with context"""
|
| 270 |
+
if not self.llm:
|
| 271 |
+
return "Lỗi: Không thể kết nối với mô hình ngôn ngữ."
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
# Create prompt based on scenario
|
| 275 |
+
if is_fallback:
|
| 276 |
+
prompt = Config.FALLBACK_SYSTEM_PROMPT.format(
|
| 277 |
+
context=context,
|
| 278 |
+
question=query
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
prompt = Config.SYSTEM_PROMPT.format(
|
| 282 |
+
context=context,
|
| 283 |
+
question=query
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Generate response
|
| 287 |
+
messages = [HumanMessage(content=prompt)]
|
| 288 |
+
response = self.llm(messages)
|
| 289 |
+
|
| 290 |
+
answer = response.content
|
| 291 |
+
|
| 292 |
+
# Add legal disclaimer if enabled and it's a fallback response
|
| 293 |
+
if Config.SHOW_LEGAL_DISCLAIMER and is_fallback:
|
| 294 |
+
disclaimer = "\n\nLưu ý quan trọng: Để đảm bảo quyền lợi của mình, người lao động nên tìm đến các chuyên gia pháp lý hoặc cơ quan chức năng có thẩm quyền để được tư vấn cụ thể và chính xác nhất dựa trên tình hình thực tế của mình."
|
| 295 |
+
answer += disclaimer
|
| 296 |
+
|
| 297 |
+
return answer
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
print(f"Error generating answer: {e}")
|
| 301 |
+
return f"Lỗi khi tạo câu trả lời: {str(e)}"
|
| 302 |
+
|
| 303 |
+
def _is_negative_response(self, response: str) -> bool:
|
| 304 |
+
"""Check if the response is a negative/unable to answer response"""
|
| 305 |
+
negative_indicators = [
|
| 306 |
+
"không thể trả lời",
|
| 307 |
+
"không tìm thấy",
|
| 308 |
+
"không có thông tin",
|
| 309 |
+
"xin lỗi",
|
| 310 |
+
"không thể tìm thấy",
|
| 311 |
+
"không có dữ liệu",
|
| 312 |
+
"không rõ",
|
| 313 |
+
"không biết",
|
| 314 |
+
"không đủ thông tin",
|
| 315 |
+
"thiếu thông tin",
|
| 316 |
+
"không có trong",
|
| 317 |
+
"ngoài phạm vi",
|
| 318 |
+
# Add the specific pattern mentioned by user
|
| 319 |
+
"không có đủ thông tin trong tài liệu tham khảo được cung cấp để trả lời trực tiếp câu hỏi này",
|
| 320 |
+
"cần tham khảo thêm các văn bản pháp luật khác",
|
| 321 |
+
"tìm kiếm thông tin chuyên sâu hơn về",
|
| 322 |
+
"tài liệu tham khảo không chứa thông tin đầy đủ"
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
response_lower = response.lower()
|
| 326 |
+
return any(indicator in response_lower for indicator in negative_indicators)
|
| 327 |
+
|
| 328 |
+
def _enhanced_search_fallback(self, query: str) -> List[Dict[str, Any]]:
|
| 329 |
+
"""Enhanced search fallback with multiple strategies"""
|
| 330 |
+
print("Attempting enhanced search fallback...")
|
| 331 |
+
|
| 332 |
+
all_docs = []
|
| 333 |
+
|
| 334 |
+
# Strategy 1: Lower similarity threshold
|
| 335 |
+
try:
|
| 336 |
+
if self.bm25_retriever and self.vector_store:
|
| 337 |
+
# BM25 with more results
|
| 338 |
+
bm25_results = self.bm25_retriever.get_relevant_documents(
|
| 339 |
+
query, top_k=15 # Increased from default
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Vector search with lower threshold
|
| 343 |
+
vector_results = self.vector_store.search_similar_documents(
|
| 344 |
+
query, top_k=10 # Increased from default
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Combine results with very low threshold
|
| 348 |
+
combined_docs = {}
|
| 349 |
+
|
| 350 |
+
for doc in bm25_results:
|
| 351 |
+
doc_id = doc.get('id', '')
|
| 352 |
+
if doc_id and doc.get('score', 0) >= 0.1: # Very low threshold
|
| 353 |
+
combined_docs[doc_id] = {**doc, 'retrieval_method': 'bm25_enhanced'}
|
| 354 |
+
|
| 355 |
+
for doc in vector_results:
|
| 356 |
+
doc_id = doc.get('id', '')
|
| 357 |
+
if doc_id and doc.get('score', 0) >= 0.15: # Very low threshold
|
| 358 |
+
if doc_id in combined_docs:
|
| 359 |
+
combined_docs[doc_id]['retrieval_method'] = 'hybrid_enhanced'
|
| 360 |
+
else:
|
| 361 |
+
combined_docs[doc_id] = {**doc, 'retrieval_method': 'vector_enhanced'}
|
| 362 |
+
|
| 363 |
+
all_docs = list(combined_docs.values())
|
| 364 |
+
|
| 365 |
+
except Exception as e:
|
| 366 |
+
print(f"Error in enhanced search: {e}")
|
| 367 |
+
|
| 368 |
+
# Strategy 2: Keyword-based search if still no results
|
| 369 |
+
if not all_docs:
|
| 370 |
+
try:
|
| 371 |
+
# Extract keywords from query
|
| 372 |
+
keywords = self.text_processor.extract_keywords(query)
|
| 373 |
+
if keywords:
|
| 374 |
+
keyword_query = " ".join(keywords[:5]) # Use top 5 keywords
|
| 375 |
+
|
| 376 |
+
if self.bm25_retriever:
|
| 377 |
+
keyword_results = self.bm25_retriever.get_relevant_documents(
|
| 378 |
+
keyword_query, top_k=10
|
| 379 |
+
)
|
| 380 |
+
all_docs.extend([{**doc, 'retrieval_method': 'keyword_fallback'}
|
| 381 |
+
for doc in keyword_results if doc.get('score', 0) >= 0.05])
|
| 382 |
+
|
| 383 |
+
except Exception as e:
|
| 384 |
+
print(f"Error in keyword fallback: {e}")
|
| 385 |
+
|
| 386 |
+
print(f"Enhanced search found {len(all_docs)} documents")
|
| 387 |
+
return all_docs[:Config.TOP_K_RETRIEVAL]
|
| 388 |
+
|
| 389 |
+
def answer_question(self, query: str, use_fallback: bool = True, refine_question: bool = True) -> Dict[str, Any]:
|
| 390 |
+
"""Answer a legal question using RAG with enhanced negative response handling and question refinement"""
|
| 391 |
+
print(f"Processing question: {query}")
|
| 392 |
+
|
| 393 |
+
# Step 1: Refine the question if enabled
|
| 394 |
+
original_query = query
|
| 395 |
+
refinement_result = None
|
| 396 |
+
|
| 397 |
+
if refine_question and Config.ENABLE_QUESTION_REFINEMENT and self.question_refiner:
|
| 398 |
+
print("🔧 Refining question for better search accuracy...")
|
| 399 |
+
refinement_result = self.question_refiner.refine_question(query, use_llm=Config.USE_LLM_FOR_REFINEMENT)
|
| 400 |
+
|
| 401 |
+
if refinement_result["refined_question"] != query:
|
| 402 |
+
refined_query = refinement_result["refined_question"]
|
| 403 |
+
print(f"📝 Original: {query}")
|
| 404 |
+
print(f"✨ Refined: {refined_query}")
|
| 405 |
+
query = refined_query
|
| 406 |
+
|
| 407 |
+
# Step 2: Retrieve relevant documents using refined query
|
| 408 |
+
retrieved_docs = self.retrieve_documents(query)
|
| 409 |
+
|
| 410 |
+
# Check if we have relevant documents
|
| 411 |
+
if not retrieved_docs and Config.ENABLE_GOOGLE_SEARCH and use_fallback:
|
| 412 |
+
print("No relevant legal documents found, using Google search fallback")
|
| 413 |
+
|
| 414 |
+
# Use Google search as fallback
|
| 415 |
+
search_results = self.google_search.search_legal_info(query)
|
| 416 |
+
|
| 417 |
+
if search_results:
|
| 418 |
+
fallback_context = self.google_search.format_search_results(search_results)
|
| 419 |
+
|
| 420 |
+
# Generate answer with fallback context
|
| 421 |
+
fallback_answer = self.generate_answer(query, fallback_context, True)
|
| 422 |
+
|
| 423 |
+
return {
|
| 424 |
+
'answer': fallback_answer,
|
| 425 |
+
'retrieved_documents': [],
|
| 426 |
+
'fallback_used': True,
|
| 427 |
+
'search_results': search_results,
|
| 428 |
+
'context': fallback_context,
|
| 429 |
+
'search_results_html': self.google_search.format_search_results_for_display(search_results),
|
| 430 |
+
'original_question': original_query,
|
| 431 |
+
'refined_question': query,
|
| 432 |
+
'question_refinement': refinement_result
|
| 433 |
+
}
|
| 434 |
+
else:
|
| 435 |
+
# Try enhanced search before giving up
|
| 436 |
+
enhanced_docs = self._enhanced_search_fallback(query)
|
| 437 |
+
if enhanced_docs:
|
| 438 |
+
enhanced_context = self.format_context(enhanced_docs)
|
| 439 |
+
enhanced_answer = self.generate_answer(query, enhanced_context, False)
|
| 440 |
+
|
| 441 |
+
return {
|
| 442 |
+
'answer': enhanced_answer,
|
| 443 |
+
'retrieved_documents': enhanced_docs,
|
| 444 |
+
'fallback_used': True,
|
| 445 |
+
'context': enhanced_context,
|
| 446 |
+
'search_results': [],
|
| 447 |
+
'search_results_html': "",
|
| 448 |
+
'enhanced_search_used': True,
|
| 449 |
+
'original_question': original_query,
|
| 450 |
+
'refined_question': query,
|
| 451 |
+
'question_refinement': refinement_result
|
| 452 |
+
}
|
| 453 |
+
else:
|
| 454 |
+
return {
|
| 455 |
+
'answer': "Xin lỗi, tôi không tìm thấy thông tin pháp luật liên quan đến câu hỏi của bạn trong cơ sở dữ liệu nội bộ và cũng không thể tìm kiếm thông tin trên web. Vui lòng thử lại với câu hỏi khác hoặc liên hệ với chuyên gia pháp lý.",
|
| 456 |
+
'retrieved_documents': [],
|
| 457 |
+
'fallback_used': True,
|
| 458 |
+
'search_results': [],
|
| 459 |
+
'context': "",
|
| 460 |
+
'search_results_html': "",
|
| 461 |
+
'enhanced_search_used': False,
|
| 462 |
+
'original_question': original_query,
|
| 463 |
+
'refined_question': query,
|
| 464 |
+
'question_refinement': refinement_result
|
| 465 |
+
}
|
| 466 |
+
elif not retrieved_docs:
|
| 467 |
+
# Try enhanced search before giving negative response
|
| 468 |
+
enhanced_docs = self._enhanced_search_fallback(query)
|
| 469 |
+
if enhanced_docs:
|
| 470 |
+
enhanced_context = self.format_context(enhanced_docs)
|
| 471 |
+
enhanced_answer = self.generate_answer(query, enhanced_context, False)
|
| 472 |
+
|
| 473 |
+
return {
|
| 474 |
+
'answer': enhanced_answer,
|
| 475 |
+
'retrieved_documents': enhanced_docs,
|
| 476 |
+
'fallback_used': False,
|
| 477 |
+
'context': enhanced_context,
|
| 478 |
+
'search_results': [],
|
| 479 |
+
'search_results_html': "",
|
| 480 |
+
'enhanced_search_used': True,
|
| 481 |
+
'original_question': original_query,
|
| 482 |
+
'refined_question': query,
|
| 483 |
+
'question_refinement': refinement_result
|
| 484 |
+
}
|
| 485 |
+
else:
|
| 486 |
+
return {
|
| 487 |
+
'answer': "Xin lỗi, tôi không tìm thấy thông tin pháp luật liên quan đến câu hỏi của bạn trong cơ sở dữ liệu.",
|
| 488 |
+
'retrieved_documents': [],
|
| 489 |
+
'fallback_used': False,
|
| 490 |
+
'context': "",
|
| 491 |
+
'search_results': [],
|
| 492 |
+
'search_results_html': "",
|
| 493 |
+
'enhanced_search_used': False,
|
| 494 |
+
'original_question': original_query,
|
| 495 |
+
'refined_question': query,
|
| 496 |
+
'question_refinement': refinement_result
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
# Format context
|
| 500 |
+
context = self.format_context(retrieved_docs)
|
| 501 |
+
|
| 502 |
+
# Generate answer
|
| 503 |
+
answer = self.generate_answer(query, context, False)
|
| 504 |
+
|
| 505 |
+
# Check if the generated answer is negative and retry with enhanced search
|
| 506 |
+
if self._is_negative_response(answer) and use_fallback:
|
| 507 |
+
print("🔍 Detected insufficient information response, activating search tools...")
|
| 508 |
+
|
| 509 |
+
# Inform user that search is being performed
|
| 510 |
+
search_notification = f"\n\n*🔍 Đang tìm kiếm thông tin bổ sung để trả lời câu hỏi của bạn...*"
|
| 511 |
+
|
| 512 |
+
# Try enhanced search first
|
| 513 |
+
enhanced_docs = self._enhanced_search_fallback(query)
|
| 514 |
+
|
| 515 |
+
if enhanced_docs:
|
| 516 |
+
# Combine original and enhanced docs
|
| 517 |
+
all_docs = retrieved_docs + [doc for doc in enhanced_docs
|
| 518 |
+
if doc.get('id') not in [d.get('id') for d in retrieved_docs]]
|
| 519 |
+
enhanced_context = self.format_context(all_docs[:Config.TOP_K_RETRIEVAL])
|
| 520 |
+
enhanced_answer = self.generate_answer(query, enhanced_context, False)
|
| 521 |
+
|
| 522 |
+
# If still negative, try Google search
|
| 523 |
+
if self._is_negative_response(enhanced_answer) and Config.ENABLE_GOOGLE_SEARCH:
|
| 524 |
+
print("📡 Enhanced search still insufficient, trying web search...")
|
| 525 |
+
search_results = self.google_search.search_legal_info(query)
|
| 526 |
+
|
| 527 |
+
if search_results:
|
| 528 |
+
fallback_context = self.google_search.format_search_results(search_results)
|
| 529 |
+
final_answer = self.generate_answer(query, fallback_context, True)
|
| 530 |
+
|
| 531 |
+
# Add notification about web search usage (if enabled)
|
| 532 |
+
if Config.SHOW_SOURCE_INFO:
|
| 533 |
+
final_answer = f"{final_answer}\n\n*🌐 Thông tin này được tìm kiếm từ web do không tìm thấy đủ thông tin trong cơ sở dữ liệu pháp luật nội bộ.*"
|
| 534 |
+
|
| 535 |
+
return {
|
| 536 |
+
'answer': final_answer,
|
| 537 |
+
'retrieved_documents': all_docs,
|
| 538 |
+
'fallback_used': True,
|
| 539 |
+
'context': enhanced_context,
|
| 540 |
+
'search_results': search_results,
|
| 541 |
+
'search_results_html': self.google_search.format_search_results_for_display(search_results),
|
| 542 |
+
'enhanced_search_used': True,
|
| 543 |
+
'search_triggered': True,
|
| 544 |
+
'original_question': original_query,
|
| 545 |
+
'refined_question': query,
|
| 546 |
+
'question_refinement': refinement_result
|
| 547 |
+
}
|
| 548 |
+
else:
|
| 549 |
+
# Web search failed, return enhanced answer with notification (if enabled)
|
| 550 |
+
if Config.SHOW_SOURCE_INFO:
|
| 551 |
+
enhanced_answer = f"{enhanced_answer}\n\n*🔍 Đã sử dụng tìm kiếm nâng cao trong cơ sở dữ liệu.*"
|
| 552 |
+
|
| 553 |
+
return {
|
| 554 |
+
'answer': enhanced_answer,
|
| 555 |
+
'retrieved_documents': all_docs,
|
| 556 |
+
'fallback_used': False,
|
| 557 |
+
'context': enhanced_context,
|
| 558 |
+
'search_results': [],
|
| 559 |
+
'search_results_html': "",
|
| 560 |
+
'enhanced_search_used': True,
|
| 561 |
+
'search_triggered': True,
|
| 562 |
+
'original_question': original_query,
|
| 563 |
+
'refined_question': query,
|
| 564 |
+
'question_refinement': refinement_result
|
| 565 |
+
}
|
| 566 |
+
else:
|
| 567 |
+
# Enhanced search was sufficient (if enabled)
|
| 568 |
+
if Config.SHOW_SOURCE_INFO:
|
| 569 |
+
enhanced_answer = f"{enhanced_answer}\n\n*🔍 Đã sử dụng tìm kiếm nâng cao để tìm thông tin này.*"
|
| 570 |
+
|
| 571 |
+
return {
|
| 572 |
+
'answer': enhanced_answer,
|
| 573 |
+
'retrieved_documents': all_docs,
|
| 574 |
+
'fallback_used': False,
|
| 575 |
+
'context': enhanced_context,
|
| 576 |
+
'search_results': [],
|
| 577 |
+
'search_results_html': "",
|
| 578 |
+
'enhanced_search_used': True,
|
| 579 |
+
'search_triggered': True,
|
| 580 |
+
'original_question': original_query,
|
| 581 |
+
'refined_question': query,
|
| 582 |
+
'question_refinement': refinement_result
|
| 583 |
+
}
|
| 584 |
+
elif Config.ENABLE_GOOGLE_SEARCH:
|
| 585 |
+
# Try Google search as last resort
|
| 586 |
+
print("📡 Database search failed, trying web search as last resort...")
|
| 587 |
+
search_results = self.google_search.search_legal_info(query)
|
| 588 |
+
|
| 589 |
+
if search_results:
|
| 590 |
+
fallback_context = self.google_search.format_search_results(search_results)
|
| 591 |
+
final_answer = self.generate_answer(query, fallback_context, True)
|
| 592 |
+
|
| 593 |
+
# Add notification about web search usage (if enabled)
|
| 594 |
+
if Config.SHOW_SOURCE_INFO:
|
| 595 |
+
final_answer = f"{final_answer}\n\n*🌐 Thông tin này được tìm kiếm từ web do không tìm thấy trong cơ sở dữ liệu pháp luật nội bộ.*"
|
| 596 |
+
|
| 597 |
+
return {
|
| 598 |
+
'answer': final_answer,
|
| 599 |
+
'retrieved_documents': retrieved_docs,
|
| 600 |
+
'fallback_used': True,
|
| 601 |
+
'context': context,
|
| 602 |
+
'search_results': search_results,
|
| 603 |
+
'search_results_html': self.google_search.format_search_results_for_display(search_results),
|
| 604 |
+
'enhanced_search_used': False,
|
| 605 |
+
'search_triggered': True,
|
| 606 |
+
'original_question': original_query,
|
| 607 |
+
'refined_question': query,
|
| 608 |
+
'question_refinement': refinement_result
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
return {
|
| 612 |
+
'answer': answer,
|
| 613 |
+
'retrieved_documents': retrieved_docs,
|
| 614 |
+
'fallback_used': False,
|
| 615 |
+
'context': context,
|
| 616 |
+
'search_results': [],
|
| 617 |
+
'search_results_html': "",
|
| 618 |
+
'enhanced_search_used': False,
|
| 619 |
+
'search_triggered': False,
|
| 620 |
+
'original_question': original_query,
|
| 621 |
+
'refined_question': query,
|
| 622 |
+
'question_refinement': refinement_result
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
def get_system_status(self) -> Dict[str, Any]:
|
| 626 |
+
"""Get status of RAG system components"""
|
| 627 |
+
status = {
|
| 628 |
+
'llm_available': self.llm is not None,
|
| 629 |
+
'vector_store_available': self.vector_store is not None,
|
| 630 |
+
'bm25_available': self.bm25_retriever is not None,
|
| 631 |
+
'reranker_available': self.reranker is not None and self.reranker.model is not None,
|
| 632 |
+
'reranking_enabled': Config.ENABLE_RERANKING,
|
| 633 |
+
'google_api_configured': bool(Config.GOOGLE_API_KEY),
|
| 634 |
+
'qdrant_configured': bool(Config.QDRANT_URL and Config.QDRANT_API_KEY)
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
# Get collection info if available
|
| 638 |
+
if self.vector_store:
|
| 639 |
+
try:
|
| 640 |
+
status['vector_store_info'] = self.vector_store.get_collection_info()
|
| 641 |
+
except:
|
| 642 |
+
status['vector_store_info'] = {}
|
| 643 |
+
|
| 644 |
+
# Get BM25 stats if available
|
| 645 |
+
if self.bm25_retriever:
|
| 646 |
+
status['bm25_stats'] = self.bm25_retriever.get_index_stats()
|
| 647 |
+
|
| 648 |
+
# Get reranker info if available
|
| 649 |
+
if self.reranker:
|
| 650 |
+
status['reranker_info'] = self.reranker.get_model_info()
|
| 651 |
+
|
| 652 |
+
return status
|
main/reranker.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sentence_transformers import CrossEncoder
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
class DocumentReranker:
|
| 7 |
+
"""Document reranker using cross-encoder models for improved relevance scoring"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the reranker with a cross-encoder model
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model_name: Name of the cross-encoder model to use
|
| 15 |
+
Default is a multilingual model that works well with Vietnamese
|
| 16 |
+
"""
|
| 17 |
+
self.model_name = model_name
|
| 18 |
+
self.model = None
|
| 19 |
+
self._initialize_model()
|
| 20 |
+
|
| 21 |
+
def _initialize_model(self):
|
| 22 |
+
"""Initialize the cross-encoder model"""
|
| 23 |
+
try:
|
| 24 |
+
print(f"Loading reranker model: {self.model_name}")
|
| 25 |
+
self.model = CrossEncoder(self.model_name)
|
| 26 |
+
print("Reranker model loaded successfully")
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error loading reranker model: {e}")
|
| 29 |
+
# Fallback to a lighter model
|
| 30 |
+
try:
|
| 31 |
+
fallback_model = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
| 32 |
+
print(f"Trying fallback model: {fallback_model}")
|
| 33 |
+
self.model = CrossEncoder(fallback_model)
|
| 34 |
+
self.model_name = fallback_model
|
| 35 |
+
print("Fallback reranker model loaded successfully")
|
| 36 |
+
except Exception as e2:
|
| 37 |
+
print(f"Error loading fallback model: {e2}")
|
| 38 |
+
self.model = None
|
| 39 |
+
|
| 40 |
+
def rerank_documents(self,
|
| 41 |
+
query: str,
|
| 42 |
+
documents: List[Dict[str, Any]],
|
| 43 |
+
top_k: int = None) -> List[Dict[str, Any]]:
|
| 44 |
+
"""
|
| 45 |
+
Rerank documents based on relevance to the query
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
query: The search query
|
| 49 |
+
documents: List of documents to rerank
|
| 50 |
+
top_k: Number of top documents to return (None for all)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
List of reranked documents with updated scores
|
| 54 |
+
"""
|
| 55 |
+
if not self.model or not documents:
|
| 56 |
+
return documents
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
# Prepare query-document pairs for the cross-encoder
|
| 60 |
+
pairs = []
|
| 61 |
+
valid_docs = []
|
| 62 |
+
|
| 63 |
+
for doc in documents:
|
| 64 |
+
content = doc.get('content', '')
|
| 65 |
+
title = doc.get('title', '')
|
| 66 |
+
|
| 67 |
+
# Combine title and content for better matching
|
| 68 |
+
doc_text = f"{title}. {content}" if title else content
|
| 69 |
+
|
| 70 |
+
# Truncate very long documents to avoid model limits
|
| 71 |
+
max_length = 512
|
| 72 |
+
if len(doc_text) > max_length:
|
| 73 |
+
doc_text = doc_text[:max_length] + "..."
|
| 74 |
+
|
| 75 |
+
pairs.append([query, doc_text])
|
| 76 |
+
valid_docs.append(doc)
|
| 77 |
+
|
| 78 |
+
if not pairs:
|
| 79 |
+
return documents
|
| 80 |
+
|
| 81 |
+
# Get relevance scores from cross-encoder
|
| 82 |
+
scores = self.model.predict(pairs)
|
| 83 |
+
|
| 84 |
+
# Update documents with reranker scores
|
| 85 |
+
reranked_docs = []
|
| 86 |
+
for doc, score in zip(valid_docs, scores):
|
| 87 |
+
# Create a copy to avoid modifying original
|
| 88 |
+
reranked_doc = doc.copy()
|
| 89 |
+
|
| 90 |
+
# Store both original and reranker scores
|
| 91 |
+
reranked_doc['reranker_score'] = float(score)
|
| 92 |
+
reranked_doc['original_score'] = doc.get('score', 0.0)
|
| 93 |
+
|
| 94 |
+
# Update the main score with reranker score
|
| 95 |
+
reranked_doc['score'] = float(score)
|
| 96 |
+
|
| 97 |
+
# Add reranking method info
|
| 98 |
+
if 'retrieval_method' in reranked_doc:
|
| 99 |
+
reranked_doc['retrieval_method'] += '_reranked'
|
| 100 |
+
else:
|
| 101 |
+
reranked_doc['retrieval_method'] = 'reranked'
|
| 102 |
+
|
| 103 |
+
reranked_docs.append(reranked_doc)
|
| 104 |
+
|
| 105 |
+
# Sort by reranker score (descending)
|
| 106 |
+
reranked_docs.sort(key=lambda x: x['reranker_score'], reverse=True)
|
| 107 |
+
|
| 108 |
+
# Return top_k if specified
|
| 109 |
+
if top_k:
|
| 110 |
+
reranked_docs = reranked_docs[:top_k]
|
| 111 |
+
|
| 112 |
+
print(f"Reranked {len(reranked_docs)} documents")
|
| 113 |
+
|
| 114 |
+
# Log top scores for debugging
|
| 115 |
+
if reranked_docs:
|
| 116 |
+
top_scores = [doc['reranker_score'] for doc in reranked_docs[:3]]
|
| 117 |
+
print(f"Top reranker scores: {top_scores}")
|
| 118 |
+
|
| 119 |
+
return reranked_docs
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Error during reranking: {e}")
|
| 123 |
+
# Return original documents if reranking fails
|
| 124 |
+
return documents
|
| 125 |
+
|
| 126 |
+
def rerank_with_fusion(self,
|
| 127 |
+
query: str,
|
| 128 |
+
documents: List[Dict[str, Any]],
|
| 129 |
+
alpha: float = 0.7,
|
| 130 |
+
top_k: int = None) -> List[Dict[str, Any]]:
|
| 131 |
+
"""
|
| 132 |
+
Rerank documents using score fusion between original and reranker scores
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
query: The search query
|
| 136 |
+
documents: List of documents to rerank
|
| 137 |
+
alpha: Weight for reranker score (0-1), higher means more weight on reranker
|
| 138 |
+
top_k: Number of top documents to return
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List of reranked documents with fused scores
|
| 142 |
+
"""
|
| 143 |
+
if not self.model or not documents:
|
| 144 |
+
return documents
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
# First get reranker scores
|
| 148 |
+
reranked_docs = self.rerank_documents(query, documents, top_k=None)
|
| 149 |
+
|
| 150 |
+
if not reranked_docs:
|
| 151 |
+
return documents
|
| 152 |
+
|
| 153 |
+
# Normalize original scores to 0-1 range
|
| 154 |
+
original_scores = [doc.get('original_score', 0.0) for doc in reranked_docs]
|
| 155 |
+
if max(original_scores) > 0:
|
| 156 |
+
original_scores_norm = [s / max(original_scores) for s in original_scores]
|
| 157 |
+
else:
|
| 158 |
+
original_scores_norm = [0.0] * len(original_scores)
|
| 159 |
+
|
| 160 |
+
# Normalize reranker scores to 0-1 range
|
| 161 |
+
reranker_scores = [doc.get('reranker_score', 0.0) for doc in reranked_docs]
|
| 162 |
+
min_reranker = min(reranker_scores)
|
| 163 |
+
max_reranker = max(reranker_scores)
|
| 164 |
+
|
| 165 |
+
if max_reranker > min_reranker:
|
| 166 |
+
reranker_scores_norm = [(s - min_reranker) / (max_reranker - min_reranker)
|
| 167 |
+
for s in reranker_scores]
|
| 168 |
+
else:
|
| 169 |
+
reranker_scores_norm = [0.5] * len(reranker_scores)
|
| 170 |
+
|
| 171 |
+
# Compute fused scores
|
| 172 |
+
for i, doc in enumerate(reranked_docs):
|
| 173 |
+
fused_score = (alpha * reranker_scores_norm[i] +
|
| 174 |
+
(1 - alpha) * original_scores_norm[i])
|
| 175 |
+
doc['fused_score'] = fused_score
|
| 176 |
+
doc['score'] = fused_score # Update main score
|
| 177 |
+
|
| 178 |
+
# Sort by fused score
|
| 179 |
+
reranked_docs.sort(key=lambda x: x['fused_score'], reverse=True)
|
| 180 |
+
|
| 181 |
+
# Return top_k if specified
|
| 182 |
+
if top_k:
|
| 183 |
+
reranked_docs = reranked_docs[:top_k]
|
| 184 |
+
|
| 185 |
+
print(f"Score fusion reranking completed for {len(reranked_docs)} documents")
|
| 186 |
+
|
| 187 |
+
return reranked_docs
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Error during fusion reranking: {e}")
|
| 191 |
+
return documents
|
| 192 |
+
|
| 193 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 194 |
+
"""Get information about the loaded model"""
|
| 195 |
+
return {
|
| 196 |
+
'model_name': self.model_name,
|
| 197 |
+
'model_loaded': self.model is not None,
|
| 198 |
+
'model_type': 'cross-encoder'
|
| 199 |
+
}
|
main/vector_store.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
from qdrant_client import QdrantClient
|
| 5 |
+
from qdrant_client.models import (
|
| 6 |
+
Distance,
|
| 7 |
+
VectorParams,
|
| 8 |
+
PointStruct,
|
| 9 |
+
Filter,
|
| 10 |
+
FieldCondition,
|
| 11 |
+
MatchValue,
|
| 12 |
+
)
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
import numpy as np
|
| 15 |
+
from config import Config
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class QdrantVectorStore:
|
| 20 |
+
"""QDrant vector store for legal document embeddings"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.client = None
|
| 24 |
+
self.embedding_model = None
|
| 25 |
+
self.collection_name = Config.COLLECTION_NAME
|
| 26 |
+
self._initialize_client()
|
| 27 |
+
self._initialize_embedding_model()
|
| 28 |
+
|
| 29 |
+
def _initialize_client(self):
|
| 30 |
+
"""Initialize QDrant client"""
|
| 31 |
+
try:
|
| 32 |
+
if Config.QDRANT_URL and Config.QDRANT_API_KEY:
|
| 33 |
+
# Cloud QDrant
|
| 34 |
+
self.client = QdrantClient(
|
| 35 |
+
url=Config.QDRANT_URL, api_key=Config.QDRANT_API_KEY
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
# Local QDrant (fallback)
|
| 39 |
+
self.client = QdrantClient(host="localhost", port=6333)
|
| 40 |
+
|
| 41 |
+
print("QDrant client initialized successfully")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Error initializing QDrant client: {e}")
|
| 44 |
+
raise
|
| 45 |
+
|
| 46 |
+
def _initialize_embedding_model(self):
|
| 47 |
+
"""Initialize embedding model"""
|
| 48 |
+
try:
|
| 49 |
+
# Clear any potentially corrupted cache
|
| 50 |
+
import tempfile
|
| 51 |
+
import shutil
|
| 52 |
+
|
| 53 |
+
cache_dir = os.path.join(tempfile.gettempdir(), "sentence_transformers")
|
| 54 |
+
|
| 55 |
+
self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
|
| 56 |
+
print(f"Embedding model {Config.EMBEDDING_MODEL} loaded successfully")
|
| 57 |
+
except UnicodeDecodeError as e:
|
| 58 |
+
print(f"Encoding error loading embedding model: {e}")
|
| 59 |
+
print("Trying to clear sentence-transformers cache...")
|
| 60 |
+
try:
|
| 61 |
+
import tempfile
|
| 62 |
+
import shutil
|
| 63 |
+
|
| 64 |
+
cache_dir = os.path.join(tempfile.gettempdir(), "sentence_transformers")
|
| 65 |
+
if os.path.exists(cache_dir):
|
| 66 |
+
shutil.rmtree(cache_dir)
|
| 67 |
+
print("Cache cleared, retrying...")
|
| 68 |
+
|
| 69 |
+
self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
|
| 70 |
+
print(
|
| 71 |
+
f"Embedding model {Config.EMBEDDING_MODEL} loaded successfully after cache clear"
|
| 72 |
+
)
|
| 73 |
+
except Exception as retry_e:
|
| 74 |
+
print(
|
| 75 |
+
f"Failed to load embedding model even after cache clear: {retry_e}"
|
| 76 |
+
)
|
| 77 |
+
raise
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Error loading embedding model: {e}")
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
def create_collection(self, vector_size: int = 384, force_recreate: bool = False):
|
| 83 |
+
"""Create collection in QDrant"""
|
| 84 |
+
try:
|
| 85 |
+
# Check if collection exists
|
| 86 |
+
collections = self.client.get_collections().collections
|
| 87 |
+
collection_exists = any(
|
| 88 |
+
col.name == self.collection_name for col in collections
|
| 89 |
+
)
|
| 90 |
+
print(f"Collection exists: {collection_exists}")
|
| 91 |
+
|
| 92 |
+
if collection_exists:
|
| 93 |
+
if force_recreate:
|
| 94 |
+
print(f"Force recreating collection: {self.collection_name}")
|
| 95 |
+
self.client.delete_collection(self.collection_name)
|
| 96 |
+
print(f"Deleted existing collection: {self.collection_name}")
|
| 97 |
+
# Create collection
|
| 98 |
+
self.client.create_collection(
|
| 99 |
+
collection_name=self.collection_name,
|
| 100 |
+
vectors_config=VectorParams(
|
| 101 |
+
size=vector_size, distance=Distance.COSINE
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
print(f"Successfully created collection: {self.collection_name}")
|
| 105 |
+
else:
|
| 106 |
+
print(
|
| 107 |
+
f"Collection {self.collection_name} already exists - skipping creation"
|
| 108 |
+
)
|
| 109 |
+
return
|
| 110 |
+
else:
|
| 111 |
+
print(
|
| 112 |
+
f"Collection {self.collection_name} does not exist - creating new collection"
|
| 113 |
+
)
|
| 114 |
+
# Create collection
|
| 115 |
+
self.client.create_collection(
|
| 116 |
+
collection_name=self.collection_name,
|
| 117 |
+
vectors_config=VectorParams(
|
| 118 |
+
size=vector_size, distance=Distance.COSINE
|
| 119 |
+
),
|
| 120 |
+
)
|
| 121 |
+
print(f"Successfully created collection: {self.collection_name}")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"Error creating collection: {e}")
|
| 125 |
+
raise
|
| 126 |
+
|
| 127 |
+
def embed_text(self, text: str) -> List[float]:
|
| 128 |
+
"""Generate embedding for text"""
|
| 129 |
+
try:
|
| 130 |
+
embedding = self.embedding_model.encode(text, convert_to_tensor=False)
|
| 131 |
+
return embedding.tolist()
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error generating embedding: {e}")
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
def add_documents(self, documents: List[Dict[str, Any]]):
|
| 137 |
+
"""Add documents to vector store"""
|
| 138 |
+
try:
|
| 139 |
+
points = []
|
| 140 |
+
|
| 141 |
+
for doc in tqdm(documents):
|
| 142 |
+
# Generate embedding
|
| 143 |
+
content = doc.get("content", "")
|
| 144 |
+
if not content:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
embedding = self.embed_text(content)
|
| 148 |
+
if not embedding:
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
# Create point
|
| 152 |
+
point = PointStruct(
|
| 153 |
+
id=str(uuid.uuid4()),
|
| 154 |
+
vector=embedding,
|
| 155 |
+
payload={
|
| 156 |
+
"article_id": doc.get("id", ""),
|
| 157 |
+
"title": doc.get("title", ""),
|
| 158 |
+
"content": content,
|
| 159 |
+
"metadata": doc.get("metadata", {}),
|
| 160 |
+
},
|
| 161 |
+
)
|
| 162 |
+
points.append(point)
|
| 163 |
+
|
| 164 |
+
# Batch upload
|
| 165 |
+
batch_size = 100
|
| 166 |
+
for i in range(0, len(points), batch_size):
|
| 167 |
+
batch = points[i : i + batch_size]
|
| 168 |
+
self.client.upsert(collection_name=self.collection_name, points=batch)
|
| 169 |
+
print(
|
| 170 |
+
f"Uploaded batch {i//batch_size + 1}/{(len(points) + batch_size - 1)//batch_size}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
print(f"Successfully added {len(points)} documents to vector store")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error adding documents: {e}")
|
| 177 |
+
raise
|
| 178 |
+
|
| 179 |
+
def search_similar_documents(
|
| 180 |
+
self, query: str, top_k: int = None, score_threshold: float = None
|
| 181 |
+
) -> List[Dict[str, Any]]:
|
| 182 |
+
"""Search for similar documents"""
|
| 183 |
+
if top_k is None:
|
| 184 |
+
top_k = Config.TOP_K_RETRIEVAL
|
| 185 |
+
if score_threshold is None:
|
| 186 |
+
score_threshold = Config.SIMILARITY_THRESHOLD
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
# Generate query embedding
|
| 190 |
+
query_embedding = self.embed_text(query)
|
| 191 |
+
if not query_embedding:
|
| 192 |
+
return []
|
| 193 |
+
|
| 194 |
+
# Search
|
| 195 |
+
search_results = self.client.search(
|
| 196 |
+
collection_name=self.collection_name,
|
| 197 |
+
query_vector=query_embedding,
|
| 198 |
+
limit=top_k,
|
| 199 |
+
score_threshold=score_threshold,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Format results
|
| 203 |
+
results = []
|
| 204 |
+
for result in search_results:
|
| 205 |
+
results.append(
|
| 206 |
+
{
|
| 207 |
+
"id": result.payload.get("article_id", ""),
|
| 208 |
+
"title": result.payload.get("title", ""),
|
| 209 |
+
"content": result.payload.get("content", ""),
|
| 210 |
+
"score": result.score,
|
| 211 |
+
"metadata": result.payload.get("metadata", {}),
|
| 212 |
+
}
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
print(f"Found {len(results)} similar documents")
|
| 216 |
+
return results
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"Error searching documents: {e}")
|
| 220 |
+
return []
|
| 221 |
+
|
| 222 |
+
def get_collection_info(self) -> Dict[str, Any]:
|
| 223 |
+
"""Get collection information"""
|
| 224 |
+
try:
|
| 225 |
+
info = self.client.get_collection(self.collection_name)
|
| 226 |
+
result = {
|
| 227 |
+
"name": self.collection_name, # Use the collection name we know
|
| 228 |
+
"vectors_count": info.vectors_count,
|
| 229 |
+
"indexed_vectors_count": info.indexed_vectors_count,
|
| 230 |
+
"points_count": info.points_count,
|
| 231 |
+
}
|
| 232 |
+
print(f"Collection info: {result}")
|
| 233 |
+
return result
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f"Collection '{self.collection_name}' does not exist: {e}")
|
| 236 |
+
return {}
|
| 237 |
+
|
| 238 |
+
def delete_collection(self):
|
| 239 |
+
"""Delete collection"""
|
| 240 |
+
try:
|
| 241 |
+
self.client.delete_collection(self.collection_name)
|
| 242 |
+
print(f"Deleted collection: {self.collection_name}")
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"Error deleting collection: {e}")
|