zixinz commited on
Commit
a01e858
·
1 Parent(s): 134053b

chore: ignore pyc and __pycache__

Browse files
Files changed (1) hide show
  1. app.py +220 -119
app.py CHANGED
@@ -4,6 +4,7 @@ import sys, pathlib
4
  BASE_DIR = pathlib.Path(__file__).resolve().parent
5
  LOCAL_DIFFUSERS_SRC = BASE_DIR / "code_edit" / "diffusers" / "src"
6
 
 
7
  if (LOCAL_DIFFUSERS_SRC / "diffusers").exists():
8
  sys.path.insert(0, str(LOCAL_DIFFUSERS_SRC))
9
  else:
@@ -20,11 +21,9 @@ from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_versio
20
  # ===========================================================================
21
 
22
  import os
23
- import sys
24
- import pathlib
25
  import subprocess
26
  import random
27
- from typing import Optional, Tuple
28
 
29
  import torch
30
  from PIL import Image, ImageOps
@@ -44,19 +43,19 @@ EXPECTED_ASSETS = [
44
  BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
45
  ]
46
 
47
- # import depth helper
48
  if str(CODE_DEPTH) not in sys.path:
49
  sys.path.insert(0, str(CODE_DEPTH))
50
  from depth_infer import DepthModel # noqa: E402
51
 
52
- # import your custom diffusers
53
  if str(CODE_EDIT / "diffusers") not in sys.path:
54
  sys.path.insert(0, str(CODE_EDIT / "diffusers"))
55
  from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402
56
  FluxFillPipeline_token12_depth_only as FluxFillPipeline,
57
  )
58
 
59
- # ---------------- Assets ensure (on-demand) ----------------
60
  def _have_all_assets() -> bool:
61
  return all(p.is_file() for p in EXPECTED_ASSETS)
62
 
@@ -66,6 +65,10 @@ def _ensure_executable(p: pathlib.Path):
66
  os.chmod(p, os.stat(p).st_mode | 0o111)
67
 
68
  def ensure_assets_if_missing():
 
 
 
 
69
  if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
70
  print("↪️ SKIP_ASSET_DOWNLOAD=1 -> skip asset download check")
71
  return
@@ -91,7 +94,7 @@ except Exception as e:
91
  print(f"⚠️ Asset prepare failed: {e}")
92
 
93
  # ---------------- Global singletons ----------------
94
- _MODELS: dict[str, DepthModel] = {}
95
  _PIPE: Optional[FluxFillPipeline] = None
96
  # ==== STAGE-2 ONLY ADDED: singleton ====
97
  _PIPE_STAGE2: Optional[FluxFillPipelineStage2] = None
@@ -103,6 +106,9 @@ def get_model(encoder: str) -> DepthModel:
103
  return _MODELS[encoder]
104
 
105
  def get_pipe() -> FluxFillPipeline:
 
 
 
106
  global _PIPE
107
  if _PIPE is not None:
108
  return _PIPE
