Spaces:
Build error
Build error
| import gradio as gr | |
| import utils | |
| from datasets import load_dataset, concatenate_datasets | |
| from langchain.docstore.document import Document as LangchainDocument | |
| from tqdm import tqdm | |
| import pickle | |
| from ragatouille import RAGPretrainedModel | |
| import chunker | |
| import retriver | |
| import rag | |
| import nltk | |
| import config | |
| import os | |
| import warnings | |
| import sys | |
| import logging | |
| logging.getLogger("langchain").setLevel(logging.ERROR) | |
| warnings.filterwarnings("ignore") | |
| class AnswerSystem: | |
| def __init__(self, rag_system) -> None: | |
| self.rag_system = rag_system | |
| def answer_generate(self, question, bm_25_flag, semantic_flag, temperature): | |
| answer, relevant_docs = self.rag_system.answer( | |
| question=question, | |
| temperature=temperature, | |
| bm_25_flag=bm_25_flag, | |
| semantic_flag=semantic_flag, | |
| num_retrieved_docs = 10, | |
| num_docs_final = 5 | |
| ) | |
| formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)]) | |
| return answer, formatted_docs | |
| def run_app(rag_model): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # RealTimeData Monthly Collection - BBC News Documentation Assistant | |
| Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset. | |
| For example: | |
| - *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"* | |
| """ | |
| ) | |
| # Поля вводу | |
| question_input = gr.Textbox(label="Enter your question:", | |
| placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?") | |
| bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True) # BM25 flag | |
| semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True) # Semantic flag | |
| temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5, | |
| step=0.1) # Temperature | |
| # Кнопка пошуку | |
| search_button = gr.Button("Search") | |
| # Поля виводу | |
| answer_output = gr.Textbox(label="Answer", interactive=False, lines=5) | |
| docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10) | |
| # Логіка пошуку | |
| system = AnswerSystem(rag_model) | |
| search_button.click( | |
| system.answer_generate, | |
| inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider], # Всі параметри | |
| outputs=[answer_output, docs_output] | |
| ) | |
| # Запуск додатку | |
| demo.launch(debug=True, share=True) | |
| def get_rag_data(): | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| if os.path.exists(config.DOCUMENTS_PATH): | |
| print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}") | |
| with open(config.DOCUMENTS_PATH, "rb") as file: | |
| docs_processed = pickle.load(file) | |
| else: | |
| print("Processing documents...") | |
| datasets_list = [ | |
| utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"]) | |
| for config in tqdm(config.AVAILABLE_DATASET_CONFIGS) | |
| ] | |
| ds = concatenate_datasets(datasets_list) | |
| RAW_KNOWLEDGE_BASE = [ | |
| LangchainDocument( | |
| page_content=doc["content"], | |
| metadata={ | |
| "title": doc["title"], | |
| "published_date": doc["published_date"], | |
| "authors": doc["authors"], | |
| "section": doc["section"], | |
| "description": doc["description"], | |
| "link": doc["link"] | |
| } | |
| ) | |
| for doc in tqdm(ds) | |
| ] | |
| docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE) | |
| print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}") | |
| with open(config.DOCUMENTS_PATH, "wb") as file: | |
| pickle.dump(docs_processed, file) | |
| return docs_processed | |
| if __name__ == '__main__': | |
| docs_processed = get_rag_data() | |
| bm25 = retriver.create_bm25(docs_processed) | |
| KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed) | |
| RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL) | |
| rag_generator = rag.RAGAnswerGenerator( | |
| docs=docs_processed, | |
| bm25=bm25, | |
| knowledge_index=KNOWLEDGE_VECTOR_DATABASE, | |
| reranker=RERANKER | |
| ) | |
| run_app(rag_generator) |