zixinz commited on
Commit
2f713b7
·
1 Parent(s): 5458ff3

chore: ignore pyc and __pycache__

Browse files
Files changed (3) hide show
  1. .gitignore +42 -0
  2. app.py +352 -49
  3. get_assets.sh +59 -0
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python caches
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+
7
+ # Build / venv
8
+ *.egg-info/
9
+ .dist/
10
+ build/
11
+ dist/
12
+ .venv/
13
+ venv/
14
+
15
+ # Weights / checkpoints (download via get_assets.sh)
16
+ code_depth/checkpoints/
17
+ code_edit/stage1/
18
+ code_edit/stage2/
19
+
20
+ # Large demo assets - don't version them in Space
21
+ code_depth/assets/
22
+ code_edit/assets/
23
+ code_edit/example_data/
24
+
25
+ # Common big binaries
26
+ *.mp4
27
+ *.mov
28
+ *.avi
29
+ *.webm
30
+ *.mkv
31
+ *.safetensors
32
+ *.pth
33
+ *.pt
34
+ *.npz
35
+ *.exr
36
+ *.zip
37
+ *.tar
38
+ *.tar.gz
39
+ *.7z
40
+
41
+ # Node caches (if any)
42
+ node_modules/
app.py CHANGED
@@ -1,87 +1,390 @@
1
- # app.py
2
  import os
 
3
  import pathlib
4
  import subprocess
 
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from PIL import Image
 
 
9
 
 
10
  BASE_DIR = pathlib.Path(__file__).resolve().parent
11
- SCRIPT_DIR = BASE_DIR / "code_depth"
12
- GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
 
13
 
14
- # 让我们能 import 到 code_depth/depth_infer.py
15
- import sys
16
- if str(SCRIPT_DIR) not in sys.path:
17
- sys.path.append(str(SCRIPT_DIR))
 
 
 
 
 
 
 
18
 
19
- from depth_infer import DepthModel # noqa
 
 
 
 
 
 
 
 
 
20
 
21
  def _ensure_executable(p: pathlib.Path):
22
  if not p.exists():
23
  raise FileNotFoundError(f"Not found: {p}")
24
  os.chmod(p, os.stat(p).st_mode | 0o111)
25
 
26
- def ensure_weights():
27
- """在 code_depth 目录下运行你的 get_weights.sh。"""
28
- _ensure_executable(GET_WEIGHTS_SH)
 
 
 
 
 
 
29
  subprocess.run(
30
- ["bash", str(GET_WEIGHTS_SH)],
31
  check=True,
32
- cwd=str(SCRIPT_DIR),
33
  env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
34
  )
35
- ckpt_dir = SCRIPT_DIR / "checkpoints"
36
- if not ckpt_dir.exists():
37
- raise RuntimeError("weights download script ran but checkpoints/ not found")
38
- return str(ckpt_dir)
39
 
40
- # 启动时下载权重(不开持久化时,若环境重建会再次下载)
41
  try:
42
- CKPT_DIR = ensure_weights()
43
- print(f"✅ Weights ready in: {CKPT_DIR}")
44
  except Exception as e:
45
- print(f"⚠️ Failed to prepare weights: {e}")
46
 
47
- # 模型缓存(按 encoder 复用)
48
  _MODELS: dict[str, DepthModel] = {}
 
49
 
50
  def get_model(encoder: str) -> DepthModel:
51
  if encoder not in _MODELS:
52
  _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