@@ -124,11 +130,8 @@ def get_pipe() -> FluxFillPipeline:
124
  print(f"[pipe] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
125
  try:
126
  if use_local:
127
- pipe = FluxFillPipeline.from_pretrained(
128
- local_flux, torch_dtype=dtype
129
- ).to(device)
130
  else:
131
- # Fetch online (requires gated access + token)
132
  pipe = FluxFillPipeline.from_pretrained(
133
  "black-forest-labs/FLUX.1-Fill-dev",
134
  torch_dtype=dtype,
@@ -137,31 +140,28 @@ def get_pipe() -> FluxFillPipeline:
137
  except Exception as e:
138
  raise RuntimeError(
139
  "Failed to load FLUX.1-Fill-dev. "
140
- "Make sure your account has access to the gated repo and HF_TOKEN is set as a Space secret, "
141
- "or pre-download to a local cache directory."
142
  ) from e
143
 
144
- # -------- LoRA (stage1) --------
145
  lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
146
- lora_file = "pytorch_lora_weights.safetensors" # your actual file name
147
  adapter_name = "stage1"
148
 
149
  if lora_dir.exists():
150
  try:
151
- import peft # assert backend is present
152
  print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
153
  pipe.load_lora_weights(
154
  str(lora_dir),
155
- weight_name=lora_file, # important: specify filename
156
- adapter_name=adapter_name # a switchable name
157
  )
158
- # Newer diffusers prefer set_adapters
159
  try:
160
  pipe.set_adapters(adapter_name, scale=1.0)
161
- print(f"[pipe] set_adapters('{adapter_name}', scale=1.0)")
162
  except Exception as e_set:
163
  print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
164
- # Older / pipelines without set_adapters: fuse LoRA
165
  try:
166
  pipe.fuse_lora(lora_scale=1.0)
167
  print("[pipe] fuse_lora(lora_scale=1.0) done")
@@ -169,7 +169,7 @@ def get_pipe() -> FluxFillPipeline:
169
  print(f"[pipe] fuse_lora failed: {e_fuse}")
170
  print("[pipe] LoRA ready ✅")
171
  except ImportError:
172
- print("[pipe] peft not installed; LoRA will be skipped (add `peft>=0.11` to requirements).")
173
  except Exception as e:
174
  print(f"[pipe] load_lora_weights failed (continue without): {e}")
175
  else:
@@ -181,7 +181,7 @@ def get_pipe() -> FluxFillPipeline:
181
  # ==== STAGE-2 ONLY ADDED: Stage-2 loader (no change to Stage-1 logic) ====
182
  def get_pipe_stage2() -> FluxFillPipelineStage2:
183
  """
184
- Load Stage-2 FluxFillPipeline_token12_depth and mount the Stage-2 LoRA.
185
  """
186
  global _PIPE_STAGE2
187
  if _PIPE_STAGE2 is not None:
@@ -230,16 +230,13 @@ def get_pipe_stage2() -> FluxFillPipelineStage2:
230
  raise RuntimeError(f"Stage-2 LoRA dir not found: {lora_dir2}")
231
  if weight_name is None:
232
  raise RuntimeError(
233
- f"Stage-2 LoRA weight not found under {lora_dir2}. "
234
- f"Tried: {candidate_names}"
235
  )
236
 
237
  try:
238
  import peft # noqa: F401
239
  except Exception as e:
240
- raise RuntimeError(
241
- "peft is not installed (requires peft>=0.11 to load LoRA)."
242
- ) from e
243
 
244
  try:
245
  print(f"[stage2] loading LoRA: {lora_dir2}/{weight_name}")
@@ -272,15 +269,15 @@ def to_grayscale_mask(im: Image.Image) -> Image.Image:
272
  Output: white = region to remove/fill, black = keep.
273
  """
274
  if im.mode == "RGBA":
275
- mask = im.split()[-1] # alpha as mask
276
  else:
277
  mask = im.convert("L")
278
- # simple binarization & denoise
279
  mask = mask.point(lambda p: 255 if p > 16 else 0)
280
- return mask # do not invert; white = mask region
281
 
282
  def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
283
- """Dilate white region by ~px pixels."""
284
  if px <= 0:
285
  return mask_l
286
  arr = np.array(mask_l, dtype=np.uint8)
@@ -291,15 +288,12 @@ def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
291
 
292
  def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
293
  """
294
- Extract "pure red strokes" as a binary mask (white=brush, black=others) from an RGBA/RGB image.
295
- Thresholds are a bit lenient to tolerate compression/resampling.
296
  """
297
  arr = np.array(img.convert("RGBA"))
298
  r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
299
-
300
- # condition: high red, low green/blue, and alpha>0
301
  red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
302
-
303
  mask = (red_hit.astype(np.uint8) * 255)
304
  m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
305
  return m
@@ -311,9 +305,9 @@ def pick_mask(
311
  dilate_px: int = 0,
312
  ) -> Optional[Image.Image]:
313
  """
314
- Rules:
315
- 1) If user uploaded a mask: use it directly (white=mask)
316
- 2) Otherwise, from ImageEditor output, only recognize "red strokes" as mask:
317
  - Try sketch_data['mask'] first (some versions provide it)
318
  - Else merge red strokes from sketch_data['layers'][*]['image']
319
  - If still none, try sketch_data['composite'] for red strokes
@@ -342,8 +336,7 @@ def pick_mask(
342
  li = lyr.get("image") or lyr.get("mask")
343
  if isinstance(li, Image.Image):
344
  m_layer = _mask_from_red(li, base_image.size)
345
- # merge: any layer with strokes contributes to mask
346
- acc = ImageOps.lighter(acc, m_layer)
347
  if acc.getbbox() is not None:
348
  return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
349
 
@@ -354,10 +347,9 @@ def pick_mask(
354
  if m_comp.getbbox() is not None:
355
  return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
356
 
357
- # 3) still none -> return None (caller will prompt for a mask)
358
  return None
359
 
360
-
361
  def _round_mult64(x: float, mode: str = "nearest") -> int:
362
  """
363
  Align x to a multiple of 64:
@@ -375,17 +367,14 @@ def _round_mult64(x: float, mode: str = "nearest") -> int:
375
  def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
376
  """
377
  Steps:
378
- 1) First round w,h up to multiples of 64 (avoid too-small sizes)
379
  2) Fix the long side to target_max (default 1024)
