NoctOWL / app.py
lorebianchi98's picture
Added on demand model loading
82c0d2b
raw
history blame
6.4 kB
import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import os
import torchvision
import shutil
# --- Setup ---
# Clean caches each restart (helps avoid 50GB limit)
for cache_dir in [
os.path.expanduser("~/.cache/huggingface"),
os.path.expanduser("~/.cache/torch"),
]:
shutil.rmtree(cache_dir, ignore_errors=True)
# Force Hugging Face cache to /tmp (ephemeral)
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
os.makedirs(os.environ["HF_HUB_CACHE"], exist_ok=True)
# Gradio temp folder
os.environ["GRADIO_TEMP_DIR"] = "tmp"
os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
# Handle ZeroGPU safely for local debugging
try:
import spaces
except ImportError:
class spaces:
def GPU(*args, **kwargs):
def decorator(fn): return fn
return decorator
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Lazy Model Loader ---
MODELS = {}
def get_model(selected_model):
"""Load model + processor on demand and cache in memory."""
if selected_model in MODELS:
return MODELS[selected_model]
print(f"Loading {selected_model}...")
if selected_model == "NoctOWLv2-Base":
model = Owlv2ForObjectDetection.from_pretrained(
"lorebianchi98/NoctOWLv2-base-patch16"
).to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
elif selected_model == "NoctOWLv2-Large":
model = Owlv2ForObjectDetection.from_pretrained(
"lorebianchi98/NoctOWLv2-large-patch14"
).to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
else:
raise gr.Error(f"Unknown model: {selected_model}")
# Cache in memory so re-selections don't re-load from disk
MODELS[selected_model] = (model, processor)
return model, processor
# --- Inference Function ---
@spaces.GPU(duration=120)
def query_image(img, text_queries, score_threshold, selected_model):
if img is None:
raise gr.Error("Please upload or select an example image first.")
if not text_queries.strip():
raise gr.Error("Please enter at least one text query.")
if selected_model is None or selected_model == "":
raise gr.Error("Please select a model before running inference.")
model, processor = get_model(selected_model)
model = model.to(device)
# Prepare text
text_queries = [f"a {t.strip()}" for t in text_queries.split(",") if t.strip()]
if not text_queries:
raise gr.Error("No valid queries found. Please check your input text.")
# Preprocess
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Postprocess
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=score_threshold
)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
# Non-Maximum Suppression
keep = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
# Format output
result_labels = []
for box, score, label in zip(boxes, scores, labels):
if score < score_threshold:
continue
box = [int(i) for i in box.tolist()]
result_labels.append((box, f"{text_queries[label.item()]} ({score:.2f})"))
return img, result_labels
# --- Interface Description ---
description = """
# πŸ¦‰ **NoctOWLv2: Fine-Grained Open-Vocabulary Object Detection**
**NoctOWL** (***N***ot **o**nly **c**oarse-**t**ext **OWL**) extends **OWL-ViT** and **OWLv2** for **Fine-Grained Open-Vocabulary Detection (FG-OVD)**.
It can recognize subtle object differences such as **color, texture, and material**, while retaining strong coarse-grained detection abilities.
**Available Models:**
- 🧩 **NoctOWLv2-Base** β€” Smaller and faster.
- 🧠 **NoctOWLv2-Large** β€” More accurate, higher capacity.
πŸ“˜ [Training & evaluation code](https://github.com/lorebianchi98/FG-OVD/NoctOWL)
"""
# --- Create Interface Layout ---
with gr.Blocks(title="NoctOWLv2 β€” Fine-Grained Zero-Shot Object Detection") as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image")
text_queries = gr.Textbox(
label="Text Queries (comma-separated)",
placeholder="e.g., red shoes, striped shirt, yellow ball"
)
score_threshold = gr.Slider(
0, 1, value=0.1, step=0.01, label="Score Threshold"
)
model_dropdown = gr.Dropdown(
choices=["NoctOWLv2-Base", "NoctOWLv2-Large"],
label="Select Model",
value=None,
info="Select which model to use for detection",
)
run_button = gr.Button("πŸš€ Run Detection", interactive=False)
with gr.Column():
output_image = gr.AnnotatedImage(label="Detected Objects")
# --- Enable / Disable Run Button ---
def toggle_button(model, text):
return gr.update(interactive=bool(model and text.strip()))
model_dropdown.change(
fn=toggle_button,
inputs=[model_dropdown, text_queries],
outputs=run_button,
)
text_queries.change(
fn=toggle_button,
inputs=[model_dropdown, text_queries],
outputs=run_button,
)
# --- Connect Button to Inference ---
run_button.click(
fn=query_image,
inputs=[input_image, text_queries, score_threshold, model_dropdown],
outputs=output_image,
)
# --- Example Images ---
gr.Examples(
examples=[
["assets/desciglio.jpg", "striped football shirt, plain red football shirt, yellow shoes, red shoes", 0.07],
["assets/pool.jpg", "white ball, blue ball, black ball, yellow ball", 0.1],
["assets/patio.jpg", "ceramic mug, glass mug, pink flowers, blue flowers", 0.09],
],
inputs=[input_image, text_queries, score_threshold],
)
demo.launch()