fisherman611 commited on
Commit
c56380d
·
verified ·
1 Parent(s): cf96c3d

Upload 4 files

Browse files
Files changed (4) hide show
  1. main/bm25_retriever.py +180 -0
  2. main/chatbot.py +652 -0
  3. main/reranker.py +199 -0
  4. main/vector_store.py +244 -0
main/bm25_retriever.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rank_bm25 import BM25Okapi
2
+ from typing import List, Dict, Any, Tuple
3
+ import pickle
4
+ import os
5
+ from utils.text_processor import VietnameseTextProcessor
6
+ from config import Config
7
+ from tqdm.auto import tqdm
8
+
9
+ os.makedirs("index", exist_ok=True)
10
+
11
+
12
+ class BM25Retriever:
13
+ """BM25 retriever for initial document retrieval"""
14
+
15
+ def __init__(self):
16
+ self.text_processor = VietnameseTextProcessor()
17
+ self.bm25 = None
18
+ self.documents = []
19
+ self.tokenized_corpus = []
20
+ self.index_file = "index/bm25_index.pkl"
21
+
22
+ def build_index(self, documents: List[Dict[str, Any]]):
23
+ """Build BM25 index from documents"""
24
+ print("Building BM25 index...")
25
+
26
+ self.documents = documents
27
+ self.tokenized_corpus = []
28
+
29
+ # Tokenize all documents
30
+ for doc in tqdm(documents):
31
+ content = doc.get("content", "")
32
+ title = doc.get("title", "")
33
+
34
+ # Combine title and content for better search
35
+ full_text = f"{title} {content}"
36
+
37
+ # Preprocess and tokenize
38
+ processed_text = self.text_processor.preprocess_for_search(full_text)
39
+ tokens = processed_text.split()
40
+
41
+ self.tokenized_corpus.append(tokens)
42
+
43
+ # Build BM25 index
44
+ self.bm25 = BM25Okapi(self.tokenized_corpus)
45
+
46
+ print(f"BM25 index built with {len(self.documents)} documents")
47
+
48
+ def save_index(self, filepath: str = None):
49
+ """Save BM25 index to file"""
50
+ if filepath is None:
51
+ filepath = self.index_file
52
+
53
+ try:
54
+ index_data = {
55
+ "bm25": self.bm25,
56
+ "documents": self.documents,
57
+ "tokenized_corpus": self.tokenized_corpus,
58
+ }
59
+
60
+ with open(filepath, "wb") as f:
61
+ pickle.dump(index_data, f)
62
+
63
+ print(f"BM25 index saved to {filepath}")
64
+ except Exception as e:
65
+ print(f"Error saving BM25 index: {e}")
66
+
67
+ def load_index(self, filepath: str = None):
68
+ """Load BM25 index from file"""
69
+ if filepath is None:
70
+ filepath = self.index_file
71
+
72
+ try:
73
+ if not os.path.exists(filepath):
74
+ print(f"Index file {filepath} not found")
75
+ return False
76
+
77
+ with open(filepath, "rb") as f:
78
+ index_data = pickle.load(f)
79
+
80
+ self.bm25 = index_data["bm25"]
81
+ self.documents = index_data["documents"]
82
+ self.tokenized_corpus = index_data["tokenized_corpus"]
83
+
84
+ print(f"BM25 index loaded from {filepath}")
85
+ return True
86
+ except UnicodeDecodeError as e:
87
+ print(f"Encoding error loading BM25 index: {e}")
88
+ print(f"Removing corrupted index file: {filepath}")
89
+ try:
90
+ os.remove(filepath)
91
+ except:
92
+ pass
93
+ return False
94
+ except (pickle.UnpicklingError, EOFError) as e:
95
+ print(f"Corrupted BM25 index file: {e}")
96
+ print(f"Removing corrupted index file: {filepath}")
97
+ try:
98
+ os.remove(filepath)
99
+ except:
100
+ pass
101
+ return False
102
+ except Exception as e:
103
+ print(f"Error loading BM25 index: {e}")
104
+ return False
105
+
106
+ def search(
107
+ self, query: str, top_k: int = None
108
+ ) -> List[Tuple[Dict[str, Any], float]]:
109
+ """Search documents using BM25"""
110
+ if top_k is None:
111
+ top_k = Config.BM25_TOP_K
112
+
113
+ if self.bm25 is None:
114
+ print("BM25 index not built. Please build index first.")
115
+ return []
116
+
117
+ # Preprocess query
118
+ processed_query = self.text_processor.preprocess_for_search(query)
119
+ query_tokens = processed_query.split()
120
+
121
+ if not query_tokens:
122
+ return []
123
+
124
+ # Get BM25 scores
125
+ scores = self.bm25.get_scores(query_tokens)
126
+
127
+ # Get top documents with scores
128
+ doc_score_pairs = [
129
+ (self.documents[i], scores[i]) for i in range(len(self.documents))
130
+ ]
131
+
132
+ # Sort by score (descending)
133
+ doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
134
+
135
+ # Return top k
136
+ results = doc_score_pairs[:top_k]
137
+
138
+ print(f"BM25 search returned {len(results)} results for query: {query}")
139
+ return results
140
+
141
+ def get_relevant_documents(
142
+ self, query: str, top_k: int = None, min_score: float = 0.0
143
+ ) -> List[Dict[str, Any]]:
144
+ """Get relevant documents above minimum score threshold"""
145
+ results = self.search(query, top_k)
146
+
147
+ # Filter by minimum score
148
+ filtered_results = [
149
+ (doc, score) for doc, score in results if score >= min_score
150
+ ]
151
+
152
+ # Return only documents
153
+ return [doc for doc, score in filtered_results]
154
+
155
+ def search_with_keywords(
156
+ self, keywords: List[str], top_k: int = None
157
+ ) -> List[Tuple[Dict[str, Any], float]]:
158
+ """Search using multiple keywords"""
159
+ # Combine keywords into a single query
160
+ query = " ".join(keywords)
161
+ return self.search(query, top_k)
162
+
163
+ def get_index_stats(self) -> Dict[str, Any]:
164
+ """Get statistics about the BM25 index"""
165
+ if self.bm25 is None:
166
+ return {}
167
+
168
+ return {
169
+ "total_documents": len(self.documents),
170
+ "total_tokens": sum(len(tokens) for tokens in self.tokenized_corpus),
171
+ "average_document_length": sum(
172
+ len(tokens) for tokens in self.tokenized_corpus
173
+ )
174
+ / len(self.tokenized_corpus)
175
+ if self.tokenized_corpus
176
+ else 0,
177
+ "vocabulary_size": len(
178
+ set(token for tokens in self.tokenized_corpus for token in tokens)
179
+ ),
180
+ }
main/chatbot.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ from langchain_google_genai import ChatGoogleGenerativeAI
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.schema import HumanMessage, SystemMessage
5
+
6
+ from main.vector_store import QdrantVectorStore
7
+ from main.bm25_retriever import BM25Retriever
8
+ from main.reranker import DocumentReranker
9
+ from utils.text_processor import VietnameseTextProcessor
10
+ from utils.google_search import GoogleSearchTool
11
+ from utils.question_refiner import VietnameseLegalQuestionRefiner
12
+ from config import Config
13
+
14
+ class VietnameseLegalRAG:
15
+ """Vietnamese Legal RAG System"""
16
+
17
+ def __init__(self):
18
+ self.vector_store = None
19
+ self.bm25_retriever = None
20
+ self.reranker = None
21
+ self.llm = None
22
+ self.text_processor = VietnameseTextProcessor()
23
+ self.google_search = GoogleSearchTool()
24
+ self.question_refiner = VietnameseLegalQuestionRefiner()
25
+
26
+ self._initialize_components()
27
+
28
+ def _initialize_components(self):
29
+ """Initialize RAG components"""
30
+ try:
31
+ # Initialize LLM
32
+ if Config.GOOGLE_API_KEY:
33
+ self.llm = ChatGoogleGenerativeAI(
34
+ model=Config.MODEL_GEN,
35
+ google_api_key=Config.GOOGLE_API_KEY,
36
+ temperature=0.1
37
+ )
38
+ print("Google Gemini LLM initialized")
39
+ else:
40
+ print("Warning: Google API key not found")
41
+
42
+ # Initialize vector store
43
+ self.vector_store = QdrantVectorStore()
44
+
45
+ # Initialize BM25 retriever
46
+ self.bm25_retriever = BM25Retriever()
47
+
48
+ # Initialize reranker if enabled
49
+ if Config.ENABLE_RERANKING:
50
+ self.reranker = DocumentReranker(model_name=Config.RERANKER_MODEL)
51
+ else:
52
+ print("Reranking disabled in configuration")
53
+
54
+ except Exception as e:
55
+ print(f"Error initializing RAG components: {e}")
56
+
57
+ def setup_indices(self, documents: List[Dict[str, Any]], force_rebuild: bool = False):
58
+ """Setup both vector and BM25 indices"""
59
+ print("Setting up RAG indices...")
60
+
61
+ try:
62
+ # Setup vector store
63
+ if self.vector_store:
64
+ # Check if we need to create collection
65
+ try:
66
+ # First, do a simple existence check
67
+ collections = self.vector_store.client.get_collections().collections
68
+ collection_exists = any(col.name == self.vector_store.collection_name for col in collections)
69
+ print(f"Collection existence check: {collection_exists}")
70
+
71
+ if collection_exists:
72
+ # Collection exists, try to get detailed info
73
+ try:
74
+ collection_info = self.vector_store.get_collection_info()
75
+ has_documents = collection_info.get('points_count', 0) > 0
76
+
77
+ if force_rebuild:
78
+ print("Force rebuild requested - recreating vector store...")
79
+ self.vector_store.create_collection(force_recreate=True)
80
+ self.vector_store.add_documents(documents)
81
+ elif not has_documents:
82
+ print("Collection exists but is empty - adding documents...")
83
+ self.vector_store.add_documents(documents)
84
+ else:
85
+ print(f"Vector store collection already exists with {collection_info.get('points_count', 0)} documents")
86
+ except Exception as info_e:
87
+ print(f"Could not get collection info: {info_e}")
88
+ if force_rebuild:
89
+ print("Force rebuild requested - recreating vector store...")
90
+ self.vector_store.create_collection(force_recreate=True)
91
+ self.vector_store.add_documents(documents)
92
+ else:
93
+ print("Assuming collection has documents - skipping setup")
94
+ else:
95
+ # Collection doesn't exist, create it
96
+ print("Collection does not exist - creating new collection...")
97
+ self.vector_store.create_collection(force_recreate=False)
98
+ self.vector_store.add_documents(documents)
99
+
100
+ except Exception as e:
101
+ print(f"Error during vector store setup: {e}")
102
+ print("Attempting to create collection...")
103
+ self.vector_store.create_collection()
104
+ self.vector_store.add_documents(documents)
105
+
106
+ # Setup BM25 index
107
+ if self.bm25_retriever:
108
+ # Try to load existing index
109
+ if not self.bm25_retriever.load_index() or force_rebuild:
110
+ self.bm25_retriever.build_index(documents)
111
+ self.bm25_retriever.save_index()
112
+ else:
113
+ print("BM25 index loaded from file")
114
+
115
+ print("RAG indices setup completed")
116
+
117
+ except Exception as e:
118
+ print(f"Error setting up indices: {e}")
119
+ raise
120
+
121
+ def retrieve_documents(self, query: str, use_hybrid: bool = True, use_reranking: bool = None) -> List[Dict[str, Any]]:
122
+ """Retrieve relevant documents using hybrid approach with optional reranking"""
123
+ retrieved_docs = []
124
+
125
+ # Use config default if not specified
126
+ if use_reranking is None:
127
+ use_reranking = Config.ENABLE_RERANKING
128
+
129
+ # Adjust retrieval counts if reranking is enabled
130
+ bm25_top_k = Config.RERANK_BEFORE_RETRIEVAL_TOP_K if use_reranking else Config.BM25_TOP_K
131
+ vector_top_k = Config.RERANK_BEFORE_RETRIEVAL_TOP_K if use_reranking else Config.TOP_K_RETRIEVAL
132
+
133
+ try:
134
+ if use_hybrid and self.bm25_retriever and self.vector_store:
135
+ # Hybrid retrieval: BM25 + Vector Search
136
+
137
+ # BM25 retrieval
138
+ bm25_results = self.bm25_retriever.get_relevant_documents(
139
+ query, top_k=bm25_top_k
140
+ )
141
+
142
+ # Vector search
143
+ vector_results = self.vector_store.search_similar_documents(
144
+ query, top_k=vector_top_k
145
+ )
146
+
147
+ # Combine and deduplicate results
148
+ all_docs = {}
149
+
150
+ # Add BM25 results
151
+ for doc in bm25_results:
152
+ doc_id = doc.get('id', '')
153
+ if doc_id:
154
+ all_docs[doc_id] = {**doc, 'retrieval_method': 'bm25'}
155
+
156
+ # Add vector results
157
+ for doc in vector_results:
158
+ doc_id = doc.get('id', '')
159
+ if doc_id:
160
+ if doc_id in all_docs:
161
+ # Combine scores if document found by both methods
162
+ all_docs[doc_id]['retrieval_method'] = 'hybrid'
163
+ all_docs[doc_id]['vector_score'] = doc.get('score', 0)
164
+ else:
165
+ all_docs[doc_id] = {**doc, 'retrieval_method': 'vector'}
166
+
167
+ retrieved_docs = list(all_docs.values())
168
+
169
+ elif self.vector_store:
170
+ # Vector search only
171
+ retrieved_docs = self.vector_store.search_similar_documents(query, top_k=vector_top_k)
172
+
173
+ elif self.bm25_retriever:
174
+ # BM25 only
175
+ retrieved_docs = self.bm25_retriever.get_relevant_documents(query, top_k=bm25_top_k)
176
+
177
+ # Improved similarity filtering logic
178
+ if retrieved_docs:
179
+ # Check for high-quality documents first
180
+ high_quality_docs = []
181
+ moderate_quality_docs = []
182
+
183
+ for doc in retrieved_docs:
184
+ score = doc.get('score', 0)
185
+ if score >= Config.SIMILARITY_THRESHOLD:
186
+ high_quality_docs.append(doc)
187
+ elif score >= Config.MIN_SIMILARITY_FOR_LEGAL_DOCS:
188
+ moderate_quality_docs.append(doc)
189
+
190
+ # Return high quality docs if available
191
+ if high_quality_docs:
192
+ print(f"Retrieved {len(high_quality_docs)} high-quality documents")
193
+ return high_quality_docs[:Config.TOP_K_RETRIEVAL]
194
+
195
+ # Return moderate quality docs if no high quality ones
196
+ elif moderate_quality_docs:
197
+ print(f"Retrieved {len(moderate_quality_docs)} moderate-quality documents")
198
+ return moderate_quality_docs[:Config.TOP_K_RETRIEVAL]
199
+
200
+ else:
201
+ print("No documents found with sufficient similarity scores")
202
+ return []
203
+
204
+ # Apply reranking if enabled
205
+ if use_reranking and self.reranker and retrieved_docs:
206
+ print(f"Applying reranking to {len(retrieved_docs)} documents...")
207
+
208
+ if Config.USE_SCORE_FUSION:
209
+ # Use score fusion for better results
210
+ retrieved_docs = self.reranker.rerank_with_fusion(
211
+ query,
212
+ retrieved_docs,
213
+ alpha=Config.RERANKER_FUSION_ALPHA,
214
+ top_k=Config.RERANKER_TOP_K
215
+ )
216
+ else:
217
+ # Use pure reranker scores
218
+ retrieved_docs = self.reranker.rerank_documents(
219
+ query,
220
+ retrieved_docs,
221
+ top_k=Config.RERANKER_TOP_K
222
+ )
223
+ print(f"Reranking completed, returning {len(retrieved_docs)} documents")
224
+ else:
225
+ # No reranking, limit results to original config
226
+ retrieved_docs = retrieved_docs[:Config.TOP_K_RETRIEVAL]
227
+
228
+ print(f"Retrieved {len(retrieved_docs)} documents for query")
229
+ return retrieved_docs
230
+
231
+ except Exception as e:
232
+ print(f"Error retrieving documents: {e}")
233
+ return []
234
+
235
+ def format_context(self, documents: List[Dict[str, Any]]) -> str:
236
+ """Format retrieved documents as context for LLM"""
237
+ if not documents:
238
+ return "Không có tài liệu pháp luật liên quan được tìm thấy."
239
+
240
+ context_parts = []
241
+
242
+ for i, doc in enumerate(documents, 1):
243
+ title = doc.get('title', 'Không có tiêu đề')
244
+ content = doc.get('content', '')
245
+ doc_id = doc.get('id', '')
246
+ metadata = doc.get('metadata', {})
247
+ law_id = metadata.get('law_id', '')
248
+ article_id = metadata.get('article_id', '')
249
+
250
+ # Limit content length
251
+ if len(content) > 500:
252
+ content = content[:500] + "..."
253
+
254
+ # Format law and article information
255
+ law_info = f"Luật: {law_id}" if law_id else ""
256
+ article_info = f"Điều {article_id}" if article_id else f"ID: {doc_id}"
257
+
258
+ context_part = f"""
259
+ Tài liệu {i}:
260
+ {law_info}
261
+ {article_info}: {title}
262
+ Nội dung: {content}
263
+ """
264
+ context_parts.append(context_part.strip())
265
+
266
+ return "\n\n".join(context_parts)
267
+
268
+ def generate_answer(self, query: str, context: str, is_fallback: bool = False) -> str:
269
+ """Generate answer using LLM with context"""
270
+ if not self.llm:
271
+ return "Lỗi: Không thể kết nối với mô hình ngôn ngữ."
272
+
273
+ try:
274
+ # Create prompt based on scenario
275
+ if is_fallback:
276
+ prompt = Config.FALLBACK_SYSTEM_PROMPT.format(
277
+ context=context,
278
+ question=query
279
+ )
280
+ else:
281
+ prompt = Config.SYSTEM_PROMPT.format(
282
+ context=context,
283
+ question=query
284
+ )
285
+
286
+ # Generate response
287
+ messages = [HumanMessage(content=prompt)]
288
+ response = self.llm(messages)
289
+
290
+ answer = response.content
291
+
292
+ # Add legal disclaimer if enabled and it's a fallback response
293
+ if Config.SHOW_LEGAL_DISCLAIMER and is_fallback:
294
+ disclaimer = "\n\nLưu ý quan trọng: Để đảm bảo quyền lợi của mình, người lao động nên tìm đến các chuyên gia pháp lý hoặc cơ quan chức năng có thẩm quyền để được tư vấn cụ thể và chính xác nhất dựa trên tình hình thực tế của mình."
295
+ answer += disclaimer
296
+
297
+ return answer
298
+
299
+ except Exception as e:
300
+ print(f"Error generating answer: {e}")
301
+ return f"Lỗi khi tạo câu trả lời: {str(e)}"
302
+
303
+ def _is_negative_response(self, response: str) -> bool:
304
+ """Check if the response is a negative/unable to answer response"""
305
+ negative_indicators = [
306
+ "không thể trả lời",
307
+ "không tìm thấy",
308
+ "không có thông tin",
309
+ "xin lỗi",
310
+ "không thể tìm thấy",
311
+ "không có dữ liệu",
312
+ "không rõ",
313
+ "không biết",
314
+ "không đủ thông tin",
315
+ "thiếu thông tin",
316
+ "không có trong",
317
+ "ngoài phạm vi",
318
+ # Add the specific pattern mentioned by user
319
+ "không có đủ thông tin trong tài liệu tham khảo được cung cấp để trả lời trực tiếp câu hỏi này",
320
+ "cần tham khảo thêm các văn bản pháp luật khác",
321
+ "tìm kiếm thông tin chuyên sâu hơn về",
322
+ "tài liệu tham khảo không chứa thông tin đầy đủ"
323
+ ]
324
+
325
+ response_lower = response.lower()
326
+ return any(indicator in response_lower for indicator in negative_indicators)
327
+
328
+ def _enhanced_search_fallback(self, query: str) -> List[Dict[str, Any]]:
329
+ """Enhanced search fallback with multiple strategies"""
330
+ print("Attempting enhanced search fallback...")
331
+
332
+ all_docs = []
333
+
334
+ # Strategy 1: Lower similarity threshold
335
+ try:
336
+ if self.bm25_retriever and self.vector_store:
337
+ # BM25 with more results
338
+ bm25_results = self.bm25_retriever.get_relevant_documents(
339
+ query, top_k=15 # Increased from default
340
+ )
341
+
342
+ # Vector search with lower threshold
343
+ vector_results = self.vector_store.search_similar_documents(
344
+ query, top_k=10 # Increased from default
345
+ )
346
+
347
+ # Combine results with very low threshold
348
+ combined_docs = {}
349
+
350
+ for doc in bm25_results:
351
+ doc_id = doc.get('id', '')
352
+ if doc_id and doc.get('score', 0) >= 0.1: # Very low threshold
353
+ combined_docs[doc_id] = {**doc, 'retrieval_method': 'bm25_enhanced'}
354
+
355
+ for doc in vector_results:
356
+ doc_id = doc.get('id', '')
357
+ if doc_id and doc.get('score', 0) >= 0.15: # Very low threshold
358
+ if doc_id in combined_docs:
359
+ combined_docs[doc_id]['retrieval_method'] = 'hybrid_enhanced'
360
+ else:
361
+ combined_docs[doc_id] = {**doc, 'retrieval_method': 'vector_enhanced'}
362
+
363
+ all_docs = list(combined_docs.values())
364
+
365
+ except Exception as e:
366
+ print(f"Error in enhanced search: {e}")
367
+
368
+ # Strategy 2: Keyword-based search if still no results
369
+ if not all_docs:
370
+ try:
371
+ # Extract keywords from query
372
+ keywords = self.text_processor.extract_keywords(query)
373
+ if keywords:
374
+ keyword_query = " ".join(keywords[:5]) # Use top 5 keywords
375
+
376
+ if self.bm25_retriever:
377
+ keyword_results = self.bm25_retriever.get_relevant_documents(
378
+ keyword_query, top_k=10
379
+ )
380
+ all_docs.extend([{**doc, 'retrieval_method': 'keyword_fallback'}
381
+ for doc in keyword_results if doc.get('score', 0) >= 0.05])
382
+
383
+ except Exception as e:
384
+ print(f"Error in keyword fallback: {e}")
385
+
386
+ print(f"Enhanced search found {len(all_docs)} documents")
387
+ return all_docs[:Config.TOP_K_RETRIEVAL]
388
+
389
+ def answer_question(self, query: str, use_fallback: bool = True, refine_question: bool = True) -> Dict[str, Any]:
390
+ """Answer a legal question using RAG with enhanced negative response handling and question refinement"""
391
+ print(f"Processing question: {query}")
392
+
393
+ # Step 1: Refine the question if enabled
394
+ original_query = query
395
+ refinement_result = None
396
+
397
+ if refine_question and Config.ENABLE_QUESTION_REFINEMENT and self.question_refiner:
398
+ print("🔧 Refining question for better search accuracy...")
399
+ refinement_result = self.question_refiner.refine_question(query, use_llm=Config.USE_LLM_FOR_REFINEMENT)
400
+
401
+ if refinement_result["refined_question"] != query:
402
+ refined_query = refinement_result["refined_question"]
403
+ print(f"📝 Original: {query}")
404
+ print(f"✨ Refined: {refined_query}")
405
+ query = refined_query
406
+
407
+ # Step 2: Retrieve relevant documents using refined query
408
+ retrieved_docs = self.retrieve_documents(query)
409
+
410
+ # Check if we have relevant documents
411
+ if not retrieved_docs and Config.ENABLE_GOOGLE_SEARCH and use_fallback:
412
+ print("No relevant legal documents found, using Google search fallback")
413
+
414
+ # Use Google search as fallback
415
+ search_results = self.google_search.search_legal_info(query)
416
+
417
+ if search_results:
418
+ fallback_context = self.google_search.format_search_results(search_results)
419
+
420
+ # Generate answer with fallback context
421
+ fallback_answer = self.generate_answer(query, fallback_context, True)
422
+
423
+ return {
424
+ 'answer': fallback_answer,
425
+ 'retrieved_documents': [],
426
+ 'fallback_used': True,
427
+ 'search_results': search_results,
428
+ 'context': fallback_context,
429
+ 'search_results_html': self.google_search.format_search_results_for_display(search_results),
430
+ 'original_question': original_query,
431
+ 'refined_question': query,
432
+ 'question_refinement': refinement_result
433
+ }
434
+ else:
435
+ # Try enhanced search before giving up
436
+ enhanced_docs = self._enhanced_search_fallback(query)
437
+ if enhanced_docs:
438
+ enhanced_context = self.format_context(enhanced_docs)
439
+ enhanced_answer = self.generate_answer(query, enhanced_context, False)
440
+
441
+ return {
442
+ 'answer': enhanced_answer,
443
+ 'retrieved_documents': enhanced_docs,
444
+ 'fallback_used': True,
445
+ 'context': enhanced_context,
446
+ 'search_results': [],
447
+ 'search_results_html': "",
448
+ 'enhanced_search_used': True,
449
+ 'original_question': original_query,
450
+ 'refined_question': query,
451
+ 'question_refinement': refinement_result
452
+ }
453
+ else:
454
+ return {
455
+ 'answer': "Xin lỗi, tôi không tìm thấy thông tin pháp luật liên quan đến câu hỏi của bạn trong cơ sở dữ liệu nội bộ và cũng không thể tìm kiếm thông tin trên web. Vui lòng thử lại với câu hỏi khác hoặc liên hệ với chuyên gia pháp lý.",
456
+ 'retrieved_documents': [],
457
+ 'fallback_used': True,
458
+ 'search_results': [],
459
+ 'context': "",
460
+ 'search_results_html': "",
461
+ 'enhanced_search_used': False,
462
+ 'original_question': original_query,
463
+ 'refined_question': query,
464
+ 'question_refinement': refinement_result
465
+ }
466
+ elif not retrieved_docs:
467
+ # Try enhanced search before giving negative response
468
+ enhanced_docs = self._enhanced_search_fallback(query)
469
+ if enhanced_docs:
470
+ enhanced_context = self.format_context(enhanced_docs)
471
+ enhanced_answer = self.generate_answer(query, enhanced_context, False)
472
+
473
+ return {
474
+ 'answer': enhanced_answer,
475
+ 'retrieved_documents': enhanced_docs,
476
+ 'fallback_used': False,
477
+ 'context': enhanced_context,
478
+ 'search_results': [],
479
+ 'search_results_html': "",
480
+ 'enhanced_search_used': True,
481
+ 'original_question': original_query,
482
+ 'refined_question': query,
483
+ 'question_refinement': refinement_result
484
+ }
485
+ else:
486
+ return {
487
+ 'answer': "Xin lỗi, tôi không tìm thấy thông tin pháp luật liên quan đến câu hỏi của bạn trong cơ sở dữ liệu.",
488
+ 'retrieved_documents': [],
489
+ 'fallback_used': False,
490
+ 'context': "",
491
+ 'search_results': [],
492
+ 'search_results_html': "",
493
+ 'enhanced_search_used': False,
494
+ 'original_question': original_query,
495
+ 'refined_question': query,
496
+ 'question_refinement': refinement_result
497
+ }
498
+
499
+ # Format context
500
+ context = self.format_context(retrieved_docs)
501
+
502
+ # Generate answer
503
+ answer = self.generate_answer(query, context, False)
504
+
505
+ # Check if the generated answer is negative and retry with enhanced search
506
+ if self._is_negative_response(answer) and use_fallback:
507
+ print("🔍 Detected insufficient information response, activating search tools...")
508
+
509
+ # Inform user that search is being performed
510
+ search_notification = f"\n\n*🔍 Đang tìm kiếm thông tin bổ sung để trả lời câu hỏi của bạn...*"
511
+
512
+ # Try enhanced search first
513
+ enhanced_docs = self._enhanced_search_fallback(query)
514
+
515
+ if enhanced_docs:
516
+ # Combine original and enhanced docs
517
+ all_docs = retrieved_docs + [doc for doc in enhanced_docs
518
+ if doc.get('id') not in [d.get('id') for d in retrieved_docs]]
519
+ enhanced_context = self.format_context(all_docs[:Config.TOP_K_RETRIEVAL])
520
+ enhanced_answer = self.generate_answer(query, enhanced_context, False)
521
+
522
+ # If still negative, try Google search
523
+ if self._is_negative_response(enhanced_answer) and Config.ENABLE_GOOGLE_SEARCH:
524
+ print("📡 Enhanced search still insufficient, trying web search...")
525
+ search_results = self.google_search.search_legal_info(query)
526
+
527
+ if search_results:
528
+ fallback_context = self.google_search.format_search_results(search_results)
529
+ final_answer = self.generate_answer(query, fallback_context, True)
530
+
531
+ # Add notification about web search usage (if enabled)
532
+ if Config.SHOW_SOURCE_INFO:
533
+ final_answer = f"{final_answer}\n\n*🌐 Thông tin này được tìm kiếm từ web do không tìm thấy đủ thông tin trong cơ sở dữ liệu pháp luật nội bộ.*"
534
+
535
+ return {
536
+ 'answer': final_answer,
537
+ 'retrieved_documents': all_docs,
538
+ 'fallback_used': True,
539
+ 'context': enhanced_context,
540
+ 'search_results': search_results,
541
+ 'search_results_html': self.google_search.format_search_results_for_display(search_results),
542
+ 'enhanced_search_used': True,
543
+ 'search_triggered': True,
544
+ 'original_question': original_query,
545
+ 'refined_question': query,
546
+ 'question_refinement': refinement_result
547
+ }
548
+ else:
549
+ # Web search failed, return enhanced answer with notification (if enabled)
550
+ if Config.SHOW_SOURCE_INFO:
551
+ enhanced_answer = f"{enhanced_answer}\n\n*🔍 Đã sử dụng tìm kiếm nâng cao trong cơ sở dữ liệu.*"
552
+
553
+ return {
554
+ 'answer': enhanced_answer,
555
+ 'retrieved_documents': all_docs,
556
+ 'fallback_used': False,
557
+ 'context': enhanced_context,
558
+ 'search_results': [],
559
+ 'search_results_html': "",
560
+ 'enhanced_search_used': True,
561
+ 'search_triggered': True,
562
+ 'original_question': original_query,
563
+ 'refined_question': query,
564
+ 'question_refinement': refinement_result
565
+ }
566
+ else:
567
+ # Enhanced search was sufficient (if enabled)
568
+ if Config.SHOW_SOURCE_INFO:
569
+ enhanced_answer = f"{enhanced_answer}\n\n*🔍 Đã sử dụng tìm kiếm nâng cao để tìm thông tin này.*"
570
+
571
+ return {
572
+ 'answer': enhanced_answer,
573
+ 'retrieved_documents': all_docs,
574
+ 'fallback_used': False,
575
+ 'context': enhanced_context,
576
+ 'search_results': [],
577
+ 'search_results_html': "",
578
+ 'enhanced_search_used': True,
579
+ 'search_triggered': True,
580
+ 'original_question': original_query,
581
+ 'refined_question': query,
582
+ 'question_refinement': refinement_result
583
+ }
584
+ elif Config.ENABLE_GOOGLE_SEARCH:
585
+ # Try Google search as last resort
586
+ print("📡 Database search failed, trying web search as last resort...")
587
+ search_results = self.google_search.search_legal_info(query)
588
+
589
+ if search_results:
590
+ fallback_context = self.google_search.format_search_results(search_results)
591
+ final_answer = self.generate_answer(query, fallback_context, True)
592
+
593
+ # Add notification about web search usage (if enabled)
594
+ if Config.SHOW_SOURCE_INFO:
595
+ final_answer = f"{final_answer}\n\n*🌐 Thông tin này được tìm kiếm từ web do không tìm thấy trong cơ sở dữ liệu pháp luật nội bộ.*"
596
+
597
+ return {
598
+ 'answer': final_answer,
599
+ 'retrieved_documents': retrieved_docs,
600
+ 'fallback_used': True,
601
+ 'context': context,
602
+ 'search_results': search_results,
603
+ 'search_results_html': self.google_search.format_search_results_for_display(search_results),
604
+ 'enhanced_search_used': False,
605
+ 'search_triggered': True,
606
+ 'original_question': original_query,
607
+ 'refined_question': query,
608
+ 'question_refinement': refinement_result
609
+ }
610
+
611
+ return {
612
+ 'answer': answer,
613
+ 'retrieved_documents': retrieved_docs,
614
+ 'fallback_used': False,
615
+ 'context': context,
616
+ 'search_results': [],
617
+ 'search_results_html': "",
618
+ 'enhanced_search_used': False,
619
+ 'search_triggered': False,
620
+ 'original_question': original_query,
621
+ 'refined_question': query,
622
+ 'question_refinement': refinement_result
623
+ }
624
+
625
+ def get_system_status(self) -> Dict[str, Any]:
626
+ """Get status of RAG system components"""
627
+ status = {
628
+ 'llm_available': self.llm is not None,
629
+ 'vector_store_available': self.vector_store is not None,
630
+ 'bm25_available': self.bm25_retriever is not None,
631
+ 'reranker_available': self.reranker is not None and self.reranker.model is not None,
632
+ 'reranking_enabled': Config.ENABLE_RERANKING,
633
+ 'google_api_configured': bool(Config.GOOGLE_API_KEY),
634
+ 'qdrant_configured': bool(Config.QDRANT_URL and Config.QDRANT_API_KEY)
635
+ }
636
+
637
+ # Get collection info if available
638
+ if self.vector_store:
639
+ try:
640
+ status['vector_store_info'] = self.vector_store.get_collection_info()
641
+ except:
642
+ status['vector_store_info'] = {}
643
+
644
+ # Get BM25 stats if available
645
+ if self.bm25_retriever:
646
+ status['bm25_stats'] = self.bm25_retriever.get_index_stats()
647
+
648
+ # Get reranker info if available
649
+ if self.reranker:
650
+ status['reranker_info'] = self.reranker.get_model_info()
651
+
652
+ return status
main/reranker.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Tuple
2
+ import numpy as np
3
+ from sentence_transformers import CrossEncoder
4
+ import logging
5
+
6
+ class DocumentReranker:
7
+ """Document reranker using cross-encoder models for improved relevance scoring"""
8
+
9
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
10
+ """
11
+ Initialize the reranker with a cross-encoder model
12
+
13
+ Args:
14
+ model_name: Name of the cross-encoder model to use
15
+ Default is a multilingual model that works well with Vietnamese
16
+ """
17
+ self.model_name = model_name
18
+ self.model = None
19
+ self._initialize_model()
20
+
21
+ def _initialize_model(self):
22
+ """Initialize the cross-encoder model"""
23
+ try:
24
+ print(f"Loading reranker model: {self.model_name}")
25
+ self.model = CrossEncoder(self.model_name)
26
+ print("Reranker model loaded successfully")
27
+ except Exception as e:
28
+ print(f"Error loading reranker model: {e}")
29
+ # Fallback to a lighter model
30
+ try:
31
+ fallback_model = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
32
+ print(f"Trying fallback model: {fallback_model}")
33
+ self.model = CrossEncoder(fallback_model)
34
+ self.model_name = fallback_model
35
+ print("Fallback reranker model loaded successfully")
36
+ except Exception as e2:
37
+ print(f"Error loading fallback model: {e2}")
38
+ self.model = None
39
+
40
+ def rerank_documents(self,
41
+ query: str,
42
+ documents: List[Dict[str, Any]],
43
+ top_k: int = None) -> List[Dict[str, Any]]:
44
+ """
45
+ Rerank documents based on relevance to the query
46
+
47
+ Args:
48
+ query: The search query
49
+ documents: List of documents to rerank
50
+ top_k: Number of top documents to return (None for all)
51
+
52
+ Returns:
53
+ List of reranked documents with updated scores
54
+ """
55
+ if not self.model or not documents:
56
+ return documents
57
+
58
+ try:
59
+ # Prepare query-document pairs for the cross-encoder
60
+ pairs = []
61
+ valid_docs = []
62
+
63
+ for doc in documents:
64
+ content = doc.get('content', '')
65
+ title = doc.get('title', '')
66
+
67
+ # Combine title and content for better matching
68
+ doc_text = f"{title}. {content}" if title else content
69
+
70
+ # Truncate very long documents to avoid model limits
71
+ max_length = 512
72
+ if len(doc_text) > max_length:
73
+ doc_text = doc_text[:max_length] + "..."
74
+
75
+ pairs.append([query, doc_text])
76
+ valid_docs.append(doc)
77
+
78
+ if not pairs:
79
+ return documents
80
+
81
+ # Get relevance scores from cross-encoder
82
+ scores = self.model.predict(pairs)
83
+
84
+ # Update documents with reranker scores
85
+ reranked_docs = []
86
+ for doc, score in zip(valid_docs, scores):
87
+ # Create a copy to avoid modifying original
88
+ reranked_doc = doc.copy()
89
+
90
+ # Store both original and reranker scores
91
+ reranked_doc['reranker_score'] = float(score)
92
+ reranked_doc['original_score'] = doc.get('score', 0.0)
93
+
94
+ # Update the main score with reranker score
95
+ reranked_doc['score'] = float(score)
96
+
97
+ # Add reranking method info
98
+ if 'retrieval_method' in reranked_doc:
99
+ reranked_doc['retrieval_method'] += '_reranked'
100
+ else:
101
+ reranked_doc['retrieval_method'] = 'reranked'
102
+
103
+ reranked_docs.append(reranked_doc)
104
+
105
+ # Sort by reranker score (descending)
106
+ reranked_docs.sort(key=lambda x: x['reranker_score'], reverse=True)
107
+
108
+ # Return top_k if specified
109
+ if top_k:
110
+ reranked_docs = reranked_docs[:top_k]
111
+
112
+ print(f"Reranked {len(reranked_docs)} documents")
113
+
114
+ # Log top scores for debugging
115
+ if reranked_docs:
116
+ top_scores = [doc['reranker_score'] for doc in reranked_docs[:3]]
117
+ print(f"Top reranker scores: {top_scores}")
118
+
119
+ return reranked_docs
120
+
121
+ except Exception as e:
122
+ print(f"Error during reranking: {e}")
123
+ # Return original documents if reranking fails
124
+ return documents
125
+
126
+ def rerank_with_fusion(self,
127
+ query: str,
128
+ documents: List[Dict[str, Any]],
129
+ alpha: float = 0.7,
130
+ top_k: int = None) -> List[Dict[str, Any]]:
131
+ """
132
+ Rerank documents using score fusion between original and reranker scores
133
+
134
+ Args:
135
+ query: The search query
136
+ documents: List of documents to rerank
137
+ alpha: Weight for reranker score (0-1), higher means more weight on reranker
138
+ top_k: Number of top documents to return
139
+
140
+ Returns:
141
+ List of reranked documents with fused scores
142
+ """
143
+ if not self.model or not documents:
144
+ return documents
145
+
146
+ try:
147
+ # First get reranker scores
148
+ reranked_docs = self.rerank_documents(query, documents, top_k=None)
149
+
150
+ if not reranked_docs:
151
+ return documents
152
+
153
+ # Normalize original scores to 0-1 range
154
+ original_scores = [doc.get('original_score', 0.0) for doc in reranked_docs]
155
+ if max(original_scores) > 0:
156
+ original_scores_norm = [s / max(original_scores) for s in original_scores]
157
+ else:
158
+ original_scores_norm = [0.0] * len(original_scores)
159
+
160
+ # Normalize reranker scores to 0-1 range
161
+ reranker_scores = [doc.get('reranker_score', 0.0) for doc in reranked_docs]
162
+ min_reranker = min(reranker_scores)
163
+ max_reranker = max(reranker_scores)
164
+
165
+ if max_reranker > min_reranker:
166
+ reranker_scores_norm = [(s - min_reranker) / (max_reranker - min_reranker)
167
+ for s in reranker_scores]
168
+ else:
169
+ reranker_scores_norm = [0.5] * len(reranker_scores)
170
+
171
+ # Compute fused scores
172
+ for i, doc in enumerate(reranked_docs):
173
+ fused_score = (alpha * reranker_scores_norm[i] +
174
+ (1 - alpha) * original_scores_norm[i])
175
+ doc['fused_score'] = fused_score
176
+ doc['score'] = fused_score # Update main score
177
+
178
+ # Sort by fused score
179
+ reranked_docs.sort(key=lambda x: x['fused_score'], reverse=True)
180
+
181
+ # Return top_k if specified
182
+ if top_k:
183
+ reranked_docs = reranked_docs[:top_k]
184
+
185
+ print(f"Score fusion reranking completed for {len(reranked_docs)} documents")
186
+
187
+ return reranked_docs
188
+
189
+ except Exception as e:
190
+ print(f"Error during fusion reranking: {e}")
191
+ return documents
192
+
193
+ def get_model_info(self) -> Dict[str, Any]:
194
+ """Get information about the loaded model"""
195
+ return {
196
+ 'model_name': self.model_name,
197
+ 'model_loaded': self.model is not None,
198
+ 'model_type': 'cross-encoder'
199
+ }
main/vector_store.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from typing import List, Dict, Any, Optional
4
+ from qdrant_client import QdrantClient
5
+ from qdrant_client.models import (
6
+ Distance,
7
+ VectorParams,
8
+ PointStruct,
9
+ Filter,
10
+ FieldCondition,
11
+ MatchValue,
12
+ )
13
+ from sentence_transformers import SentenceTransformer
14
+ import numpy as np
15
+ from config import Config
16
+ from tqdm import tqdm
17
+
18
+
19
+ class QdrantVectorStore:
20
+ """QDrant vector store for legal document embeddings"""
21
+
22
+ def __init__(self):
23
+ self.client = None
24
+ self.embedding_model = None
25
+ self.collection_name = Config.COLLECTION_NAME
26
+ self._initialize_client()
27
+ self._initialize_embedding_model()
28
+
29
+ def _initialize_client(self):
30
+ """Initialize QDrant client"""
31
+ try:
32
+ if Config.QDRANT_URL and Config.QDRANT_API_KEY:
33
+ # Cloud QDrant
34
+ self.client = QdrantClient(
35
+ url=Config.QDRANT_URL, api_key=Config.QDRANT_API_KEY
36
+ )
37
+ else:
38
+ # Local QDrant (fallback)
39
+ self.client = QdrantClient(host="localhost", port=6333)
40
+
41
+ print("QDrant client initialized successfully")
42
+ except Exception as e:
43
+ print(f"Error initializing QDrant client: {e}")
44
+ raise
45
+
46
+ def _initialize_embedding_model(self):
47
+ """Initialize embedding model"""
48
+ try:
49
+ # Clear any potentially corrupted cache
50
+ import tempfile
51
+ import shutil
52
+
53
+ cache_dir = os.path.join(tempfile.gettempdir(), "sentence_transformers")
54
+
55
+ self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
56
+ print(f"Embedding model {Config.EMBEDDING_MODEL} loaded successfully")
57
+ except UnicodeDecodeError as e:
58
+ print(f"Encoding error loading embedding model: {e}")
59
+ print("Trying to clear sentence-transformers cache...")
60
+ try:
61
+ import tempfile
62
+ import shutil
63
+
64
+ cache_dir = os.path.join(tempfile.gettempdir(), "sentence_transformers")
65
+ if os.path.exists(cache_dir):
66
+ shutil.rmtree(cache_dir)
67
+ print("Cache cleared, retrying...")
68
+
69
+ self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
70
+ print(
71
+ f"Embedding model {Config.EMBEDDING_MODEL} loaded successfully after cache clear"
72
+ )
73
+ except Exception as retry_e:
74
+ print(
75
+ f"Failed to load embedding model even after cache clear: {retry_e}"
76
+ )
77
+ raise
78
+ except Exception as e:
79
+ print(f"Error loading embedding model: {e}")
80
+ raise
81
+
82
+ def create_collection(self, vector_size: int = 384, force_recreate: bool = False):
83
+ """Create collection in QDrant"""
84
+ try:
85
+ # Check if collection exists
86
+ collections = self.client.get_collections().collections
87
+ collection_exists = any(
88
+ col.name == self.collection_name for col in collections
89
+ )
90
+ print(f"Collection exists: {collection_exists}")
91
+
92
+ if collection_exists:
93
+ if force_recreate:
94
+ print(f"Force recreating collection: {self.collection_name}")
95
+ self.client.delete_collection(self.collection_name)
96
+ print(f"Deleted existing collection: {self.collection_name}")
97
+ # Create collection
98
+ self.client.create_collection(
99
+ collection_name=self.collection_name,
100
+ vectors_config=VectorParams(
101
+ size=vector_size, distance=Distance.COSINE
102
+ ),
103
+ )
104
+ print(f"Successfully created collection: {self.collection_name}")
105
+ else:
106
+ print(
107
+ f"Collection {self.collection_name} already exists - skipping creation"
108
+ )
109
+ return
110
+ else:
111
+ print(
112
+ f"Collection {self.collection_name} does not exist - creating new collection"
113
+ )
114
+ # Create collection
115
+ self.client.create_collection(
116
+ collection_name=self.collection_name,
117
+ vectors_config=VectorParams(
118
+ size=vector_size, distance=Distance.COSINE
119
+ ),
120
+ )
121
+ print(f"Successfully created collection: {self.collection_name}")
122
+
123
+ except Exception as e:
124
+ print(f"Error creating collection: {e}")
125
+ raise
126
+
127
+ def embed_text(self, text: str) -> List[float]:
128
+ """Generate embedding for text"""
129
+ try:
130
+ embedding = self.embedding_model.encode(text, convert_to_tensor=False)
131
+ return embedding.tolist()
132
+ except Exception as e:
133
+ print(f"Error generating embedding: {e}")
134
+ return []
135
+
136
+ def add_documents(self, documents: List[Dict[str, Any]]):
137
+ """Add documents to vector store"""
138
+ try:
139
+ points = []
140
+
141
+ for doc in tqdm(documents):
142
+ # Generate embedding
143
+ content = doc.get("content", "")
144
+ if not content:
145
+ continue
146
+
147
+ embedding = self.embed_text(content)
148
+ if not embedding:
149
+ continue
150
+
151
+ # Create point
152
+ point = PointStruct(
153
+ id=str(uuid.uuid4()),
154
+ vector=embedding,
155
+ payload={
156
+ "article_id": doc.get("id", ""),
157
+ "title": doc.get("title", ""),
158
+ "content": content,
159
+ "metadata": doc.get("metadata", {}),
160
+ },
161
+ )
162
+ points.append(point)
163
+
164
+ # Batch upload
165
+ batch_size = 100
166
+ for i in range(0, len(points), batch_size):
167
+ batch = points[i : i + batch_size]
168
+ self.client.upsert(collection_name=self.collection_name, points=batch)
169
+ print(
170
+ f"Uploaded batch {i//batch_size + 1}/{(len(points) + batch_size - 1)//batch_size}"
171
+ )
172
+
173
+ print(f"Successfully added {len(points)} documents to vector store")
174
+
175
+ except Exception as e:
176
+ print(f"Error adding documents: {e}")
177
+ raise
178
+
179
+ def search_similar_documents(
180
+ self, query: str, top_k: int = None, score_threshold: float = None
181
+ ) -> List[Dict[str, Any]]:
182
+ """Search for similar documents"""
183
+ if top_k is None:
184
+ top_k = Config.TOP_K_RETRIEVAL
185
+ if score_threshold is None:
186
+ score_threshold = Config.SIMILARITY_THRESHOLD
187
+
188
+ try:
189
+ # Generate query embedding
190
+ query_embedding = self.embed_text(query)
191
+ if not query_embedding:
192
+ return []
193
+
194
+ # Search
195
+ search_results = self.client.search(
196
+ collection_name=self.collection_name,
197
+ query_vector=query_embedding,
198
+ limit=top_k,
199
+ score_threshold=score_threshold,
200
+ )
201
+
202
+ # Format results
203
+ results = []
204
+ for result in search_results:
205
+ results.append(
206
+ {
207
+ "id": result.payload.get("article_id", ""),
208
+ "title": result.payload.get("title", ""),
209
+ "content": result.payload.get("content", ""),
210
+ "score": result.score,
211
+ "metadata": result.payload.get("metadata", {}),
212
+ }
213
+ )
214
+
215
+ print(f"Found {len(results)} similar documents")
216
+ return results
217
+
218
+ except Exception as e:
219
+ print(f"Error searching documents: {e}")
220
+ return []
221
+
222
+ def get_collection_info(self) -> Dict[str, Any]:
223
+ """Get collection information"""
224
+ try:
225
+ info = self.client.get_collection(self.collection_name)
226
+ result = {
227
+ "name": self.collection_name, # Use the collection name we know
228
+ "vectors_count": info.vectors_count,
229
+ "indexed_vectors_count": info.indexed_vectors_count,
230
+ "points_count": info.points_count,
231
+ }
232
+ print(f"Collection info: {result}")
233
+ return result
234
+ except Exception as e:
235
+ print(f"Collection '{self.collection_name}' does not exist: {e}")
236
+ return {}
237
+
238
+ def delete_collection(self):
239
+ """Delete collection"""
240
+ try:
241
+ self.client.delete_collection(self.collection_name)
242
+ print(f"Deleted collection: {self.collection_name}")
243
+ except Exception as e:
244
+ print(f"Error deleting collection: {e}")