380
- 3) Scale the short side proportionally and align to a multiple of 64 (at least 64)
381
  """
382
  w, h = img.size
383
-
384
- # 1) round each up to multiple of 64
385
  w1 = max(64, _round_mult64(w, mode="ceil"))
386
  h1 = max(64, _round_mult64(h, mode="ceil"))
387
 
388
- # 2) fix long side to target_max; scale short side
389
  if w1 >= h1:
390
  out_w = target_max
391
  scaled_h = h1 * (target_max / w1)
@@ -403,7 +392,6 @@ def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, inpu
403
  if image is None:
404
  return None
405
  dm = get_model(encoder)
406
- # colored visualization (RGB), consistent with your previous colormap style
407
  d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
408
  return d_rgb
409
 
@@ -411,10 +399,9 @@ def prepare_canvas(image, depth_img, source):
411
  base = depth_img if source == "depth" else image
412
  if base is None:
413
  raise gr.Error('Please upload an image (and wait for the depth preview), then click "Prepare canvas".')
414
- # Use a generic gr.update to set ImageEditor value
415
  return gr.update(value=base)
416
 
417
- # ---------------- Two-stage pipeline: depth(color) -> fill ----------------
418
  @spaces.GPU
419
  def run_depth_and_fill(
420
  image: Image.Image,
@@ -439,14 +426,14 @@ def run_depth_and_fill(
439
  depth_rgb: Image.Image = depth_model.infer(
440
  image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
441
  ).convert("RGB")
442
-
443
  print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
444
 
445
  # 2) extract mask (uploaded > drawn)
446
  mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
447
  if (mask_l is None) or (mask_l.getbbox() is None):
448
- raise gr.Error("No valid mask detected: please draw on the canvas or upload a mask image.")
449
-
450
  print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
451
 
452
  # 3) decide output size
@@ -454,14 +441,17 @@ def run_depth_and_fill(
454
  orig_w, orig_h = image.size
455
  print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
456
 
457
- # 4) run FLUX pipeline
458
- # Key fix: pass depth_rgb as `image` instead of the original image
459
  pipe = get_pipe()
460
- generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
 
 
 
 
461
 
462
  result = pipe(
463
  prompt=prompt,
464
- image=depth_rgb, # FIX: pass the colored depth map, not the original image
465
  mask_image=mask_l,
466
  width=width,
467
  height=height,
@@ -469,11 +459,11 @@ def run_depth_and_fill(
469
  num_inference_steps=int(steps),
470
  max_sequence_length=512,
471
  generator=generator,
472
- depth=depth_rgb, # also feed depth input (colored depth)
473
  ).images[0]
474
 
475
  final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
476
-
477
  # return result and mask preview
478
  mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
479
  return final_result, mask_preview
@@ -482,21 +472,20 @@ def _to_pil_rgb(img_like) -> Image.Image:
482
  """Normalize input to PIL RGB. Supports PIL/L/RGBA/np.array."""
483
  if isinstance(img_like, Image.Image):
484
  return img_like.convert("RGB")
485
- # numpy array -> PIL
486
  try:
487
  arr = np.array(img_like)
488
- if arr.ndim == 2: # grayscale
489
  arr = np.stack([arr, arr, arr], axis=-1)
490
  return Image.fromarray(arr.astype(np.uint8), mode="RGB")
491
  except Exception:
492
- raise gr.Error("Stage-2: `depth` / `depth_image` is not a valid image. Please check the provided objects.")
493
 
494
- # ==== STAGE-2 ONLY ADDED: Stage-2 inference (takes Stage-1 output + Stage-1 depth preview) ====
495
  @spaces.GPU
496
  def run_stage2_refine(
497
  image: Image.Image, # original image (RGB)
498
  stage1_out: Image.Image, # output from Stage-1
499
- depth_img_from_stage1_input: Image.Image, # ★ new: Stage-1 depth preview (from UI)
500
  mask_upload: Optional[Image.Image],
501
  sketch: Optional[dict],
502
  prompt: str,
@@ -510,34 +499,38 @@ def run_stage2_refine(
510
  seed: Optional[int],
511
  ) -> Image.Image:
512
  if image is None or stage1_out is None:
513
- raise gr.Error("Please complete Stage-1 generation first (needs original image and Stage-1 output).")
514
 
515
- # allow refine without mask (use all-black -> no masked area)
516
  mask_l = pick_mask(mask_upload, sketch, image, dilate_px=0)
517
  if (mask_l is None) or (mask_l.getbbox() is None):
518
  mask_l = Image.new("L", image.size, 0)
519
 
520
- # unify sizes (based on original image)
521
  width, height = prepare_size_for_flux(image, target_max=max_side)
522
  orig_w, orig_h = image.size
523
 
524
  pipe2 = get_pipe_stage2()
525
- g2 = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) \
 
 
526
  else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
 
527
  depth_pil = _to_pil_rgb(stage1_out) # for `depth`
528
  depth_image_pil = _to_pil_rgb(depth_img_from_stage1_input) # for `depth_image`
529
- image_rgb = _to_pil_rgb(image) # normalize original image to RGB
530
 
531
- # resize to (width, height)
532
  depth_pil = depth_pil.resize((width, height), Image.BICUBIC)
533
  depth_image_pil = depth_image_pil.resize((width, height), Image.BICUBIC)
534
- # ★★ Mapping:
535
- # - image = original RGB
536
- # - depth = Stage-1 output (treated as updated geometry)
537
- # - depth_image = Stage-1 input depth (UI's depth preview)
 
538
  out2 = pipe2(
539
  prompt=prompt,
540
- image=image, # original RGB
541
  mask_image=mask_l,
542
  width=width,
543
  height=height,
@@ -545,99 +538,207 @@ def run_stage2_refine(
545
  num_inference_steps=int(steps),
546
  max_sequence_length=512,
547
  generator=g2,
548
- depth=depth_pil, # ← Stage-1 output as `depth`
549
- depth_image=depth_image_pil, # ← Stage-1 depth preview as `depth_image`
550
  ).images[0]
551
 
552
- out2 = out2.resize((orig_w * 3, orig_h), Image.BICUBIC) # preserve your original ×3 display layout
553
  return out2
554
 
555
- # ===================================================================
556
-
557
  # ---------------- UI ----------------
558
  with gr.Blocks() as demo:
559
- gr.Markdown("## GeoRemover · Depth Removal (Depth (colored) → FLUX Fill)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
  with gr.Row():
562
  with gr.Column(scale=1):
563
- # input image
564
- img = gr.Image(label="Upload image", type="pil")
 
 
 
565
 
566
  # Mask: upload or draw
567
  with gr.Tab("Upload mask"):
568
- mask_upload = gr.Image(label="Mask (optional)", type="pil")
 
 
 
569
 
570
  with gr.Tab("Draw mask"):
571
- draw_source = gr.Radio(["image", "depth"], value="image", label="Draw on")
572
- prepare_btn = gr.Button("Prepare canvas")
 
 
 
 
 
 
 
 
 
 
 
 
573
  sketch = gr.ImageEditor(
574
- label="Sketch mask (draw with brush)",
575
  type="pil",
576
- # Provide red-only brush for precise extraction of strokes
577
- brush=gr.Brush(colors=["#FF0000"], default_size=24)
578
  )
579
 
580
- # prompt
581
- prompt = gr.Textbox(label="Prompt", value="A beautiful scene")
 
 
 
 
582
 
583
- # tunables
584
  with gr.Accordion("Advanced (Depth & FLUX)", open=False):
585
- encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder")
586
- max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res")
587
- input_size = gr.Slider(256, 1024, value=518, step=2, label="Depth: input_size")
588
- fp32 = gr.Checkbox(False, label="Depth: use FP32 (default FP16)")
589
- max_side = gr.Slider(512, 1536, value=1024, step=64, label="FLUX: max side (px)")
590
- mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)")
591
- guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale")
592
- steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps")
593
- seed = gr.Number(value=0, precision=0, label="Seed (>=0 fixed; empty = random)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
  run_btn = gr.Button("Run", variant="primary")
596
- # ==== STAGE-2 ONLY ADDED: add Stage-2 button ====
597
- run_btn_stage2 = gr.Button("Run Stage-2 (Refine)", variant="secondary")
598
- # =================================================
599
 
600
  with gr.Column(scale=1):
601
- depth_preview = gr.Image(label="Depth preview (colored)", interactive=False)
602
- mask_preview = gr.Image(label="Mask preview (to be removed)", interactive=False)
603
- out = gr.Image(label="Output")
604
- # ==== STAGE-2 ONLY ADDED: Stage-2 output ====
605
- out_stage2 = gr.Image(label="Output (Stage-2 refine)")
606
- # ============================================
607
-
608
- # Event: when image changes, compute the colored depth preview
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  img.change(
610
  fn=preview_depth,
611
  inputs=[img, encoder, max_res, input_size, fp32],
612
  outputs=[depth_preview],
613
  )
614
 
615
- # Prepare canvas: put original image or colored depth image into ImageEditor
616
  prepare_btn.click(
617
  fn=prepare_canvas,
618
  inputs=[img, depth_preview, draw_source],
619
  outputs=[sketch],
620
  )
621
 
622
- # Run Stage-1 (wiring unchanged)
623
  run_btn.click(
624
  fn=run_depth_and_fill,
625
  inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
626
  max_side, mask_dilate_px, guidance_scale, steps, seed],
627
  outputs=[out, mask_preview],
628
  api_name="run",
 
 
 
 
629
  )
630
 
631
- # ==== STAGE-2 ONLY ADDED: run after Stage-1 has produced a result ====
632
  run_btn_stage2.click(
633
  fn=run_stage2_refine,
634
- inputs=[img, out, depth_preview, # ← pass depth_preview as the 3rd input to Stage-2
635
  mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
636
  max_side, guidance_scale, steps, seed],
637
  outputs=[out_stage2],
638
  api_name="run_stage2",
639
  )
640
- # ====================================================================
641
 
642
  if __name__ == "__main__":
643
  os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
 
4
  BASE_DIR = pathlib.Path(__file__).resolve().parent
5
  LOCAL_DIFFUSERS_SRC = BASE_DIR / "code_edit" / "diffusers" / "src"
6
 
7
+ # Ensure local diffusers is importable
8
  if (LOCAL_DIFFUSERS_SRC / "diffusers").exists():
9
  sys.path.insert(0, str(LOCAL_DIFFUSERS_SRC))
10
  else:
 
21
  # ===========================================================================
22
 
23
  import os
 
 
24
  import subprocess
25
  import random
26
+ from typing import Optional, Tuple, Dict, Any
27
 
28
  import torch
29
  from PIL import Image, ImageOps
 
43
  BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
44
  ]
45
 
46
+ # Import depth helper
47
  if str(CODE_DEPTH) not in sys.path:
48
  sys.path.insert(0, str(CODE_DEPTH))
49
  from depth_infer import DepthModel # noqa: E402
50
 
51
+ # Import your custom diffusers (local fork)
52
  if str(CODE_EDIT / "diffusers") not in sys.path:
53
  sys.path.insert(0, str(CODE_EDIT / "diffusers"))
54
  from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402
55
  FluxFillPipeline_token12_depth_only as FluxFillPipeline,
56
  )
57
 
58
+ # ---------------- Asset preparation (on-demand) ----------------
59
  def _have_all_assets() -> bool:
60
  return all(p.is_file() for p in EXPECTED_ASSETS)
61
 
 
65
  os.chmod(p, os.stat(p).st_mode | 0o111)
66
 
67
  def ensure_assets_if_missing():
68
+ """
69
+ If SKIP_ASSET_DOWNLOAD=1 -> skip checks.
70
+ Otherwise ensure checkpoints/LoRAs exist; if missing, run get_assets.sh.
71
+ """
72
  if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
73
  print("↪️ SKIP_ASSET_DOWNLOAD=1 -> skip asset download check")
74
  return
 
94
  print(f"⚠️ Asset prepare failed: {e}")
95
 
96
  # ---------------- Global singletons ----------------
97
+ _MODELS: Dict[str, DepthModel] = {}
98
  _PIPE: Optional[FluxFillPipeline] = None
99
  # ==== STAGE-2 ONLY ADDED: singleton ====
100
  _PIPE_STAGE2: Optional[FluxFillPipelineStage2] = None
 
106
  return _MODELS[encoder]
107
 
108
  def get_pipe() -> FluxFillPipeline:
109
+ """
110
+ Load Stage-1 pipeline (FluxFillPipeline_token12_depth_only) and mount Stage-1 LoRA if present.
111
+ """
112
  global _PIPE
113
  if _PIPE is not None:
114
  return _PIPE
 
130
  print(f"[pipe] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
131
  try:
132
  if use_local:
133
+ pipe = FluxFillPipeline.from_pretrained(local_flux, torch_dtype=dtype).to(device)
 
 
134
  else:
 
135
  pipe = FluxFillPipeline.from_pretrained(
136
  "black-forest-labs/FLUX.1-Fill-dev",
137
  torch_dtype=dtype,
 
140
  except Exception as e:
141
  raise RuntimeError(
142
  "Failed to load FLUX.1-Fill-dev. "
143
+ "Ensure gated access and HF_TOKEN; or pre-download to local cache."
 
144
  ) from e
145
 
146
+ # -------- LoRA (Stage-1) --------
147
  lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
148
+ lora_file = "pytorch_lora_weights.safetensors"
149
  adapter_name = "stage1"
150
 
151
  if lora_dir.exists():
152
  try:
153
+ import peft # assert backend presence
154
  print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
155
  pipe.load_lora_weights(
156
  str(lora_dir),
157
+ weight_name=lora_file,
158
+ adapter_name=adapter_name,
159
  )
 
160
  try:
161
  pipe.set_adapters(adapter_name, scale=1.0)
162
+ print(f"[pipe] set_adapters('{adapter_name}', 1.0)")
163
  except Exception as e_set:
164
  print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
 
165
  try:
166
  pipe.fuse_lora(lora_scale=1.0)
167
  print("[pipe] fuse_lora(lora_scale=1.0) done")
 
169
  print(f"[pipe] fuse_lora failed: {e_fuse}")
170
  print("[pipe] LoRA ready ✅")
171
  except ImportError:
172
+ print("[pipe] peft not installed; LoRA skipped (add `peft>=0.11`).")
173
  except Exception as e:
174
  print(f"[pipe] load_lora_weights failed (continue without): {e}")
175
  else:
 
181
  # ==== STAGE-2 ONLY ADDED: Stage-2 loader (no change to Stage-1 logic) ====
182
  def get_pipe_stage2() -> FluxFillPipelineStage2:
183
  """
