from functools import partial import gradio as gr import torch from transformers import AutoModel, AutoTokenizer import spaces import os import tempfile from PIL import Image, ImageDraw import re # Import thΖ° viện regular expression # --- 1. Load Model and Tokenizer (Done only once at startup) --- print("Loading model and tokenizer...") model_name = "deepseek-ai/DeepSeek-OCR" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Load the model to CPU first; it will be moved to GPU during processing model = AutoModel.from_pretrained( model_name, _attn_implementation="flash_attention_2", trust_remote_code=True, use_safetensors=True, ) model = model.eval() print("βœ… Model loaded successfully.") # --- Helper function to find pre-generated result images --- def find_result_image(path): for filename in os.listdir(path): if "grounding" in filename or "result" in filename: try: image_path = os.path.join(path, filename) return Image.open(image_path) except Exception as e: print(f"Error opening result image {filename}: {e}") return None # --- 2. Main Processing Function (UPDATED for multi-bbox drawing) --- @spaces.GPU def process_ocr_task(image, model_size, ref_text, task_type): """ Processes an image with DeepSeek-OCR for all supported tasks. Now draws ALL detected bounding boxes for ANY task. """ if image is None: return "Please upload an image first.", None print("πŸš€ Moving model to GPU...") model_gpu = model.cuda().to(torch.bfloat16) print("βœ… Model is on GPU.") with tempfile.TemporaryDirectory() as output_path: # Build the prompt... (same as before) if task_type == "πŸ“ Free OCR": prompt = "\nFree OCR." elif task_type == "πŸ“„ Convert to Markdown": prompt = "\n<|grounding|>Convert the document to markdown." elif task_type == "πŸ“ˆ Parse Figure": prompt = "\nParse the figure." elif task_type == "πŸ” Locate Object by Reference": if not ref_text or ref_text.strip() == "": raise gr.Error("For the 'Locate' task, you must provide the reference text to find!") prompt = f"\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image." else: prompt = "\nFree OCR." temp_image_path = os.path.join(output_path, "temp_image.png") image.save(temp_image_path) # Configure model size... (same as before) size_configs = { "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, } config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) print(f"πŸƒ Running inference with prompt: {prompt}") text_result = model_gpu.infer( tokenizer, prompt=prompt, image_file=temp_image_path, output_path=output_path, base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"], save_results=True, test_compress=True, eval_mode=True, ) print(f"====\nπŸ“„ Text Result: {text_result}\n====") # --- NEW LOGIC: Always try to find and draw all bounding boxes --- result_image_pil = None # Define the pattern to find all coordinates like [[280, 15, 696, 997]] pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>") matches = list(pattern.finditer(text_result)) # Use finditer to get all matches if matches: print(f"βœ… Found {len(matches)} bounding box(es). Drawing on the original image.") # Create a copy of the original image to draw on image_with_bboxes = image.copy() # draw = ImageDraw.Draw(image_with_bboxes) w, h = image.size # Get original image dimensions for match in matches: # Extract coordinates as integers coords_norm = [int(c) for c in match.groups()] x1_norm, y1_norm, x2_norm, y2_norm = coords_norm # Scale the normalized coordinates (from 1000x1000 space) to the image's actual size x1 = int(x1_norm / 1000 * w) y1 = int(y1_norm / 1000 * h) x2 = int(x2_norm / 1000 * w) y2 = int(y2_norm / 1000 * h) # Crop the image to the bounding box image_with_bboxes = image_with_bboxes.crop([x1, y1, x2, y2]) result_image_pil = image_with_bboxes else: # If no coordinates are found in the text, fall back to finding a pre-generated image print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.") result_image_pil = find_result_image(output_path) return text_result, result_image_pil # --- 3. Build the Gradio Interface (UPDATED) --- with gr.Blocks(title="Text Extraction Demo", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🐳 Full Demo of DeepSeek-OCR 🐳 Use the tabs below to switch between Free OCR and Locate modes. """ ) with gr.Tabs(): with gr.TabItem("Free OCR"): with gr.Row(): with gr.Column(scale=1): free_image = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Image", sources=["upload", "clipboard"]) free_model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Base", label="βš™οΈ Resolution Size") free_btn = gr.Button("Run Free OCR", variant="primary") with gr.Column(scale=2): free_output_text = gr.Textbox(label="πŸ“„ Text Result", lines=15, show_copy_button=True) free_output_image = gr.Image(label="πŸ–ΌοΈ Image Result (if any)", type="pil") # Wire Free OCR button free_ocr = partial(process_ocr_task, task_type="πŸ“ Free OCR", ref_text="") free_btn.click(fn=free_ocr, inputs=[free_image, free_model_size], outputs=[free_output_text, free_output_image]) with gr.TabItem("Locate"): with gr.Row(): with gr.Column(scale=1): loc_image = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Image", sources=["upload", "clipboard"]) loc_model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Base", label="βš™οΈ Resolution Size") # ref_text_input = gr.Textbox(label="πŸ“ Reference Text (what to locate)", placeholder="e.g., the teacher, 20-10, a red car...") loc_btn = gr.Button("Locate", variant="primary") with gr.Column(scale=2): loc_output_text = gr.Textbox(label="πŸ“„ Text Result", lines=15, show_copy_button=True) loc_output_image = gr.Image(label="πŸ–ΌοΈ Image Result (if any)", type="pil") # Wire Locate button pets_detection = partial(process_ocr_task, task_type="πŸ” Locate Object by Reference", ref_text="pets") loc_btn.click(fn=pets_detection, inputs=[loc_image, loc_model_size], outputs=[loc_output_text, loc_output_image]) # Keep examples (they'll run process_ocr_task directly) - provide a compact examples widget pointing to the free tab inputs gr.Examples( examples=[ ["doc_markdown.png", "Gundam (Recommended)", "", "πŸ“„ Convert to Markdown"], ["chart.png", "Gundam (Recommended)", "", "πŸ“ˆ Parse Figure"], ["teacher.jpg", "Base", "the teacher", "πŸ” Locate Object by Reference"], ["math_locate.jpg", "Small", "20-10", "πŸ” Locate Object by Reference"], ["receipt.jpg", "Base", "", "πŸ“ Free OCR"], ], inputs=[free_image, free_model_size], outputs=[free_output_text, free_output_image], fn=process_ocr_task, cache_examples=False, ) # --- 4. Launch the App --- if __name__ == "__main__": if not os.path.exists("examples"): os.makedirs("examples") # Make sure to have the correct image files in your "examples" folder # e.g., doc_markdown.png, chart.png, teacher.jpg, math_locate.jpg, receipt.jpg demo.queue(max_size=20).launch(share=True)