Spaces:
Running
Running
| """ | |
| Visual-CoT: Chain-of-Thought Reasoning Demo on Hugging Face Spaces | |
| Showcasing Visual Chain-of-Thought with Interactive Benchmark Examples | |
| Paper: Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive | |
| Dataset and Benchmark for Chain-of-Thought Reasoning | |
| https://arxiv.org/abs/2403.16999 | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import re | |
| import json | |
| import spaces | |
| from pathlib import Path | |
| import requests | |
| from io import BytesIO | |
| from llava.constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| ) | |
| from llava.conversation import conv_templates | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from llava.mm_utils import ( | |
| process_images, | |
| tokenizer_image_token, | |
| get_model_name_from_path, | |
| ) | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| MODEL_PATH = "deepcs233/VisCoT-7b-336" # Hugging Face model ID | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Benchmark datasets available | |
| BENCHMARK_DATASETS = [ | |
| "docvqa", | |
| "flickr30k", | |
| "gqa", | |
| "infographicsvqa", | |
| "openimages", | |
| "textcap", | |
| "textvqa", | |
| "vsr", | |
| "cub", | |
| ] | |
| # Global model variables (lazy loading) | |
| tokenizer, model, image_processor, context_len = None, None, None, None | |
| # ============================================================================= | |
| # Model Loading (with Zero GPU optimization) | |
| # ============================================================================= | |
| def load_model_once(): | |
| """Load model once and cache it""" | |
| global tokenizer, model, image_processor, context_len | |
| if model is not None: | |
| return tokenizer, model, image_processor, context_len | |
| print("🔄 Loading Visual-CoT model...") | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(MODEL_PATH) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| MODEL_PATH, | |
| None, | |
| model_name, | |
| load_8bit=False, | |
| load_4bit=False, | |
| device=DEVICE, | |
| ) | |
| print("✓ Model loaded successfully!") | |
| return tokenizer, model, image_processor, context_len | |
| # ============================================================================= | |
| # Utility Functions | |
| # ============================================================================= | |
| def parse_bbox(text): | |
| """Parse bounding box from model output""" | |
| pattern1 = r"###\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" | |
| pattern2 = r"\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" | |
| matches = re.findall(pattern1, text) | |
| if not matches: | |
| matches = re.findall(pattern2, text) | |
| if matches: | |
| bbox = [float(x) for x in matches[-1]] | |
| if all(0 <= x <= 1 for x in bbox): | |
| return bbox | |
| return None | |
| def draw_bounding_box(image, bbox, color="red", width=5): | |
| """Draw bounding box on image""" | |
| if bbox is None: | |
| return image | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| img_width, img_height = img.size | |
| # Convert normalized to pixel coordinates | |
| x1 = int(bbox[0] * img_width) | |
| y1 = int(bbox[1] * img_height) | |
| x2 = int(bbox[2] * img_width) | |
| y2 = int(bbox[3] * img_height) | |
| # Draw rectangle | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| # Draw label | |
| label = f"ROI: [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}]" | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14) | |
| except: | |
| font = ImageFont.load_default() | |
| # Text background | |
| bbox_text = draw.textbbox((x1, y1 - 22), label, font=font) | |
| draw.rectangle([bbox_text[0]-2, bbox_text[1]-2, bbox_text[2]+2, bbox_text[3]+2], fill=color) | |
| draw.text((x1, y1 - 22), label, fill="white", font=font) | |
| return img | |
| def load_benchmark_examples(dataset_name, num_examples=5): | |
| """ | |
| Load examples from benchmark dataset | |
| Returns list of (image_path, question, ground_truth_bbox, ground_truth_answer) | |
| """ | |
| benchmark_file = f"viscot_benchmark/benchmark/{dataset_name}.json" | |
| if not os.path.exists(benchmark_file): | |
| return [] | |
| try: | |
| with open(benchmark_file, 'r') as f: | |
| data = json.load(f) | |
| examples = [] | |
| for item in data[:num_examples]: | |
| # Extract information based on dataset structure | |
| image_file = item.get('image', '') | |
| question = item['conversations'][0]['value'].replace('<image>\n', '').split('Please provide')[0].strip() | |
| gt_bbox_str = item['conversations'][1]['value'] if len(item['conversations']) > 1 else None | |
| gt_answer = item['conversations'][3]['value'] if len(item['conversations']) > 3 else None | |
| examples.append({ | |
| 'image': image_file, | |
| 'question': question, | |
| 'gt_bbox': gt_bbox_str, | |
| 'gt_answer': gt_answer, | |
| 'dataset': dataset_name | |
| }) | |
| return examples | |
| except Exception as e: | |
| print(f"Error loading {dataset_name}: {e}") | |
| return [] | |
| # ============================================================================= | |
| # Main Inference Function (with @spaces.GPU decorator) | |
| # ============================================================================= | |
| # Zero GPU allocation for 120 seconds | |
| def generate_viscot_response(image, question, temperature=0.2, max_tokens=512): | |
| """ | |
| Generate Visual-CoT response with bounding box detection | |
| Args: | |
| image: PIL Image | |
| question: str | |
| temperature: float | |
| max_tokens: int | |
| Returns: | |
| tuple: (bbox_response, final_answer, image_with_bbox, processing_info) | |
| """ | |
| if image is None: | |
| return "❌ Please upload an image!", "", None, "" | |
| if not question.strip(): | |
| return "❌ Please enter a question!", "", None, "" | |
| try: | |
| # Load model (lazy loading) | |
| tokenizer, model, image_processor, context_len = load_model_once() | |
| # Initialize conversation | |
| conv_mode = "llava_v1" | |
| conv = conv_templates[conv_mode].copy() | |
| # ===================================================================== | |
| # STEP 1: Detect Region of Interest (ROI) | |
| # ===================================================================== | |
| prompt_step1 = ( | |
| f"{DEFAULT_IMAGE_TOKEN}\n{question} " | |
| f"Please provide the bounding box coordinate of the region this question asks about." | |
| ) | |
| conv.append_message(conv.roles[0], prompt_step1) | |
| conv.append_message(conv.roles[1], None) | |
| prompt1 = conv.get_prompt() | |
| # Process image | |
| image_tensor = process_images([image], image_processor, model.config) | |
| if isinstance(image_tensor, list): | |
| image_tensor = [img.to(DEVICE, dtype=torch.bfloat16) for img in image_tensor] | |
| else: | |
| image_tensor = image_tensor.to(DEVICE, dtype=torch.bfloat16) | |
| # Tokenize | |
| input_ids = tokenizer_image_token( | |
| prompt1, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(DEVICE) | |
| # Generate bbox | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=temperature > 0.001, | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=128, | |
| use_cache=True, | |
| ) | |
| bbox_response = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| # Parse bbox | |
| bbox = parse_bbox(bbox_response) | |
| # ===================================================================== | |
| # STEP 2: Answer Question with ROI Context | |
| # ===================================================================== | |
| conv.messages[-1][-1] = bbox_response | |
| second_question = ( | |
| f"Please answer the question based on the original image and local detail image. {question}" | |
| ) | |
| conv.append_message(conv.roles[0], second_question) | |
| conv.append_message(conv.roles[1], None) | |
| prompt2 = conv.get_prompt() | |
| input_ids = tokenizer_image_token( | |
| prompt2, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(DEVICE) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=temperature > 0.001, | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=max_tokens, | |
| use_cache=True, | |
| ) | |
| final_answer = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| # Visualization | |
| image_with_bbox = draw_bounding_box(image, bbox) if bbox else image | |
| # Processing info | |
| processing_info = f"✓ Processed successfully | Bbox: {bbox if bbox else 'Not detected'}" | |
| return bbox_response, final_answer, image_with_bbox, processing_info | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" | |
| return error_msg, "", None, error_msg | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| def create_demo(): | |
| """Create Gradio interface""" | |
| # Custom CSS for beautiful UI | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .header { | |
| text-align: center; | |
| padding: 20px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .info-box { | |
| background: #f0f7ff; | |
| border-left: 4px solid #3b82f6; | |
| padding: 15px; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| .example-box { | |
| border: 2px solid #e5e7eb; | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin: 5px 0; | |
| } | |
| .metric-card { | |
| background: white; | |
| border-radius: 8px; | |
| padding: 15px; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
| margin: 10px 0; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="purple", | |
| ), | |
| css=custom_css, | |
| title="Visual-CoT Demo" | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>🌋 Visual-CoT: Chain-of-Thought Reasoning</h1> | |
| <p style="font-size: 18px; margin: 10px 0;"> | |
| Advancing Multi-Modal Language Models with Visual Chain-of-Thought | |
| </p> | |
| <p style="font-size: 14px; opacity: 0.9;"> | |
| 📄 <a href="https://arxiv.org/abs/2403.16999" style="color: white; text-decoration: underline;"> | |
| Paper (NeurIPS 2024 Spotlight) | |
| </a> | | |
| 💻 <a href="https://github.com/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;"> | |
| GitHub | |
| </a> | | |
| 🤗 <a href="https://huggingface.co/datasets/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;"> | |
| Dataset | |
| </a> | |
| </p> | |
| </div> | |
| """) | |
| # Introduction | |
| gr.Markdown(""" | |
| ## 🎯 What is Visual-CoT? | |
| **Visual Chain-of-Thought (VisCoT)** enables AI models to: | |
| - 🎯 **Identify important regions** in images using bounding boxes | |
| - 💭 **Reason step-by-step** like humans (Chain-of-Thought) | |
| - 💡 **Answer questions** about visual content with interpretable explanations | |
| ### 📊 Dataset & Model | |
| - **438K** Q&A pairs with bounding box annotations | |
| - **13 diverse benchmarks** (DocVQA, GQA, TextVQA, etc.) | |
| - **LLaVA-1.5 based** architecture with CLIP ViT-L/14 | |
| """) | |
| with gr.Tabs(): | |
| # ============================================================ | |
| # Tab 1: Interactive Demo | |
| # ============================================================ | |
| with gr.Tab("🎨 Interactive Demo"): | |
| gr.Markdown(""" | |
| ### Try Visual-CoT with Your Own Images! | |
| Upload an image and ask a question. The model will: | |
| 1. **Detect** the region of interest (ROI) → Output bounding box | |
| 2. **Analyze** the ROI and full image → Generate answer | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input | |
| image_input = gr.Image( | |
| type="pil", | |
| label="📸 Upload Image", | |
| height=400, | |
| ) | |
| question_input = gr.Textbox( | |
| label="❓ Your Question", | |
| placeholder="Example: What is unusual about this image?", | |
| lines=3, | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="🌡️ Temperature", | |
| info="0 = Deterministic, 1 = Creative" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="📝 Max Output Tokens" | |
| ) | |
| submit_btn = gr.Button("🚀 Analyze Image", variant="primary", size="lg") | |
| clear_btn = gr.Button("🗑️ Clear", size="sm") | |
| with gr.Column(scale=1): | |
| # Output | |
| gr.Markdown("### 📤 Results") | |
| with gr.Group(): | |
| gr.Markdown("#### 🎯 Step 1: Region Detection") | |
| bbox_output = gr.Textbox( | |
| label="Detected Bounding Box", | |
| lines=2, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 💡 Step 2: Answer") | |
| answer_output = gr.Textbox( | |
| label="Final Answer", | |
| lines=6, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 🖼️ Visualization") | |
| image_output = gr.Image( | |
| label="Image with Bounding Box", | |
| type="pil", | |
| height=350, | |
| ) | |
| info_output = gr.Textbox( | |
| label="Processing Info", | |
| lines=1, | |
| visible=False, | |
| ) | |
| # Example images | |
| gr.Markdown("### 📋 Try These Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/extreme_ironing.jpg", "What is unusual about this image?"], | |
| ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], | |
| ], | |
| inputs=[image_input, question_input], | |
| label="Click to load example", | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=generate_viscot_response, | |
| inputs=[image_input, question_input, temperature, max_tokens], | |
| outputs=[bbox_output, answer_output, image_output, info_output], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", "", "", None, ""), | |
| outputs=[image_input, question_input, bbox_output, answer_output, image_output, info_output], | |
| ) | |
| # ============================================================ | |
| # Tab 2: Benchmark Explorer | |
| # ============================================================ | |
| with gr.Tab("📊 Benchmark Explorer"): | |
| gr.Markdown(""" | |
| ### Explore Visual-CoT Benchmark Examples | |
| Select a benchmark dataset and browse annotated examples from our evaluation suite. | |
| These examples showcase the model's performance across diverse visual reasoning tasks. | |
| """) | |
| with gr.Row(): | |
| dataset_dropdown = gr.Dropdown( | |
| choices=BENCHMARK_DATASETS, | |
| value="gqa", | |
| label="🗂️ Select Benchmark Dataset", | |
| info="Choose from 13 diverse benchmarks" | |
| ) | |
| load_examples_btn = gr.Button("📥 Load Examples", variant="secondary") | |
| benchmark_gallery = gr.Gallery( | |
| label="Benchmark Examples", | |
| columns=3, | |
| height=400, | |
| object_fit="contain", | |
| ) | |
| benchmark_info = gr.Markdown(""" | |
| **Select a dataset and click "Load Examples" to view benchmark samples.** | |
| Available benchmarks: | |
| - **DocVQA**: Document visual question answering | |
| - **GQA**: Scene graph question answering | |
| - **TextVQA**: Text-based VQA | |
| - **Flickr30k**: Image captioning & grounding | |
| - **InfographicsVQA**: Infographic understanding | |
| - **OpenImages**: Object detection & description | |
| - And more... | |
| """) | |
| # Placeholder for benchmark loading (would need actual implementation) | |
| load_examples_btn.click( | |
| fn=lambda x: gr.Info(f"Loading {x} examples... (Feature coming soon!)"), | |
| inputs=[dataset_dropdown], | |
| outputs=None, | |
| ) | |
| # ============================================================ | |
| # Tab 3: About & Paper | |
| # ============================================================ | |
| with gr.Tab("📚 About"): | |
| gr.Markdown(""" | |
| ## 📄 Paper Information | |
| **Title:** Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning | |
| **Authors:** Hao Shao, Shengju Qian, Han Xiao, Guanglu Song, Zhuofan Zong, Letian Wang, Yu Liu, Hongsheng Li | |
| **Conference:** NeurIPS 2024 (Spotlight) 🎉 | |
| **Abstract:** | |
| We introduce Visual-CoT, a comprehensive dataset and benchmark for evaluating chain-of-thought reasoning | |
| in multi-modal language models. Our dataset comprises 438K question-answer pairs with intermediate bounding | |
| box annotations highlighting key regions essential for answering questions. We propose a multi-turn processing | |
| pipeline that dynamically focuses on visual inputs and provides interpretable reasoning steps. | |
| --- | |
| ## 🏗️ Model Architecture | |
| ``` | |
| ┌─────────────────────────────────────┐ | |
| │ Visual-CoT Pipeline │ | |
| ├─────────────────────────────────────┤ | |
| │ │ | |
| │ 📸 Image Input │ | |
| │ ↓ │ | |
| │ 🔍 CLIP ViT-L/14 (Vision Encoder) │ | |
| │ ↓ │ | |
| │ 🔗 MLP Projector (2-layer) │ | |
| │ ↓ │ | |
| │ 🧠 LLaMA/Vicuna (Language Model) │ | |
| │ ↓ │ | |
| │ ┌──────────────┐ │ | |
| │ │ Step 1: ROI │ → Bounding Box │ | |
| │ └──────────────┘ │ | |
| │ ↓ │ | |
| │ ┌──────────────┐ │ | |
| │ │ Step 2: QA │ → Final Answer │ | |
| │ └──────────────┘ │ | |
| │ │ | |
| └─────────────────────────────────────┘ | |
| ``` | |
| --- | |
| ## 📊 Key Results | |
| - **Detection Accuracy**: 75.3% (IoU > 0.5) | |
| - **Answer Accuracy**: 82.7% (GPT-3.5 evaluated) | |
| - **Benchmarks**: State-of-the-art on 10+ visual reasoning tasks | |
| - **Model Sizes**: 7B and 13B parameters | |
| - **Resolutions**: 224px and 336px | |
| --- | |
| ## 🔗 Resources | |
| - 📄 **Paper**: [arXiv:2403.16999](https://arxiv.org/abs/2403.16999) | |
| - 💻 **Code**: [GitHub](https://github.com/deepcs233/Visual-CoT) | |
| - 🤗 **Dataset**: [Hugging Face](https://huggingface.co/datasets/deepcs233/Visual-CoT) | |
| - 🌐 **Project Page**: [https://hao-shao.com/projects/viscot.html](https://hao-shao.com/projects/viscot.html) | |
| - 🎯 **Models**: | |
| - [VisCoT-7b-224](https://huggingface.co/deepcs233/VisCoT-7b-224) | |
| - [VisCoT-7b-336](https://huggingface.co/deepcs233/VisCoT-7b-336) | |
| - [VisCoT-13b-224](https://huggingface.co/deepcs233/VisCoT-13b-224) | |
| - [VisCoT-13b-336](https://huggingface.co/deepcs233/VisCoT-13b-336) | |
| --- | |
| ## 📜 Citation | |
| If you find our work useful, please cite: | |
| ```bibtex | |
| @article{shao2024visual, | |
| title={Visual CoT: Unleashing Chain-of-Thought Reasoning in Multi-Modal Language Models}, | |
| author={Shao, Hao and Qian, Shengju and Xiao, Han and Song, Guanglu and Zong, Zhuofan and Wang, Letian and Liu, Yu and Li, Hongsheng}, | |
| journal={arXiv preprint arXiv:2403.16999}, | |
| year={2024} | |
| } | |
| ``` | |
| --- | |
| ## ⚖️ License | |
| - **Code**: Apache License 2.0 | |
| - **Dataset**: Research use only | |
| - **Models**: Subject to base LLM license (LLaMA) | |
| --- | |
| ## 🙏 Acknowledgements | |
| This work is built upon: | |
| - [LLaVA](https://github.com/haotian-liu/LLaVA) - Base architecture | |
| - [Shikra](https://github.com/shikras/shikra) - Positional annotations | |
| - [Vicuna](https://github.com/lm-sys/FastChat) - Language model | |
| - [CLIP](https://github.com/openai/CLIP) - Vision encoder | |
| """) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| <div style="text-align: center; color: #666; padding: 20px;"> | |
| <p>🚀 Powered by <a href="https://huggingface.co/docs/hub/spaces-zerogpu">Zero GPU</a> on Hugging Face Spaces</p> | |
| <p>Made with ❤️ by the Visual-CoT Team</p> | |
| </div> | |
| """) | |
| return demo | |
| # ============================================================================= | |
| # Launch | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) # Enable queue for Zero GPU | |
| demo.launch() | |