import os import json import faiss import pickle import numpy as np from datetime import datetime from uuid import uuid4 from sentence_transformers import SentenceTransformer, CrossEncoder, util from sklearn.feature_extraction.text import TfidfVectorizer from rank_bm25 import BM25Okapi from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BertTokenizer, BertForSequenceClassification import torch import torch.nn.functional as F class RetrievalScorer: """ Handles different retrieval scoring methods (FAISS, TF-IDF, BM25) and combines them. """ def __init__(self, texts, embedder, tfidf_vectorizer, tfidf_matrix, bm25, index): self.texts = texts self.embedder = embedder self.tfidf_vectorizer = tfidf_vectorizer self.tfidf_matrix = tfidf_matrix self.bm25 = bm25 self.index = index def faiss_score(self, query: str, top_k: int) -> list[tuple[str, float]]: """Scores documents based on FAISS (semantic similarity).""" query_vec = self.embedder.encode([query])[0] distances, indices = self.index.search(np.array([query_vec]), top_k) # Convert distances to scores (higher is better) scores = 1 - distances[0] / (np.max(distances[0]) + 1e-5) if np.max(distances[0]) > 0 else np.zeros_like(distances[0]) return [(self.texts[i], scores[j]) for j, i in enumerate(indices[0])] def tfidf_score(self, query: str, top_k: int) -> list[tuple[str, float]]: """Scores documents based on TF-IDF.""" query_vec = self.tfidf_vectorizer.transform([query]) scores = np.dot(query_vec, self.tfidf_matrix.T).toarray()[0] top_idx = np.argsort(scores)[-top_k:][::-1] return [(self.texts[i], scores[i]) for i in top_idx if scores[i] > 0] # Filter out zero scores def bm25_score(self, query: str, top_k: int) -> list[tuple[str, float]]: """Scores documents based on BM25.""" tokens = query.lower().split() # Lowercase for better matching scores = self.bm25.get_scores(tokens) top_idx = np.argsort(scores)[-top_k:][::-1] return [(self.texts[i], scores[i]) for i in top_idx if scores[i] > 0] # Filter out zero scores def hybrid_score(self, query: str, top_k: int) -> list[str]: """Combines scores from all methods and re-ranks for top_k documents.""" # Retrieve more candidates initially to allow for better re-ranking candidates_multiplier = 3 faiss_candidates = self.faiss_score(query, top_k * candidates_multiplier) tfidf_candidates = self.tfidf_score(query, top_k * candidates_multiplier) bm25_candidates = self.bm25_score(query, top_k * candidates_multiplier) # Aggregate scores by document score_map = {} for doc, score in faiss_candidates + tfidf_candidates + bm25_candidates: score_map[doc] = score_map.get(doc, 0) + score # Sort aggregated scores and return top_k documents sorted_docs = sorted(score_map.items(), key=lambda x: x[1], reverse=True) return [doc for doc, _ in sorted_docs[:top_k]] class RAGPipeline: """ Implements a RAG pipeline with hybrid retrieval, LLM generation, and integrated defense mechanisms against adversarial queries and hallucinations. """ def __init__( self, json_path: str = "calebdata.json", embedder_model: str = "infly/inf-retriever-v1-1.5b", reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", generator_model: str = "google/flan-t5-base", defense_model_path: str = "./defense_model", # Path to saved BERT defense model cache_dir: str = "cache" ): os.makedirs(cache_dir, exist_ok=True) self.cache_dir = cache_dir self.chunks = self._load_chunks(json_path) self.texts = [chunk["text"] for chunk in self.chunks] # Load models print("Loading embedder model...") self.embedder = SentenceTransformer(embedder_model) print("Loading reranker model...") self.reranker = CrossEncoder(reranker_model) print("Loading generator tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(generator_model) print("Loading generator model...") self.generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model) self.embeddings = self._load_or_compute_embeddings() self.index = self._load_or_build_faiss_index() self.tfidf_vectorizer, self.tfidf_matrix = self._build_tfidf() # BM25 requires tokenized corpus for initialization self.bm25_corpus = [text.lower().split() for text in self.texts] self.bm25 = BM25Okapi(self.bm25_corpus) self.retriever = RetrievalScorer( self.texts, self.embedder, self.tfidf_vectorizer, self.tfidf_matrix, self.bm25, self.index ) # Initialize defense components self.defense_tokenizer = None self.defense_model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if os.path.exists(defense_model_path) and os.path.isdir(defense_model_path): print(f"Loading defense model from {defense_model_path}...") self.defense_tokenizer = BertTokenizer.from_pretrained(defense_model_path) self.defense_model = BertForSequenceClassification.from_pretrained(defense_model_path) self.defense_model.to(self.device) self.defense_model.eval() print("Defense model loaded successfully.") else: print(f"Warning: Defense model not found at {defense_model_path}. Running without adversarial query detection.") print("Please run `python defense.py` to train and save the model.") def _load_chunks(self, path: str) -> list[dict]: """Loads and preprocesses text chunks from a JSON file.""" with open(path, "r") as f: raw = json.load(f) seen_texts = set() filtered_chunks = [] for item in raw: text = (item.get("text") or "").strip().replace("\n", " ") # Filter out short or duplicate texts if len(text) < 30 or text in seen_texts: continue seen_texts.add(text) filtered_chunks.append({"id": str(uuid4()), "text": text, "metadata": item.get("metadata", {})}) print(f"Loaded {len(filtered_chunks)} chunks from {path}") return filtered_chunks def _load_or_compute_embeddings(self) -> np.ndarray: """Loads embeddings from cache or computes and saves them.""" path = os.path.join(self.cache_dir, "embeddings.pkl") if os.path.exists(path): print("Loading embeddings from cache...") with open(path, "rb") as f: return pickle.load(f) print("Computing embeddings (this may take a while)...") # Ensure texts are strings for embedding embeddings = self.embedder.encode(self.texts, convert_to_numpy=True) with open(path, "wb") as f: pickle.dump(embeddings, f) print("Embeddings computed and saved.") return embeddings def _load_or_build_faiss_index(self) -> faiss.Index: """Loads FAISS index from cache or builds and saves it.""" path = os.path.join(self.cache_dir, "faiss.index") dimension = self.embeddings.shape[1] if os.path.exists(path): print("Loading FAISS index from cache...") return faiss.read_index(path) print("Building FAISS index...") index = faiss.IndexFlatL2(dimension) # Using L2 distance index.add(self.embeddings) faiss.write_index(index, path) print("FAISS index built and saved.") return index def _build_tfidf(self) -> tuple[TfidfVectorizer, np.ndarray]: """Builds TF-IDF vectorizer and matrix.""" print("Building TF-IDF model...") vectorizer = TfidfVectorizer() matrix = vectorizer.fit_transform(self.texts) print("TF-IDF model built.") return vectorizer, matrix def _rerank(self, query: str, docs: list[str], min_score: float = 0.1) -> list[str]: """Re-ranks retrieved documents using a cross-encoder.""" if not docs: return [] pairs = [(query, doc) for doc in docs] scores = self.reranker.predict(pairs) scored_docs = sorted(zip(scores, docs), reverse=True) return [doc for score, doc in scored_docs if score > min_score] def hybrid_search(self, query: str, top_k: int = 5) -> list[str]: """Performs hybrid search and re-ranks results.""" candidates = self.retriever.hybrid_score(query, top_k=top_k * 3) # Get more candidates for reranking reranked = self._rerank(query, candidates) self._log_retrieval(query, reranked[:top_k]) return reranked[:top_k] def _log_retrieval(self, query: str, docs: list[str]): """Logs retrieval events for analysis.""" log = { "timestamp": datetime.now().isoformat(), "query": query, "retrieved_docs": docs } with open(os.path.join(self.cache_dir, "retrieval_log.jsonl"), "a") as f: f.write(json.dumps(log) + "\n") def build_context_window(self, query: str, max_tokens: int = 450, add_poisoned_doc: str = None) -> str: """ Builds the context window from retrieved documents. Includes a conceptual flag for simulating data poisoning. """ passages = self.hybrid_search(query, top_k=10) # --- Conceptual Data Poisoning Simulation --- # If add_poisoned_doc is provided, inject it at the top of the context # This simulates a successful poisoning where a malicious document # is highly ranked and retrieved. if add_poisoned_doc: passages.insert(0, add_poisoned_doc) # ------------------------------------------- context = "" total_tokens = 0 for passage in passages: tokens = self.tokenizer.tokenize(passage) if total_tokens + len(tokens) > max_tokens: break context += passage + "\n" total_tokens += len(tokens) return context.strip() def _is_query_adversarial(self, query: str, threshold: float = 0.7) -> bool: """ Detects if a query is adversarial using the trained defense model. Returns True if adversarial, False otherwise. """ if not self.defense_model or not self.defense_tokenizer: return False # No defense model loaded, bypass check encoding = self.defense_tokenizer.encode_plus( query, add_special_tokens=True, truncation=True, max_length=128, # Match max_len used in training padding='max_length', return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.defense_model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = F.softmax(logits, dim=1) # Assuming label 1 is adversarial adversarial_prob = probabilities[0][1].item() return adversarial_prob >= threshold def _check_groundedness_and_hallucination(self, generated_answer: str, context: str, min_overlap_ratio: float = 0.3) -> bool: """ Basic check for groundedness/hallucination: Checks if a significant portion of the generated answer's key phrases are present in the provided context. This is a heuristic. A more advanced approach would use semantic similarity or entailment. """ if not context: return True # If no context, can't check groundedness, assume it's okay or handle as ungrounded # Simple keyword overlap check context_words = set(context.lower().split()) answer_words = set(generated_answer.lower().split()) common_words = context_words.intersection(answer_words) # Filter out very common stopwords from the overlap check for better signal stopwords = set(["a", "an", "the", "is", "are", "was", "were", "and", "or", "in", "on", "at", "for", "with", "from", "to", "of", "about", "this", "that", "it", "its"]) meaningful_common_words = [word for word in common_words if word not in stopwords] if len(answer_words) == 0: return True # Empty answer overlap_ratio = len(meaningful_common_words) / (len(answer_words) - len(answer_words.intersection(stopwords))) if (len(answer_words) - len(answer_words.intersection(stopwords))) > 0 else 0 # Semantic similarity check (complementary to keyword overlap) # Lower sim implies less groundedness semantic_scorer = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L3-v2") context_embedding = semantic_scorer.encode(context, convert_to_tensor=True) answer_embedding = semantic_scorer.encode(generated_answer, convert_to_tensor=True) semantic_similarity = util.pytorch_cos_sim(context_embedding, answer_embedding).item() # You can adjust these thresholds based on empirical testing is_grounded_by_keywords = overlap_ratio >= min_overlap_ratio is_grounded_by_semantic_sim = semantic_similarity >= 0.6 # Example threshold # Consider it not hallucinated if either keyword overlap OR semantic similarity is good # Or, for stricter groundedness, require BOTH to be good. return is_grounded_by_keywords or is_grounded_by_semantic_sim def generate_answer(self, query: str, add_poisoned_doc: str = None) -> dict: """ Generates an answer to the query using the RAG pipeline, with defenses. Returns a dictionary containing the answer, sources, and defense flags. """ is_adversarial_query = self._is_query_adversarial(query) if is_adversarial_query: return { "answer": "I cannot process this request due to potential security concerns. Please rephrase your query.", "sources": [], "defense_triggered": True, "hallucinated": False, "reason": "Adversarial Query Detected" } context = self.build_context_window(query, add_poisoned_doc=add_poisoned_doc) if not context: return { "answer": "I couldn't find enough relevant information in my knowledge base to answer that question.", "sources": [], "defense_triggered": True, "hallucinated": False, "reason": "No Relevant Context Found" } prompt = f"Answer the question based on the context.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:" inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512) # Ensure inputs are on the correct device inputs = {k: v.to(self.device) for k, v in inputs.items()} output = self.generator.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=300, do_sample=True, top_p=0.95, top_k=50, pad_token_id=self.tokenizer.eos_token_id ) answer = self.tokenizer.decode(output[0], skip_special_tokens=True) # Basic hallucination check is_hallucinated = not self._check_groundedness_and_hallucination(answer, context) retrieved_docs = self.hybrid_search(query) # Re-run to get actual retrieved docs without poisoned_doc in logs return { "answer": answer, "sources": retrieved_docs, "defense_triggered": False, # No query defense triggered "hallucinated": is_hallucinated, "reason": "Hallucination Detected" if is_hallucinated else "Normal Operation" } def generate_answer_with_sources(self, query: str, add_poisoned_doc: str = None) -> dict: """ Generates an answer and provides the sources for transparency. This method will leverage the defense mechanisms of generate_answer. """ result = self.generate_answer(query, add_poisoned_doc=add_poisoned_doc) return result