53
  return _MODELS[encoder]
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @spaces.GPU
56
- def infer_depth(
57
  image: Image.Image,
58
- encoder: str = "vitl",
59
- max_res: int = 1280,
60
- input_size: int = 518,
61
- fp32: bool = False,
62
- grayscale: bool = False,
63
- ) -> Image.Image:
64
- # 这里才真正触发 CUDA 设备占用
65
- device = "cuda" if torch.cuda.is_available() else "cpu"
66
- print(f"[infer] device={device}, encoder={encoder}, max_res={max_res}, input_size={input_size}, fp32={fp32}, gray={grayscale}")
67
- model = get_model(encoder)
68
- return model.infer(image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=grayscale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  with gr.Blocks() as demo:
71
- gr.Markdown("## GeoRemover · Depth Preview (Video-Depth-Anything)")
 
72
  with gr.Row():
73
- with gr.Column():
74
- inp = gr.Image(label="Upload image", type="pil")
75
- encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Encoder")
76
- max_res = gr.Slider(512, 2048, value=1280, step=64, label="Max resolution")
77
- input_size = gr.Slider(256, 1024, value=518, step=2, label="Model input_size")
78
- fp32 = gr.Checkbox(False, label="Use FP32 (default FP16)")
79
- gray = gr.Checkbox(False, label="Grayscale depth")
80
- btn = gr.Button("Run")
81
- with gr.Column():
82
- out = gr.Image(label="Depth visualization")
83
-
84
- btn.click(fn=infer_depth, inputs=[inp, encoder, max_res, input_size, fp32, gray], outputs=[out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
1
  import os
2
+ import sys
3
  import pathlib
4
  import subprocess
5
+ import random
6
+ from typing import Optional, Tuple
7
+
8
  import gradio as gr
9
  import spaces
10
  import torch
11
+ from PIL import Image, ImageOps
12
+ import numpy as np
13
+ import cv2
14
 
15
+ # ---------------- Paths & assets ----------------
16
  BASE_DIR = pathlib.Path(__file__).resolve().parent
17
+ CODE_DEPTH = BASE_DIR / "code_depth"
18
+ CODE_EDIT = BASE_DIR / "code_edit"
19
+ GET_ASSETS = BASE_DIR / "get_assets.sh"
20
 
21
+ EXPECTED_ASSETS = [
22
+ BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth",
23
+ BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth",
24
+ BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors",
25
+ BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
26
+ ]
27
+
28
+ # import depth helper
29
+ if str(CODE_DEPTH) not in sys.path:
30
+ sys.path.insert(0, str(CODE_DEPTH))
31
+ from depth_infer import DepthModel # noqa: E402
32
 
33
+ # import your custom diffusers
34
+ if str(CODE_EDIT / "diffusers") not in sys.path:
35
+ sys.path.insert(0, str(CODE_EDIT / "diffusers"))
36
+ from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402
37
+ FluxFillPipeline_token12_depth_only as FluxFillPipeline,
38
+ )
39
+
40
+ # ---------------- Assets ensure (on-demand) ----------------
41
+ def _have_all_assets() -> bool:
42
+ return all(p.is_file() for p in EXPECTED_ASSETS)
43
 
44
  def _ensure_executable(p: pathlib.Path):
45
  if not p.exists():
46
  raise FileNotFoundError(f"Not found: {p}")
47
  os.chmod(p, os.stat(p).st_mode | 0o111)
48
 
49
+ def ensure_assets_if_missing():
50
+ if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
51
+ print("↪️ SKIP_ASSET_DOWNLOAD=1 -> 跳过资产下载检查")
52
+ return
53
+ if _have_all_assets():
54
+ print("✅ Assets already present")
55
+ return
56
+ print("⬇️ Missing assets, running get_assets.sh ...")
57
+ _ensure_executable(GET_ASSETS)
58
  subprocess.run(
59
+ ["bash", str(GET_ASSETS)],
60
  check=True,
61
+ cwd=str(BASE_DIR),
62
  env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
63
  )
64
+ if not _have_all_assets():
65
+ missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()]
66
+ raise RuntimeError(f"Assets missing after get_assets.sh: {missing}")
67
+ print("✅ Assets ready.")
68
 
 
69
  try:
70
+ ensure_assets_if_missing()
 
71
  except Exception as e:
72
+ print(f"⚠️ Asset prepare failed: {e}")
73
 
74
+ # ---------------- Global singletons ----------------
75
  _MODELS: dict[str, DepthModel] = {}
76
+ _PIPE: Optional[FluxFillPipeline] = None
77
 
78
  def get_model(encoder: str) -> DepthModel:
79
  if encoder not in _MODELS:
80
  _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
81
  return _MODELS[encoder]
82
 
83
+ def get_pipe() -> FluxFillPipeline:
84
+ global _PIPE
85
+ if _PIPE is not None:
86
+ return _PIPE
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
89
+ print(f"[pipe] load FLUX.1-Fill-dev dtype={dtype}, device={device}")
90
+ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=dtype).to(device)
91
+
92
+ # LoRA(stage1)
93
+ lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
94
+ if lora_dir.exists():
95
+ try:
96
+ pipe.load_lora_weights(str(lora_dir)) # 需要 peft
97
+ print(f"[pipe] loaded LoRA from: {lora_dir}")
98
+ except Exception as e:
99
+ print(f"[pipe] load LoRA failed (continue without): {e}")
100
+ else:
101
+ print(f"[pipe] LoRA path not found: {lora_dir} (continue without)")
102
+
103
+ _PIPE = pipe
104
+ return pipe
105
+
106
+ # ---------------- Mask helpers ----------------
107
+ def to_grayscale_mask(im: Image.Image) -> Image.Image:
108
+ """
109
+ 将任意 RGBA/RGB/L 的图转为 L。
110
+ 输出:白=需要移除/填充区域,黑=保留。
111
+ """
112
+ if im.mode == "RGBA":
113
+ mask = im.split()[-1] # alpha as mask
114
+ else:
115
+ mask = im.convert("L")
116
+ # 简单二值化,去噪
117
+ mask = mask.point(lambda p: 255 if p > 16 else 0)
118
+ return mask # 不做 invert,白色=mask区域
119
+
120
+ def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
121
+ """对白色区域做膨胀;px 约等于扩大像素。"""
122
+ if px <= 0:
123
+ return mask_l
124
+ arr = np.array(mask_l, dtype=np.uint8)
125
+ kernel = np.ones((3, 3), np.uint8)
126
+ iters = max(1, int(px // 2)) # 经验
127
+ dilated = cv2.dilate(arr, kernel, iterations=iters)
128
+ return Image.fromarray(dilated, mode="L")
129
+
130
+ def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
131
+ """
132
+ 从一张 RGBA/RGB 图里提取“纯红笔迹”为二值蒙版(白=画笔,黑=其他)。
133
+ 阈值稍微宽一点以容忍压缩/插值。
134
+ """
135
+ arr = np.array(img.convert("RGBA"))
136
+ r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
137
+
138
+ # 条件:红高、绿低、蓝低、且 alpha>0
139
+ red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
140
+
141
+ mask = (red_hit.astype(np.uint8) * 255)
142
+ m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
143
+ return m
144
+
145
+ def pick_mask(
146
+ upload_mask: Optional[Image.Image],
147
+ sketch_data: Optional[dict],
148
+ base_image: Image.Image,
149
+ dilate_px: int = 0,
150
+ ) -> Optional[Image.Image]:
151
+ """
152
+ 规则:
153
+ 1) 若用户上传了 mask:直接用(白=mask)
154
+ 2) 否则从 ImageEditor 返回里只“认红色笔迹”为 mask:
155
+ - 先看 sketch_data['mask'](有些版本会给)
156
+ - 不然遍历 sketch_data['layers'][*]['image'],合并其中的红色笔迹
157
+ - 若还没有,再退到 sketch_data['composite'] 里找红色笔迹
158
+ """
159
+ # 1) 上传优先
160
+ if isinstance(upload_mask, Image.Image):
161
+ m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
162
+ return dilate_mask(m, dilate_px) if dilate_px > 0 else m
163
+
164
+ # 2) 手绘(ImageEditor)
165
+ if isinstance(sketch_data, dict):
166
+ # 2a) 显式 mask(仍然支持)
167
+ m = sketch_data.get("mask")
168
+ if isinstance(m, Image.Image):
169
+ m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
170
+ return dilate_mask(m, dilate_px) if dilate_px > 0 else m
171
+
172
+ # 2b) 从 layers 里合并红色笔迹
173
+ layers = sketch_data.get("layers")
174
+ acc = None
175
+ if isinstance(layers, list) and layers:
176
+ acc = Image.new("L", base_image.size, 0)
177
+ for lyr in layers:
178
+ if not isinstance(lyr, dict):
179
+ continue
180
+ li = lyr.get("image") or lyr.get("mask")
181
+ if isinstance(li, Image.Image):
182
+ m_layer = _mask_from_red(li, base_image.size)
183
+ # 合并:有任一层画过就算 mask
184
+ acc = ImageOps.lighter(acc, m_layer)
185
+ if acc.getbbox() is not None:
186
+ return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
187
+
188
+ # 2c) 最后从 composite 里找红色笔迹
189
+ comp = sketch_data.get("composite")
190
+ if isinstance(comp, Image.Image):
191
+ m_comp = _mask_from_red(comp, base_image.size)
192
+ if m_comp.getbbox() is not None:
193
+ return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
194
+
195
+ # 3) 没拿到就返回 None(后面会提示“需要掩码”)
196
+ return None
197
+
198
+
199
+ def _round_mult64(x: float, mode: str = "nearest") -> int:
200
+ """
201
+ 把 x 对齐到 64 的倍数:
202
+ - mode="ceil" 向上取整
203
+ - mode="floor" 向下取整
204
+ - mode="nearest" 最近的倍数
205
+ """
206
+ if mode == "ceil":
207
+ return int((x + 63) // 64) * 64
208
+ elif mode == "floor":
209
+ return int(x // 64) * 64
210
+ else: # nearest
211
+ return int((x + 32) // 64) * 64
212
+
213
+ def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
214
+ """
215
+ 步骤:
216
+ 1) 先把原始 w,h 向上对齐到 64 的倍数(避免小图过小)
217
+ 2) 把长边固定为 target_max(默认1024)
218
+ 3) 短边按比例缩放并对齐到 64 的倍数(至少 64)
219
+ """
220
+ w, h = img.size
221
+
222
+ # 1) 先各自向上对齐到 64 的倍数
223
+ w1 = max(64, _round_mult64(w, mode="ceil"))
224
+ h1 = max(64, _round_mult64(h, mode="ceil"))
225
+
226
+ # 2) 固定长边为 target_max,短边按比例
227
+ if w1 >= h1:
228
+ out_w = target_max # 长边固定 1024
229
+ scaled_h = h1 * (target_max / w1)
230
+ out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
231
+ else:
232
+ out_h = target_max
233
+ scaled_w = w1 * (target_max / h1)
234
+ out_w = max(64, _round_mult64(scaled_w, mode="nearest"))
235
+
236
+ return int(out_w), int(out_h)
237
+
238
+ # ---------------- Preview depth for canvas (彩色) ----------------
239
+ def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
240
+ if image is None:
241
+ return None
242
+ dm = get_model(encoder)
243
+ # 彩色可视化(RGB),严格按你之前的 colormap 风格
244
+ d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
245
+ return d_rgb
246
+
247
+ def prepare_canvas(image, depth_img, source):
248
+ base = depth_img if source == "depth" else image
249
+ if base is None:
250
+ raise gr.Error("请先上传图片(并等待深度预览出来),再点击\"Prepare canvas\"。")
251
+ # 对 ImageEditor 用通用的 gr.update 来设置 value
252
+ return gr.update(value=base)
253
+
254
+ # ---------------- Two-stage pipeline: depth(color) -> fill ----------------
255
  @spaces.GPU
256
+ def run_depth_and_fill(
257
  image: Image.Image,
258
+ mask_upload: Optional[Image.Image],
259
+ sketch: Optional[dict],
260
+ prompt: str,
261
+ encoder: str,
262
+ max_res: int,
263
+ input_size: int,
264
+ fp32: bool,
265
+ max_side: int,
266
+ mask_dilate_px: int,
267
+ guidance_scale: float,
268
+ steps: int,
269
+ seed: Optional[int],
270
+ ) -> Tuple[Image.Image, Image.Image]:
271
+ if image is None:
272
+ raise gr.Error("请先上传一张图片。")
273
+
274
+ # 1) 生成彩色深度图(RGB)
275
+ depth_model = get_model(encoder)
276
+ depth_rgb: Image.Image = depth_model.infer(
277
+ image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
278
+ ).convert("RGB")
279
+
280
+ print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
281
+
282
+ # 2) 提取 mask(上传 > 手绘)
283
+ mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
284
+ if (mask_l is None) or (mask_l.getbbox() is None):
285
+ raise gr.Error("没有检测到有效的 mask:请确认已在画布上涂抹或上传 mask 图片。")
286
+
287
+ print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
288
+
289
+ # 3) 确定输出尺寸
290
+ width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
291
+ orig_w, orig_h = image.size
292
+ print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
293
 
294
+ # 4) 运行 FLUX pipeline
295
+ # 关键修复:image 参数应该传入 depth_rgb 而不是原图
296
+ pipe = get_pipe()
297
+ generator = torch.Generator("cpu").manual_seed(int(seed)) if (seed is not None and seed >= 0) else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
298
+
299
+ result = pipe(
300
+ prompt=prompt,
301
+ image=depth_rgb, # 修复:传入彩色深度图而不是原图
302
+ mask_image=mask_l,
303
+ width=width,
304
+ height=height,
305
+ guidance_scale=float(guidance_scale),
306
+ num_inference_steps=int(steps),
307
+ max_sequence_length=512,
308
+ generator=generator,
309
+ depth=depth_rgb, # depth 参数也传入彩色深度图
310
+ ).images[0]
311
+
312
+ final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
313
+
314
+ # 返回结果和 mask 预览
315
+ mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
316
+ return final_result, mask_preview
317
+
318
+
319
+ # ---------------- UI ----------------
320
  with gr.Blocks() as demo:
321
+ gr.Markdown("## GeoRemover · Depth Removal (Depth(color) → FLUX Fill)")
322
+
323
  with gr.Row():
324
+ with gr.Column(scale=1):
325
+ # 输入图
326
+ img = gr.Image(label="Upload image", type="pil")
327
+
328
+ # Mask 两种方式:上传 or
329
+ with gr.Tab("Upload mask"):
330
+ mask_upload = gr.Image(label="Mask (optional)", type="pil")
331
+
332
+ with gr.Tab("Draw mask"):
333
+ draw_source = gr.Radio(["image", "depth"], value="image", label="Draw on")
334
+ prepare_btn = gr.Button("Prepare canvas")
335
+ sketch = gr.ImageEditor(
336
+ label="Sketch mask (draw with brush)",
337
+ type="pil",
338
+ # 画笔只给纯红,方便我们精确提取笔迹
339
+ brush=gr.Brush(colors=["#FF0000"], default_size=24)
340
+ )
341
+
342
+
343
+ # prompt
344
+ prompt = gr.Textbox(label="Prompt", value="A beautiful scene")
345
+
346
+ # 可调参数
347
+ with gr.Accordion("Advanced (Depth & FLUX)", open=False):
348
+ encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Depth encoder")
349
+ max_res = gr.Slider(512, 2048, value=1280, step=64, label="Depth: max_res")
350
+ input_size = gr.Slider(256, 1024, value=518, step=2, label="Depth: input_size")
351
+ fp32 = gr.Checkbox(False, label="Depth: use FP32 (default FP16)")
352
+ max_side = gr.Slider(512, 1536, value=1024, step=64, label="FLUX: max side (px)")
353
+ mask_dilate_px = gr.Slider(0, 128, value=0, step=1, label="Mask dilation (px)")
354
+ guidance_scale = gr.Slider(0, 50, value=30, step=0.5, label="FLUX: guidance_scale")
355
+ steps = gr.Slider(10, 75, value=50, step=1, label="FLUX: steps")
356
+ seed = gr.Number(value=0, precision=0, label="Seed (>=0 固定;留空随机)")
357
+
358
+ run_btn = gr.Button("Run", variant="primary")
359
+
360
+ with gr.Column(scale=1):
361
+ depth_preview = gr.Image(label="Depth preview (colored)", interactive=False)
362
+ mask_preview = gr.Image(label="Mask preview (what will be removed)", interactive=False)
363
+ out = gr.Image(label="Output")
364
+
365
+ # 事件:上传图片后生成"彩色深度预览"
366
+ img.change(
367
+ fn=preview_depth,
368
+ inputs=[img, encoder, max_res, input_size, fp32],
369
+ outputs=[depth_preview],
370
+ )
371
+
372
+ # 准备画布:把原图或"彩色深度图"放进 ImageEditor
373
+ prepare_btn.click(
374
+ fn=prepare_canvas,
375
+ inputs=[img, depth_preview, draw_source],
376
+ outputs=[sketch],
377
+ )
378
+
379
+ # 运行
380
+ run_btn.click(
381
+ fn=run_depth_and_fill,
382
+ inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
383
+ max_side, mask_dilate_px, guidance_scale, steps, seed],
384
+ outputs=[out, mask_preview],
385
+ api_name="run",
386
+ )
387
 
388
  if __name__ == "__main__":
389
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
390
+ demo.launch(server_name="0.0.0.0", server_port=7860)
get_assets.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # 必要目录
5
+ mkdir -p code_depth/checkpoints \
6
+ code_edit/stage1/checkpoint-4800 \
7
+ code_edit/stage2/checkpoint-20000
8
+
9
+ # 通用下载函数:优先 curl,退回 wget;内建重试和断点续传
10
+ fetch() {
11
+ local url="$1"
12
+ local out="$2"
13
+ # 已存在就跳过
14
+ if [ -s "$out" ]; then
15
+ echo "✔ Exists: $out"
16
+ return 0
17
+ fi
18
+ echo "↓ Fetch: $url -> $out"
19
+ if command -v curl >/dev/null 2>&1; then
20
+ # --retry 对网络/5xx/超时都重试;-C - 断点续传;-f 让 4xx/5xx 变为非0退出
21
+ curl -fL --retry 5 --retry-all-errors --connect-timeout 20 -C - -o "$out" "${url}?download=1"
22
+ else
23
+ # wget 也加 tries 与 continue
24
+ wget --tries=5 -c -O "$out" "${url}?download=1"
25
+ fi
26
+ }
27
+
28
+ # 1) VDA 权重
29
+ fetch "https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth" \
30
+ "code_depth/checkpoints/video_depth_anything_vits.pth"
31
+
32
+ fetch "https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth" \
33
+ "code_depth/checkpoints/video_depth_anything_vitl.pth"
34
+
35
+ # 2) 你的 stage1 / stage2 两个 safetensors
36
+ fetch "https://huggingface.co/buxiangzhiren/GeoRemover/resolve/main/stage1/checkpoint-4800/pytorch_lora_weights.safetensors" \
37
+ "code_edit/stage1/checkpoint-4800/pytorch_lora_weights.safetensors"
38
+
39
+ fetch "https://huggingface.co/buxiangzhiren/GeoRemover/resolve/main/stage2/checkpoint-20000/pytorch_lora_weights.safetensors" \
40
+ "code_edit/stage2/checkpoint-20000/pytorch_lora_weights.safetensors"
41
+
42
+ # 最终校验:缺哪个报名字
43
+ missing=()
44
+ need=(
45
+ "code_depth/checkpoints/video_depth_anything_vits.pth"
46
+ "code_depth/checkpoints/video_depth_anything_vitl.pth"
47
+ "code_edit/stage1/checkpoint-4800/pytorch_lora_weights.safetensors"
48
+ "code_edit/stage2/checkpoint-20000/pytorch_lora_weights.safetensors"
49
+ )
50
+ for f in "${need[@]}"; do
51
+ [ -s "$f" ] || missing+=("$f")
52
+ done
53
+ if [ ${#missing[@]} -ne 0 ]; then
54
+ echo "❌ Missing after download:" >&2
55
+ printf ' - %s\n' "${missing[@]}" >&2
56
+ exit 1
57
+ fi
58
+
59
+ echo "✅ All assets ready."