184
+ Load Stage-2 FluxFillPipeline_token12_depth and mount Stage-2 LoRA.
185
  """
186
  global _PIPE_STAGE2
187
  if _PIPE_STAGE2 is not None:
 
230
  raise RuntimeError(f"Stage-2 LoRA dir not found: {lora_dir2}")
231
  if weight_name is None:
232
  raise RuntimeError(
233
+ f"Stage-2 LoRA weight not found under {lora_dir2}. Tried: {candidate_names}"
 
234
  )
235
 
236
  try:
237
  import peft # noqa: F401
238
  except Exception as e:
239
+ raise RuntimeError("peft is not installed (requires peft>=0.11).") from e
 
 
240
 
241
  try:
242
  print(f"[stage2] loading LoRA: {lora_dir2}/{weight_name}")
 
269
  Output: white = region to remove/fill, black = keep.
270
  """
271
  if im.mode == "RGBA":
272
+ mask = im.split()[-1] # alpha as mask
273
  else:
274
  mask = im.convert("L")
275
+ # Simple binarization & denoise
276
  mask = mask.point(lambda p: 255 if p > 16 else 0)
277
+ return mask # Do not invert; white = mask region
278
 
279
  def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
280
+ """Dilate the white region by ~px pixels."""
281
  if px <= 0:
