Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import logging | |
| from typing import List, Dict, Any | |
| from data_processor import load_json_data, process_documents, split_documents | |
| from rag_pipeline import RAGPipeline | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| DATA_PATH = "ltu_programme_data.json" | |
| QDRANT_PATH = "./qdrant_data" | |
| EMBEDDING_MODEL = "BAAI/bge-en-icl" | |
| LLM_MODEL = "meta-llama/Llama-3.3-70B-Instruct" | |
| qdrant = None | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| def get_rag_pipeline(): | |
| return RAGPipeline( | |
| embedding_model_name=EMBEDDING_MODEL, | |
| llm_model_name=LLM_MODEL, | |
| qdrant_path = QDRANT_PATH | |
| ) | |
| def load_and_index_documents(rag_pipeline: RAGPipeline) -> bool: | |
| """Load and index documents""" | |
| if not os.path.exists(DATA_PATH): | |
| st.error(f"Data file not found: {DATA_PATH}") | |
| return False | |
| with st.spinner("Loading and processing documents..."): | |
| # Load data | |
| data = load_json_data(DATA_PATH) | |
| if not data: | |
| st.error("Failed to load data") | |
| return False | |
| # Process documents | |
| processed_docs = process_documents(data) | |
| if not processed_docs: | |
| st.error("Failed to process documents") | |
| return False | |
| # Split documents | |
| chunked_docs = split_documents(processed_docs, chunk_size=1000, overlap=100) | |
| if not chunked_docs: | |
| st.error("Failed to split documents") | |
| return False | |
| # Index documents | |
| with st.spinner(f"Indexing {len(chunked_docs)} document chunks..."): | |
| rag_pipeline.index_documents(chunked_docs) | |
| return True | |
| def display_document_sources(documents: List[Dict[str, Any]]): | |
| """Display the sources of the retrieved documents""" | |
| if documents: | |
| with st.expander("View Sources"): | |
| for i, doc in enumerate(documents): | |
| st.markdown(f"**Source {i+1}**: [{doc.meta.get('url', 'Unknown')}]({doc.meta.get('url', '#')})") | |
| st.markdown(f"**Excerpt**: {doc.content[:200]}...") | |
| st.markdown("---") | |
| def check_documents_indexed(qdrant_path: str) -> int: | |
| """Check if documents are already indexed by returning the number of documents in Qdrant""" | |
| try: | |
| from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
| # Initialize the document store with the existing path | |
| document_store = QdrantDocumentStore( | |
| path=qdrant_path, | |
| embedding_dim=4096, | |
| recreate_index=False, | |
| index="ltu_documents" | |
| ) | |
| # Get the document count | |
| document_count = len(document_store.filter_documents({})) | |
| return document_count | |
| except Exception: | |
| # If there's an error (e.g., Qdrant not initialized), return 0 | |
| return 0 | |
| def main(): | |
| # Set page config | |
| st.set_page_config( | |
| page_title="LTU Chat - QA App", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Header | |
| st.title("π LTU Chat - QA App") | |
| st.markdown(""" | |
| Ask questions about LTU programmes and get answers powered by AI. | |
| This app uses RAG (Retrieval Augmented Generation) to provide accurate information. | |
| """) | |
| rag_pipeline = get_rag_pipeline() | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("Sett`ings") | |
| # Initialize RAG pipeline if not already done | |
| # if st.session_state.rag_pipeline is None: | |
| # if st.button("Initialize RAG Pipeline"): | |
| # st.session_state.rag_pipeline = get_rag_pipeline() | |
| # st.success("RAG pipeline initialized successfully!") | |
| # else: | |
| # st.success("RAG pipeline is ready!") | |
| # Check if documents are already indexed | |
| documents_indexed = rag_pipeline.get_document_count() | |
| if not documents_indexed: | |
| if st.button("Index Documents"): | |
| success = load_and_index_documents(rag_pipeline) | |
| if success: | |
| st.success("Documents indexed successfully!") | |
| # Refresh the documents_indexed status | |
| documents_indexed = True | |
| # Get document counts | |
| count = rag_pipeline.get_document_count() | |
| st.info(f"Indexed {count} documents documents in vector store.") | |
| else: | |
| st.success(f"{documents_indexed} documents are indexed and ready!") | |
| top_k = st.slider("Number of documents to retrieve", min_value=1, max_value=10, value=5) | |
| # Work in progress | |
| st.title("Work in progress") | |
| st.toggle("Hybrid retrieval", disabled=True) | |
| st.toggle("Self RAG", disabled=True) | |
| st.toggle("Query Expansion", disabled=True) | |
| st.toggle("Graph RAG", disabled=True) | |
| st.toggle("Prompt engineering (CoT, Step-Back Prompt, Active Prompt)", disabled=True) | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if message.get("documents"): | |
| display_document_sources(message["documents"]) | |
| # Chat input | |
| if prompt := st.chat_input("Ask a question about LTU programmes"): | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Generate response | |
| if rag_pipeline and documents_indexed: | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| # Query the RAG pipeline | |
| result = rag_pipeline.query(prompt, top_k=top_k) | |
| # Display the answer | |
| st.markdown(result["answer"]) | |
| # Display sources | |
| if result.get("documents"): | |
| display_document_sources(result["documents"]) | |
| # Add assistant message to chat history | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": result["answer"], | |
| "documents": result.get("documents", []) | |
| }) | |
| else: | |
| with st.chat_message("assistant"): | |
| if not rag_pipeline: | |
| error_message = "Please initialize the RAG pipeline first." | |
| else: | |
| error_message = "Please index documents first." | |
| st.error(error_message) | |
| st.session_state.messages.append({"role": "assistant", "content": error_message}) | |
| if __name__ == "__main__": | |
| main() | |