lorebianchi98 commited on
Commit
82c0d2b
·
1 Parent(s): b1322c5

Added on demand model loading

Browse files
Files changed (1) hide show
  1. app.py +41 -16
app.py CHANGED
@@ -3,8 +3,21 @@ import gradio as gr
3
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
  import os
5
  import torchvision
 
6
 
7
  # --- Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
8
  os.environ["GRADIO_TEMP_DIR"] = "tmp"
9
  os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
10
 
@@ -19,22 +32,34 @@ except ImportError:
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # --- Load Models ---
23
- print("Loading models...")
24
- noctowlv2_base = Owlv2ForObjectDetection.from_pretrained(
25
- "lorebianchi98/NoctOWLv2-base-patch16"
26
- ).to(device)
27
- processorv2_base = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
28
 
29
- noctowlv2_large = Owlv2ForObjectDetection.from_pretrained(
30
- "lorebianchi98/NoctOWLv2-large-patch14"
31
- ).to(device)
32
- processorv2_large = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
33
 
34
- MODELS = {
35
- "NoctOWLv2-Base": (noctowlv2_base, processorv2_base),
36
- "NoctOWLv2-Large": (noctowlv2_large, processorv2_large),
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  # --- Inference Function ---
@@ -47,7 +72,7 @@ def query_image(img, text_queries, score_threshold, selected_model):
47
  if selected_model is None or selected_model == "":
48
  raise gr.Error("Please select a model before running inference.")
49
 
50
- model, processor = MODELS[selected_model]
51
  model = model.to(device)
52
 
53
  # Prepare text
@@ -154,7 +179,7 @@ with gr.Blocks(title="NoctOWLv2 — Fine-Grained Zero-Shot Object Detection") as
154
  outputs=output_image,
155
  )
156
 
157
- # --- Example Images (without predefined model) ---
158
  gr.Examples(
159
  examples=[
160
  ["assets/desciglio.jpg", "striped football shirt, plain red football shirt, yellow shoes, red shoes", 0.07],
 
3
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
  import os
5
  import torchvision
6
+ import shutil
7
 
8
  # --- Setup ---
9
+ # Clean caches each restart (helps avoid 50GB limit)
10
+ for cache_dir in [
11
+ os.path.expanduser("~/.cache/huggingface"),
12
+ os.path.expanduser("~/.cache/torch"),
13
+ ]:
14
+ shutil.rmtree(cache_dir, ignore_errors=True)
15
+
16
+ # Force Hugging Face cache to /tmp (ephemeral)
17
+ os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
18
+ os.makedirs(os.environ["HF_HUB_CACHE"], exist_ok=True)
19
+
20
+ # Gradio temp folder
21
  os.environ["GRADIO_TEMP_DIR"] = "tmp"
22
  os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
23
 
 
32
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
+ # --- Lazy Model Loader ---
36
+ MODELS = {}
 
 
 
 
37
 
38
+ def get_model(selected_model):
39
+ """Load model + processor on demand and cache in memory."""
40
+ if selected_model in MODELS:
41
+ return MODELS[selected_model]
42
 
43
+ print(f"Loading {selected_model}...")
44
+
45
+ if selected_model == "NoctOWLv2-Base":
46
+ model = Owlv2ForObjectDetection.from_pretrained(
47
+ "lorebianchi98/NoctOWLv2-base-patch16"
48
+ ).to(device)
49
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
50
+
51
+ elif selected_model == "NoctOWLv2-Large":
52
+ model = Owlv2ForObjectDetection.from_pretrained(
53
+ "lorebianchi98/NoctOWLv2-large-patch14"
54
+ ).to(device)
55
+ processor = Owlv2Processor.from_pretrained("google/owlv2-large-patch14")
56
+
57
+ else:
58
+ raise gr.Error(f"Unknown model: {selected_model}")
59
+
60
+ # Cache in memory so re-selections don't re-load from disk
61
+ MODELS[selected_model] = (model, processor)
62
+ return model, processor
63
 
64
 
65
  # --- Inference Function ---
 
72
  if selected_model is None or selected_model == "":
73
  raise gr.Error("Please select a model before running inference.")
74
 
75
+ model, processor = get_model(selected_model)
76
  model = model.to(device)
77
 
78
  # Prepare text
 
179
  outputs=output_image,
180
  )
181
 
182
+ # --- Example Images ---
183
  gr.Examples(
184
  examples=[
185
  ["assets/desciglio.jpg", "striped football shirt, plain red football shirt, yellow shoes, red shoes", 0.07],