Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
| import os | |
| import torchvision | |
| import shutil | |
| # --- Setup --- | |
| # Clean caches each restart (helps avoid 50GB limit) | |
| for cache_dir in [ | |
| os.path.expanduser("~/.cache/huggingface"), | |
| os.path.expanduser("~/.cache/torch"), | |
| ]: | |
| shutil.rmtree(cache_dir, ignore_errors=True) | |
| # Force Hugging Face cache to /tmp (ephemeral) | |
| os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" | |
| os.makedirs(os.environ["HF_HUB_CACHE"], exist_ok=True) | |
| # Gradio temp folder | |
| os.environ["GRADIO_TEMP_DIR"] = "tmp" | |
| os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True) | |
| # Handle ZeroGPU safely for local debugging | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): return fn | |
| return decorator | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Lazy Model Loader --- | |
| MODELS = {} | |
| def get_model(selected_model): | |
| """Load model + processor on demand and cache in memory.""" | |
| if selected_model in MODELS: | |
| return MODELS[selected_model] | |
| print(f"Loading {selected_model}...") | |
| if selected_model == "NoctOWLv2-Base": | |
| model = Owlv2ForObjectDetection.from_pretrained( | |
| "lorebianchi98/NoctOWLv2-base-patch16" | |
| ).to(device) | |
| processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16") | |
| elif selected_model == "NoctOWLv2-Large": | |
| model = Owlv2ForObjectDetection.from_pretrained( | |
| "lorebianchi98/NoctOWLv2-large-patch14" | |
| ).to(device) | |
| processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14") | |
| else: | |
| raise gr.Error(f"Unknown model: {selected_model}") | |
| # Cache in memory so re-selections don't re-load from disk | |
| MODELS[selected_model] = (model, processor) | |
| return model, processor | |
| # --- Inference Function --- | |
| def query_image(img, text_queries, score_threshold, selected_model): | |
| if img is None: | |
| raise gr.Error("Please upload or select an example image first.") | |
| if not text_queries.strip(): | |
| raise gr.Error("Please enter at least one text query.") | |
| if selected_model is None or selected_model == "": | |
| raise gr.Error("Please select a model before running inference.") | |
| model, processor = get_model(selected_model) | |
| model = model.to(device) | |
| # Prepare text | |
| text_queries = [f"a {t.strip()}" for t in text_queries.split(",") if t.strip()] | |
| if not text_queries: | |
| raise gr.Error("No valid queries found. Please check your input text.") | |
| # Preprocess | |
| size = max(img.shape[:2]) | |
| target_sizes = torch.Tensor([[size, size]]) | |
| inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Postprocess | |
| outputs.logits = outputs.logits.cpu() | |
| outputs.pred_boxes = outputs.pred_boxes.cpu() | |
| results = processor.post_process_object_detection( | |
| outputs=outputs, target_sizes=target_sizes, threshold=score_threshold | |
| ) | |
| boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] | |
| # Non-Maximum Suppression | |
| keep = torchvision.ops.nms(boxes, scores, iou_threshold=0.5) | |
| boxes, scores, labels = boxes[keep], scores[keep], labels[keep] | |
| # Format output | |
| result_labels = [] | |
| for box, score, label in zip(boxes, scores, labels): | |
| if score < score_threshold: | |
| continue | |
| box = [int(i) for i in box.tolist()] | |
| result_labels.append((box, f"{text_queries[label.item()]} ({score:.2f})")) | |
| return img, result_labels | |
| # --- Interface Description --- | |
| description = """ | |
| # π¦ **NoctOWLv2: Fine-Grained Open-Vocabulary Object Detection** | |
| **NoctOWL** (***N***ot **o**nly **c**oarse-**t**ext **OWL**) extends **OWL-ViT** and **OWLv2** for **Fine-Grained Open-Vocabulary Detection (FG-OVD)**. | |
| It can recognize subtle object differences such as **color, texture, and material**, while retaining strong coarse-grained detection abilities. | |
| **Available Models:** | |
| - π§© **NoctOWLv2-Base** β Smaller and faster. | |
| - π§ **NoctOWLv2-Large** β More accurate, higher capacity. | |
| π [Training & evaluation code](https://github.com/lorebianchi98/FG-OVD/NoctOWL) | |
| """ | |
| # --- Create Interface Layout --- | |
| with gr.Blocks(title="NoctOWLv2 β Fine-Grained Zero-Shot Object Detection") as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image") | |
| text_queries = gr.Textbox( | |
| label="Text Queries (comma-separated)", | |
| placeholder="e.g., red shoes, striped shirt, yellow ball" | |
| ) | |
| score_threshold = gr.Slider( | |
| 0, 1, value=0.1, step=0.01, label="Score Threshold" | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=["NoctOWLv2-Base", "NoctOWLv2-Large"], | |
| label="Select Model", | |
| value=None, | |
| info="Select which model to use for detection", | |
| ) | |
| run_button = gr.Button("π Run Detection", interactive=False) | |
| with gr.Column(): | |
| output_image = gr.AnnotatedImage(label="Detected Objects") | |
| # --- Enable / Disable Run Button --- | |
| def toggle_button(model, text): | |
| return gr.update(interactive=bool(model and text.strip())) | |
| model_dropdown.change( | |
| fn=toggle_button, | |
| inputs=[model_dropdown, text_queries], | |
| outputs=run_button, | |
| ) | |
| text_queries.change( | |
| fn=toggle_button, | |
| inputs=[model_dropdown, text_queries], | |
| outputs=run_button, | |
| ) | |
| # --- Connect Button to Inference --- | |
| run_button.click( | |
| fn=query_image, | |
| inputs=[input_image, text_queries, score_threshold, model_dropdown], | |
| outputs=output_image, | |
| ) | |
| # --- Example Images --- | |
| gr.Examples( | |
| examples=[ | |
| ["assets/desciglio.jpg", "striped football shirt, plain red football shirt, yellow shoes, red shoes", 0.07], | |
| ["assets/pool.jpg", "white ball, blue ball, black ball, yellow ball", 0.1], | |
| ["assets/patio.jpg", "ceramic mug, glass mug, pink flowers, blue flowers", 0.09], | |
| ], | |
| inputs=[input_image, text_queries, score_threshold], | |
| ) | |
| demo.launch() | |