282
  return mask_l
283
  arr = np.array(mask_l, dtype=np.uint8)
 
288
 
289
  def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
290
  """
291
+ Extract "pure red strokes" as a binary mask (white=brush, black=others) from RGBA/RGB.
292
+ Thresholds are lenient to tolerate compression/resampling.
293
  """
294
  arr = np.array(img.convert("RGBA"))
295
  r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
 
 
296
  red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
 
297
  mask = (red_hit.astype(np.uint8) * 255)
298
  m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
299
  return m
 
305
  dilate_px: int = 0,
306
  ) -> Optional[Image.Image]:
307
  """
308
+ Selection rules:
309
+ 1) If a mask is uploaded: use it directly (white=mask)
310
+ 2) Else from ImageEditor output, only red strokes are recognized as mask:
311
  - Try sketch_data['mask'] first (some versions provide it)
312
  - Else merge red strokes from sketch_data['layers'][*]['image']
313
  - If still none, try sketch_data['composite'] for red strokes
 
336
  li = lyr.get("image") or lyr.get("mask")
337
  if isinstance(li, Image.Image):
338
  m_layer = _mask_from_red(li, base_image.size)
339
+ acc = ImageOps.lighter(acc, m_layer) # union
 
340
  if acc.getbbox() is not None:
341
  return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
342
 
 
347
  if m_comp.getbbox() is not None:
348
  return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
349
 
350
+ # 3) No valid mask
351
  return None
352
 
 
353
  def _round_mult64(x: float, mode: str = "nearest") -> int:
354
  """
