Spaces:
Running
on
Zero
Running
on
Zero
| """Template Demo for IBM Granite Hugging Face spaces.""" | |
| import os | |
| import time | |
| from pathlib import Path | |
| import re | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from gradio_pdf import PDF | |
| from sandbox.light_rag.light_rag import LightRAG | |
| from themes.research_monochrome import theme | |
| dir_ = Path(__file__).parent.parent | |
| TITLE = "Multimodal RAG with Granite Vision 3.2" | |
| DESCRIPTION = """ | |
| <p>This experimental demo highlights granite-vision-3.2-2b capabilities within a multimodal retrieval-augmented generation (RAG) pipeline, demonstrating Granite's document understanding in real-world applications. Explore the sample document excerpts and try the sample prompts or enter your own. Keep in mind that AI can occasionally make mistakes. | |
| <span class="gr_docs_link"> | |
| <a href="https://www.ibm.com/granite/docs/models/vision/">View Documentation <i class="fa fa-external-link"></i></a> | |
| </span> | |
| </p> | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| BASE_PATH = dir_ / "data" / "final_v2_mar04" | |
| PDFS_PATH = BASE_PATH / "pdfs" | |
| MILVUS_PATH = BASE_PATH / "milvus" | |
| IMAGES_PATH = BASE_PATH / "images" | |
| PREVIEWS_PATH = BASE_PATH / "preview" | |
| sample_data = [ | |
| { | |
| "preview_image": str(PREVIEWS_PATH / "IBM-financial-2010.png"), | |
| "prompts": """Where geographically was the greatest growth in revenue in 2007? | |
| Which year had the highest income in billion? | |
| Did the net income decrease in 2007 compared to 2006? | |
| Net cash from operations on 2005? | |
| What does it mean to be Globally Integrated Enterprise? | |
| What are the segments for pretax income?""".split("\n"), | |
| "pdf": str(PDFS_PATH / "IBM_Annual_Report_2007_3-20.pdf"), | |
| "index": "ibm_report_2007_short_text_milvus_lite_2048_128_slate_278m_cosine", | |
| "db": str(MILVUS_PATH / "milvus.db"), | |
| "name": "IBM annual report 2007", | |
| "origin": "https://www.ibm.com/investor/att/pdf/IBM_Annual_Report_2007.pdf", | |
| "image_paths": {"prefix": str(IMAGES_PATH / "ibm_report_2007") + "/", "use_last": 2}, | |
| }, | |
| { | |
| "preview_image": str(PREVIEWS_PATH / "Wilhlborg-financial.png"), | |
| "prompts": """Where does Wihlborgs mainly operate? | |
| Which year had the second lowest Equity/assets ratio? | |
| Which year had the highest Project investments value? | |
| What is the trend of equity/assets ratio? | |
| What was the Growth percentage in income from property management in 2020? | |
| Has the company’s interest coverage ratio increased or decreased in recent years?""".split("\n") | |
| , | |
| "pdf": str(PDFS_PATH / "wihlborgs-2-13_16-18.pdf"), | |
| "index": "wihlborgs_short_text_milvus_lite_2048_128_slate_278m_cosine", | |
| "db": str(MILVUS_PATH / "milvus.db"), | |
| "name": "Wihlborgs Report 2020", | |
| "origin": "https://www.wihlborgs.se/globalassets/investor-relations/rapporter/2021/20210401-wihlborgs-annual-report-and-sustainability-report-2020-c24a6b51-c124-44fc-a4af-4237a33a29fb.pdf", | |
| "image_paths": {"prefix": str(IMAGES_PATH / "wihlborgs") + "/", "use_last": 2}, | |
| }, | |
| ] | |
| config = { | |
| "embedding_model_id": "ibm-granite/granite-embedding-278m-multilingual", | |
| "generation_model_id": "ibm-granite/granite-3.1-8b-instruct", | |
| "milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine", | |
| "milvus_db_path": str(dir_ / "data" / MILVUS_PATH / "milvus_text_sample.db"), | |
| } | |
| if gr.NO_RELOAD: | |
| light_rag: LightRAG = LightRAG(config) | |
| if not os.environ.get("LAZY_LOADING") == "true": | |
| for sample in sample_data: | |
| light_rag.precache_milvus(sample["index"], sample["db"]) | |
| def lower_md_headers(md: str) -> str: | |
| return re.sub(r'(?:^|\n)##?\s(.+)', lambda m: '\n### ' + m.group(1), md) | |
| # Parser for retrival results | |
| def format_retrieval_result(i, d, cb, selected_sample): | |
| image_paths = sample_data[selected_sample]["image_paths"] | |
| if d.metadata["type"] == "text": | |
| context_string = f"---\n## Context {i + 1}\n#### (text extracted from document)\n{lower_md_headers(d.page_content)}\n" | |
| cb.append(gr.ChatMessage(role="assistant", content=context_string)) | |
| return True | |
| elif d.metadata["type"] == "image_description": | |
| context_string = f"---\n## Context {i + 1}\n#### (image description generated by Granite Vision)" | |
| cb.append(gr.ChatMessage(role="assistant", content=context_string)) | |
| # /dccstor/mm-rag/idanfr/granite_vision_demo/wdu_output/IBM_Annual_Report_2007/images/IBM_Annual_Report_2007_im_image_7_1.png | |
| image_path_parts = d.metadata["image_fullpath"].split("/") | |
| image_path = image_paths["prefix"] + ("/".join(image_path_parts[-image_paths["use_last"]:])) | |
| # print(f"image_path: {image_path}") | |
| cb.append(gr.ChatMessage(role="assistant", content=gr.Image(image_path))) | |
| cb.append(gr.ChatMessage(role="assistant", content=f"\n{lower_md_headers(d.metadata['image_description'])}\n")) | |
| chatbot = gr.Chatbot( | |
| examples=[{"text": x} for x in sample_data[0]["prompts"]], | |
| type="messages", | |
| label=f"Q&A about {sample_data[0]['name']}", | |
| height=685, | |
| group_consecutive_messages=True, | |
| autoscroll=False, | |
| elem_classes=["chatbot_view"], | |
| ) | |
| def generate_with_llm(query, context): | |
| if os.environ.get("NO_LLM"): | |
| time.sleep(2) | |
| return "Now answer, just a string", query | |
| return light_rag.generate(query=query, context=context) | |
| # TODO: maybe add GPU back ? | |
| def retrieval(collection, db, q): | |
| return light_rag.search(q, top_n=3, collection=collection, db=db) | |
| # ################ | |
| # User Interface | |
| # ################ | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
| is_in_edit_mode = gr.State(True) # in block to be reactive | |
| selected_doc = gr.State(0) | |
| current_question = gr.State("") | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| # LEFT COLUMN: Sample selection, download, and PDF viewer | |
| with gr.Column(): | |
| # Show preview images | |
| images_only = [sd["preview_image"] for sd in sample_data] | |
| document_gallery = gr.Gallery( | |
| images_only, | |
| label="Select a document", | |
| rows=1, | |
| columns=3, | |
| height="125px", | |
| # width="125px", | |
| allow_preview=False, | |
| selected_index=0, | |
| elem_classes=["preview_im_element"], | |
| ) | |
| with gr.Group(): | |
| pdf_display = PDF( | |
| sample_data[0]["pdf"], | |
| label=f"Preview for {sample_data[0]['name']}", | |
| height=460, | |
| interactive=False, | |
| elem_classes=["pdf_viewer"], | |
| ) | |
| dl_btn = gr.DownloadButton( | |
| label=f"Download PDF ({sample_data[0]['name']})", value=sample_data[0]["pdf"], visible=True | |
| ) | |
| def sample_image_selected(d: gr.SelectData): | |
| dx = sample_data[d.index] | |
| # print(f"DX:{dx}") | |
| return ( | |
| gr.update(examples=[{"text": x} for x in dx["prompts"]], label=f"Q&A about {dx['name']}"), | |
| gr.update(value=dx["pdf"], label=f"Preview for {dx['name']}"), | |
| gr.DownloadButton(value=dx["pdf"], label=f"Download PDF ({dx['name']})"), | |
| d.index | |
| ) | |
| document_gallery.select(lambda: [], outputs=[chatbot]) | |
| document_gallery.select(sample_image_selected, inputs=[], | |
| outputs=[chatbot, pdf_display, dl_btn, selected_doc]) | |
| # Right Column: Chat interface | |
| with gr.Column(): | |
| # Render ChatBot | |
| chatbot.render() | |
| # Define behavior for example selection | |
| def update_user_chat_x(x: gr.SelectData): | |
| return [gr.ChatMessage(role="user", content=x.value["text"])] | |
| def question_from_selection(x: gr.SelectData): | |
| return x.value["text"] | |
| def _decorate_yield_result(cb, fb_status=False, gallery_status=False): | |
| return cb, gr.Button(interactive=fb_status), gr.Gallery( | |
| elem_classes=["preview_im_element"] if gallery_status else ["preview_im_element", "inactive_div"]) | |
| def send_generate(msg, cb, selected_sample): | |
| collection = sample_data[selected_sample]["index"] | |
| db = sample_data[selected_sample]["db"] | |
| # print(f"collection: {collection}, {db}") | |
| original_msg = gr.ChatMessage(role="user", content=msg) | |
| cb.append(original_msg) | |
| waiting_for_retrieval_msg = gr.ChatMessage(role="assistant", | |
| content='## Answer\n*Querying Index*<span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> <span class="dot-3">.</span></span>') | |
| cb.append(waiting_for_retrieval_msg) | |
| yield _decorate_yield_result(cb) | |
| q = msg.strip() | |
| results = retrieval(collection, db, q) | |
| # for d in results: | |
| # print(f"****\n{d}") | |
| context_string = "## Context Documents for Answer\n\n" | |
| for i, d in enumerate(results): | |
| if format_retrieval_result(i, d, cb, selected_sample): | |
| yield _decorate_yield_result(cb) | |
| waiting_for_llm_msg = gr.ChatMessage(role="assistant", | |
| content='## Answer\n *Waiting for LLM* <span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> <span class="dot-3">.</span></span> ') | |
| cb[1] = waiting_for_llm_msg | |
| yield _decorate_yield_result(cb) | |
| answer, prompt = generate_with_llm(q, results) | |
| cb[1] = gr.ChatMessage(role="assistant", content=f"## Answer\n<b>{answer.strip()}</b>") | |
| # cb.pop() | |
| # cb.append(gr.ChatMessage(role="assistant", content=f"## Answer\n<b>{answer.strip()}</b>")) | |
| yield _decorate_yield_result(cb, fb_status=True, gallery_status=True) | |
| # Create User Chat Textbox and Reset Button | |
| tbb = gr.Textbox(submit_btn=True, show_label=False, placeholder="Type a message...") | |
| fb = gr.Button("Ask new question", visible=False) | |
| fb.click(lambda: [], outputs=[chatbot]) | |
| chatbot.example_select(lambda: False, outputs=is_in_edit_mode) | |
| # chatbot.example_select(update_user_chat_x, outputs=[chatbot]) | |
| chatbot.example_select(question_from_selection, inputs=[], outputs=[current_question] | |
| ).then(send_generate, inputs=[current_question, chatbot, selected_doc], | |
| outputs=[chatbot, fb, document_gallery]) | |
| def textbox_switch(e_mode): # Handle toggling between edit and non-edit mode | |
| if not e_mode: | |
| return [gr.update(visible=False), gr.update(visible=True)] | |
| else: | |
| return [gr.update(visible=True), gr.update(visible=False)] | |
| tbb.submit(lambda: False, outputs=[is_in_edit_mode]) | |
| fb.click(lambda: True, outputs=[is_in_edit_mode]) | |
| is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb]) | |
| # submit user question | |
| # tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot) | |
| tbb.submit(lambda x: x, inputs=[tbb], outputs=[current_question] | |
| ).then(send_generate, | |
| inputs=[current_question, chatbot, selected_doc], | |
| outputs=[chatbot, fb, document_gallery]) | |
| if __name__ == "__main__": | |
| # demo.queue(max_size=20).launch() | |
| demo.launch() | |