355
  Align x to a multiple of 64:
 
367
  def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
368
  """
369
  Steps:
370
+ 1) Round w,h up to multiples of 64 (avoid too-small sizes)
371
  2) Fix the long side to target_max (default 1024)
372
+ 3) Scale the short side proportionally and align to a multiple of 64 (>= 64)
373
  """
374
  w, h = img.size
 
 
375
  w1 = max(64, _round_mult64(w, mode="ceil"))
376
  h1 = max(64, _round_mult64(h, mode="ceil"))
377
 
 
378
  if w1 >= h1:
379
  out_w = target_max
380
  scaled_h = h1 * (target_max / w1)
 
392
  if image is None:
393
  return None
394
  dm = get_model(encoder)
 
395
  d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
396
  return d_rgb
397
 
 
399
  base = depth_img if source == "depth" else image
400
  if base is None:
401
  raise gr.Error('Please upload an image (and wait for the depth preview), then click "Prepare canvas".')
 
402
  return gr.update(value=base)
403
 
404
+ # ---------------- Stage-1: depth(color) -> fill ----------------
405
  @spaces.GPU
406
  def run_depth_and_fill(
407
  image: Image.Image,
 
426
  depth_rgb: Image.Image = depth_model.infer(
427
  image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
428
  ).convert("RGB")
429
+
430
  print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
431
 
432
  # 2) extract mask (uploaded > drawn)
433
  mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
434
  if (mask_l is None) or (mask_l.getbbox() is None):
435
+ raise gr.Error("No valid mask detected: please draw with the red brush or upload a binary mask.")
436
+
437
  print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
438
 
439
  # 3) decide output size
 
441
  orig_w, orig_h = image.size
442
  print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
443
 
444
+ # 4) run FLUX pipeline (key: use depth_rgb as both image and depth input)
 
445
  pipe = get_pipe()
446
+ generator = (
447
+ torch.Generator("cpu").manual_seed(int(seed))
448
+ if (seed is not None and seed >= 0)
449
+ else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
450
+ )
451
 
452
  result = pipe(
453
  prompt=prompt,
454
+ image=depth_rgb, # use the colored depth map instead of original image
455
  mask_image=mask_l,
456
  width=width,
457
  height=height,
 
459
  num_inference_steps=int(steps),
460
  max_sequence_length=512,
461
  generator=generator,
462
+ depth=depth_rgb, # feed depth (colored)
463
  ).images[0]
464
 
465
  final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
466
+
467
  # return result and mask preview
468
  mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
469
  return final_result, mask_preview
 
472
  """Normalize input to PIL RGB. Supports PIL/L/RGBA/np.array."""
473
  if isinstance(img_like, Image.Image):
474
  return img_like.convert("RGB")
 
475
  try:
476
  arr = np.array(img_like)
477
+ if arr.ndim == 2:
478
  arr = np.stack([arr, arr, arr], axis=-1)
479
  return Image.fromarray(arr.astype(np.uint8), mode="RGB")
480
  except Exception:
481
+ raise gr.Error("Stage-2: `depth` / `depth_image` is not a valid image object.")
482
 
483
+ # ---------------- Stage-2: REQUIRED refine/render ----------------
484
  @spaces.GPU
485
  def run_stage2_refine(
486
  image: Image.Image, # original image (RGB)
487
  stage1_out: Image.Image, # output from Stage-1
488
+ depth_img_from_stage1_input: Image.Image, # Stage-1 depth preview (from UI)
489
  mask_upload: Optional[Image.Image],
490
  sketch: Optional[dict],
491
  prompt: str,
 
499
  seed: Optional[int],
500
  ) -> Image.Image:
501
  if image is None or stage1_out is None:
502
+ raise gr.Error("Please complete Stage-1 first (needs original image and Stage-1 output).")
503
 
504
+ # Allow refine without mask (use all-black)
505
  mask_l = pick_mask(mask_upload, sketch, image, dilate_px=0)
506
  if (mask_l is None) or (mask_l.getbbox() is None):
507
  mask_l = Image.new("L", image.size, 0)
508
 
509
+ # Unify sizes
510
  width, height = prepare_size_for_flux(image, target_max=max_side)
511
  orig_w, orig_h = image.size
512
 
513
  pipe2 = get_pipe_stage2()
514
+ g2 = (
515
+ torch.Generator("cpu").manual_seed(int(seed))
516
+ if (seed is not None and seed >= 0)
517
  else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
518
+ )
519
  depth_pil = _to_pil_rgb(stage1_out) # for `depth`
520
  depth_image_pil = _to_pil_rgb(depth_img_from_stage1_input) # for `depth_image`
521
+ image_rgb = _to_pil_rgb(image)
522
 
523
+ # Resize to (width, height)
524
  depth_pil = depth_pil.resize((width, height), Image.BICUBIC)
525
  depth_image_pil = depth_image_pil.resize((width, height), Image.BICUBIC)
526
+
527
+ # Mapping:
528
+ # image = original RGB
529
+ # depth = Stage-1 output (updated geometry)
530
+ # depth_image = Stage-1 input depth (UI depth preview)
531
  out2 = pipe2(
532
  prompt=prompt,
533
+ image=image, # original image
534
  mask_image=mask_l,
535
  width=width,
536
  height=height,
 
538
  num_inference_steps=int(steps),
539
  max_sequence_length=512,
540
  generator=g2,
541
+ depth=depth_pil,
542
+ depth_image=depth_image_pil,
543
  ).images[0]
544
 
545
+ out2 = out2.resize((orig_w * 3, orig_h), Image.BICUBIC) # keep your 3× showcase layout
546
  return out2
547
 
 
 
548
  # ---------------- UI ----------------
549
  with gr.Blocks() as demo:
550
+ gr.Markdown(
551
+ """
552
+ # GeoRemover · Depth-Guided Object Removal (Two-Stage, Stage-2 REQUIRED)
553
+
554
+ **Pipeline overview**
555
+ 1) Compute a **colored depth map** from your input image.
556
+ 2) You create a **removal mask** (red brush or upload).
557
+ 3) **Stage-1** runs FLUX Fill with depth guidance to get a first pass.
558
+ 4) **Stage-2 (REQUIRED)** renders the final result from depth → image using Stage-1 output and the original depth.
559
+
560
+ > ⚠️ **Stage-2 is required.** Always click **Run Stage-2 (Render)** *after* Stage-1 finishes. Stage-1 alone is not the final output.
561
+
562
+ ---
563
+
564
+ ### Quick start
565
+ 1. **Upload image** (left). Wait for **Depth preview (colored)** (right).
566
+ 2. In **Draw mask**, pick **Draw on: _image_** or **_depth_**, then click **Prepare canvas**.
567
+ 3. Paint the region to remove using the **red brush** (**red = remove**).
568
+ 4. Optionally adjust **Mask dilation** for thin edges.
569
+ 5. Enter a concise **Prompt** describing the fill content.
570
+ 6. Click **Run** → produces **Stage-1** (first pass).
571
+ 7. Click **Run Stage-2 (Render)** → produces the **final** result.
572
+
573
+ ---
574
+
575
+ ### Mask rules & tips
576
+ - Only **red strokes** are treated as mask (**white = remove, black = keep** internally).
577
+ - Paint **slightly larger** than the object boundary to avoid seams/halos.
578
+ - If you have a binary mask already, use **Upload mask**.
579
+ - **Mask dilation (px)** expands the mask to cover thin borders.
580
+ """
581
+ )
582
 
583
  with gr.Row():
584
  with gr.Column(scale=1):
585
+ # Input image
586
+ img = gr.Image(
587
+ label="Upload image",
588
+ type="pil",
589
+ )
590
 
591
  # Mask: upload or draw
592
  with gr.Tab("Upload mask"):
593
+ mask_upload = gr.Image(
594
+ label="Mask (optional)",
595
+ type="pil",
596
+ )
597
 
598
  with gr.Tab("Draw mask"):
599
+ draw_source = gr.Radio(
600
+ ["image", "depth"],
601
+ value="image",
602
+ label="Draw on",
603
+ )
604
+ prepare_btn = gr.Button("Prepare canvas", variant="secondary")
605
+ gr.Markdown(
606
+ """
607
+ **Canvas usage**
608
+ - Click **Prepare canvas** after selecting *image* or *depth*.
609
+ - Use the **red brush** only—red strokes are extracted as the removal mask.
610
+ - Switch tabs anytime if you prefer uploading a ready-made mask.
611
+ """
612
+ )
613
  sketch = gr.ImageEditor(
614
+ label="Sketch mask (red = remove)",
615
  type="pil",
616
+ brush=gr.Brush(colors=["#FF0000"], default_size=24),
 
617
  )
618
 
619
+ # Prompt
620
+ prompt = gr.Textbox(
621
+ label="Prompt",
622
+ value="A beautiful scene",
623
+ placeholder="don't change it",
624
+ )
625
 
626
+ # Tunables
627
  with gr.Accordion("Advanced (Depth & FLUX)", open=False):
628
+ encoder = gr.Dropdown(
629
+ ["vits", "vitl"],
630
+ value="vitl",
631
+ label="Depth encoder",
632
+ )
633
+ max_res = gr.Slider(
634
+ 512, 2048, value=1280, step=64,
635
+ label="Depth: max_res",
636
+ )
637
+ input_size = gr.Slider(
638
+ 256, 1024, value=518, step=2,
639
+ label="Depth: input_size",
640
+ )
641
+ fp32 = gr.Checkbox(
642
+ False,
643
+ label="Depth: use FP32 (default FP16)",
644
+ )
645
+ max_side = gr.Slider(
646
+ 512, 1536, value=1024, step=64,
647
+ label="FLUX: max side (px)",
648
+ )
649
+ mask_dilate_px = gr.Slider(
650
+ 0, 128, value=0, step=1,
651
+ label="Mask dilation (px)",
652
+ )
653
+ guidance_scale = gr.Slider(
654
+ 0, 50, value=30, step=0.5,
655
+ label="FLUX: guidance_scale",
656
+ )
657
+ steps = gr.Slider(
658
+ 10, 75, value=50, step=1,
659
+ label="FLUX: steps",
660
+ )
661
+ seed = gr.Number(
662
+ value=0, precision=0,
663
+ label="Seed (>=0 = fixed; empty = random)",
664
+ )
665
 
666
  run_btn = gr.Button("Run", variant="primary")
667
+ # Stage-2 is REQUIRED: keep disabled until Stage-1 finishes
668
+ run_btn_stage2 = gr.Button("Run Stage-2 (Render)", variant="secondary", interactive=False)
 
669
 
670
  with gr.Column(scale=1):
671
+ depth_preview = gr.Image(
672
+ label="Depth preview (colored)",
673
+ interactive=False,
674
+ )
675
+ mask_preview = gr.Image(
676
+ label="Mask preview (areas to remove)",
677
+ interactive=False,
678
+ )
679
+ out = gr.Image(
680
+ label="Output (Stage-1 first pass)",
681
+ )
682
+ out_stage2 = gr.Image(
683
+ label="Final Output (Stage-2)",
684
+ )
685
+
686
+ gr.Markdown(
687
+ """
688
+ ### Why Stage-2 is required
689
+ Stage-1 provides a depth-guided fill that is *not final*. **Stage-2 renders** the definitive image by leveraging:
690
+ - **Stage-1 output** as updated geometry hints, and
691
+ - **Original colored depth** as `depth_image` guidance.
692
+ Skipping Stage-2 will leave the process incomplete.
693
+
694
+ ### Troubleshooting
695
+ - **“No valid mask detected”**: Either upload a binary mask (white=remove) **or** draw with **red brush** after clicking **Prepare canvas**.
696
+ - **Seams/halos**: Increase **Mask dilation (px)** (e.g., 8–16) and re-run both stages.
697
+ - **Prompt not followed**: Lower **guidance_scale** (e.g., 18–24) and make the prompt more concrete.
698
+ - **Depth looks noisy**: Use **vitl**, increase **Depth: max_res**, or enable **FP32**.
699
+ """
700
+ )
701
+
702
+ # ===== Helpers to toggle Stage-2 button =====
703
+ def _enable_button():
704
+ return gr.update(interactive=True)
705
+
706
+ # Auto depth preview on image change
707
  img.change(
708
  fn=preview_depth,
709
  inputs=[img, encoder, max_res, input_size, fp32],
710
  outputs=[depth_preview],
711
  )
712
 
713
+ # Prepare canvas for drawing on image or depth
714
  prepare_btn.click(
715
  fn=prepare_canvas,
716
  inputs=[img, depth_preview, draw_source],
717
  outputs=[sketch],
718
  )
719
 
720
+ # Stage-1
721
  run_btn.click(
722
  fn=run_depth_and_fill,
723
  inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
724
  max_side, mask_dilate_px, guidance_scale, steps, seed],
725
  outputs=[out, mask_preview],
726
  api_name="run",
727
+ ).then( # Enable Stage-2 only after Stage-1 completes
728
+ fn=_enable_button,
729
+ inputs=[],
730
+ outputs=[run_btn_stage2],
731
  )
732
 
733
+ # Stage-2 (REQUIRED; unlocked after Stage-1)
734
  run_btn_stage2.click(
735
  fn=run_stage2_refine,
736
+ inputs=[img, out, depth_preview,
737
  mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
738
  max_side, guidance_scale, steps, seed],
739
  outputs=[out_stage2],
740
  api_name="run_stage2",
741
  )
 
742
 
743
  if __name__ == "__main__":
744
  os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")