euIaxs22 commited on
Commit
654e6f7
·
verified ·
1 Parent(s): 20ef566

Update services/vince_server.py

Browse files
Files changed (1) hide show
  1. services/vince_server.py +160 -237
services/vince_server.py CHANGED
@@ -1,247 +1,170 @@
1
  #!/usr/bin/env python3
2
- """
3
- services/vincie.py
4
-
5
- VincieService — preparação e execução CLI do VINCIE (upstream)
6
- - Garante repositório íntegro (clona/repara se faltarem main.py/.git).
7
- - Baixa snapshot completo do modelo no HF_HUB_CACHE.
8
- - Cria symlink idempotente ckpt/VINCIE-3B (no repo e em /app/ckpt) apontando para o snapshot (contém dit.pth, vae.pth, llm14b).
9
- - Valida artefatos esperados pelo generate.yaml.
10
- - Executa main.py do upstream com overrides de geração (sem mexer em ckpt.path).
11
- - Limpa VRAM levemente após cada job.
12
-
13
- Observação:
14
- - Para latência mínima, preferir o vince_server in-process (pipeline aquecida).
15
- - Este serviço via subprocess é fiel ao upstream e útil como fallback/diag.
16
- """
17
-
18
- import os
19
- import json
20
- import subprocess
21
  from pathlib import Path
22
  from typing import List, Optional
 
23
 
24
- from huggingface_hub import snapshot_download
 
 
25
 
26
-
27
- class VincieService:
28
- def __init__(
29
- self,
30
- repo_dir: str = "/app/VINCIE",
31
- python_bin: str = "python",
32
- repo_url: str = "https://github.com/ByteDance-Seed/VINCIE",
33
- model_repo: str = "ByteDance-Seed/VINCIE-3B",
34
- output_root: str = "/app/outputs",
35
- ):
36
- self.repo_dir = Path(repo_dir)
37
- self.python = python_bin
38
- self.repo_url = repo_url
39
- self.model_repo = model_repo
40
- self.output_root = Path(output_root)
41
- self.output_root.mkdir(parents=True, exist_ok=True)
42
-
43
- self.generate_yaml = self.repo_dir / "configs" / "generate.yaml"
44
- self.ckpt_link_repo = self.repo_dir / "ckpt" / "VINCIE-3B"
45
- self.ckpt_link_app = Path("/app/ckpt") / "VINCIE-3B"
46
-
47
- self.ckpt_dir: Optional[Path] = None
48
- self._env = os.environ.copy()
49
-
50
- # ---------- util ----------
51
-
52
- @staticmethod
53
- def _run(cmd: List[str], cwd: Optional[Path] = None, env=None):
54
- subprocess.run(cmd, cwd=str(cwd) if cwd else None, check=True, env=env)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @staticmethod
57
- def _ensure_symlink(link: Path, target: Path):
58
- link.parent.mkdir(parents=True, exist_ok=True)
59
- if link.is_symlink():
60
- try:
61
- if link.resolve() != target:
62
- link.unlink()
63
- link.symlink_to(target, target_is_directory=True)
64
- except Exception:
65
- # relinka a partir do zero
66
- link.unlink(missing_ok=True)
67
- link.symlink_to(target, target_is_directory=True)
68
- elif link.exists():
69
- VincieService._run(["rm", "-rf", str(link)])
70
- link.symlink_to(target, target_is_directory=True)
71
- else:
72
- link.symlink_to(target, target_is_directory=True)
73
-
74
- # ---------- repo/modelo ----------
75
-
76
- def ensure_repo(self) -> None:
77
- self.repo_dir.mkdir(parents=True, exist_ok=True)
78
- main_py = self.repo_dir / "main.py"
79
- git_dir = self.repo_dir / ".git"
80
- if main_py.exists() and git_dir.exists():
81
- return
82
- tmp = self.repo_dir.with_name(self.repo_dir.name + ".tmp")
83
- if tmp.exists():
84
- self._run(["rm", "-rf", str(tmp)])
85
- self._run(["git", "clone", self.repo_url, str(tmp)])
86
- # swap atômico simples
87
- if self.repo_dir.exists():
88
- self._run(["rm", "-rf", str(self.repo_dir)])
89
- tmp.rename(self.repo_dir)
90
-
91
- def ensure_model(self, revision: Optional[str] = None, token: Optional[str] = None) -> None:
92
- cache_dir = os.environ.get("HF_HUB_CACHE")
93
- snapshot_path = snapshot_download(
94
- repo_id=self.model_repo,
95
- revision=revision,
96
- token=token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN"),
97
- cache_dir=cache_dir,
98
- resume_download=True,
99
- )
100
- self.ckpt_dir = Path(snapshot_path)
101
- # symlinks idempotentes
102
- self._ensure_symlink(self.ckpt_link_repo, self.ckpt_dir)
103
- self._ensure_symlink(self.ckpt_link_app, self.ckpt_dir)
104
-
105
- def validate_assets(self) -> None:
106
- # exige generate.yaml/main.py e conteúdo essencial no snapshot
107
- if not self.generate_yaml.exists() or not (self.repo_dir / "main.py").exists():
108
- raise RuntimeError("VINCIE repo inválido (faltando generate.yaml ou main.py)")
109
- target = self.ckpt_dir or self.ckpt_link_repo
110
- need = [target / "dit.pth", target / "vae.pth", target / "llm14b"]
111
- missing = [str(p) for p in need if not p.exists()]
112
- if missing:
113
- raise RuntimeError(f"Snapshot incompleto: {missing}")
114
- # também requer que o link repo exista (a config usa ckpt/ relativo)
115
- if not self.ckpt_link_repo.exists():
116
- raise RuntimeError("ckpt link ausente no repo: ckpt/VINCIE-3B")
117
-
118
- # ---------- execução ----------
119
-
120
- def _build_overrides(
121
- self,
122
- extra_overrides: Optional[List[str]] = None,
123
- cfg_scale: Optional[float] = None,
124
- resolution_input: Optional[int] = None,
125
- aspect_ratio_input: Optional[str] = None,
126
- steps: Optional[int] = None,
127
- ) -> List[str]:
128
- overrides = list(extra_overrides or [])
129
- # não altera ckpt.path; respeita o YAML
130
- if cfg_scale is not None:
131
- overrides.append(f"generation.cfg_scale={cfg_scale}")
132
- if resolution_input is not None:
133
- overrides.append(f"generation.resolution_input={resolution_input}")
134
- if aspect_ratio_input is not None:
135
- overrides.append(f"generation.aspect_ratio_input={aspect_ratio_input}")
136
- if steps is not None:
137
- overrides.append(f"generation.steps={steps}")
138
- return overrides
139
-
140
- def _clean_gpu_memory(self) -> None:
141
- code = r"""
142
- import torch, gc
143
- try:
144
- torch.cuda.synchronize()
145
- except Exception:
146
- pass
147
- gc.collect()
148
- try:
149
- torch.cuda.empty_cache()
150
- torch.cuda.memory.reset_peak_memory_stats()
151
- except Exception:
152
- pass
153
- """
154
- self._run([self.python, "-c", code], env=self._env)
155
-
156
- # ---------- APIs ----------
157
-
158
- def multi_turn_edit(
159
- self,
160
- input_image: str,
161
- turns: List[str],
162
- out_dir_name: Optional[str] = None,
163
- *,
164
- cfg_scale: Optional[float] = None,
165
- resolution_input: Optional[int] = None,
166
- aspect_ratio_input: Optional[str] = None,
167
- steps: Optional[int] = None,
168
- ) -> str:
169
- self.ensure_repo()
170
- self.ensure_model()
171
- self.validate_assets()
172
-
173
- out_dir = self.output_root / (out_dir_name or f"multi_turn_{Path(input_image).stem}")
174
  out_dir.mkdir(parents=True, exist_ok=True)
175
-
176
- image_json = json.dumps([str(input_image)])
177
- prompts_json = json.dumps(turns)
178
-
179
- base_overrides = [
180
- f"generation.positive_prompt.image_path={image_json}",
181
- f"generation.positive_prompt.prompts={prompts_json}",
182
- ]
183
- overrides = self._build_overrides(
184
- extra_overrides=base_overrides,
185
- cfg_scale=cfg_scale,
186
- resolution_input=resolution_input,
187
- aspect_ratio_input=aspect_ratio_input,
188
- steps=steps,
189
- )
190
-
191
- cmd = [
192
- self.python,
193
- "main.py",
194
- str(self.generate_yaml),
195
- *overrides,
196
- f"generation.output.dir={str(out_dir)}",
197
- ]
198
- self._run(cmd, cwd=self.repo_dir, env=self._env)
199
- self._clean_gpu_memory()
200
- return str(out_dir)
201
-
202
- def multi_concept_compose(
203
- self,
204
- concept_images: List[str],
205
- concept_prompts: List[str],
206
- final_prompt: str,
207
- out_dir_name: Optional[str] = None,
208
- *,
209
- cfg_scale: Optional[float] = None,
210
- resolution_input: Optional[int] = None,
211
- aspect_ratio_input: Optional[str] = None,
212
- steps: Optional[int] = None,
213
- ) -> str:
214
- self.ensure_repo()
215
- self.ensure_model()
216
- self.validate_assets()
217
-
218
  out_dir = self.output_root / (out_dir_name or "multi_concept")
219
  out_dir.mkdir(parents=True, exist_ok=True)
220
-
221
- imgs_json = json.dumps([str(p) for p in concept_images])
222
- prompts_all = concept_prompts + [final_prompt]
223
- prompts_json = json.dumps(prompts_all)
224
-
225
- base_overrides = [
226
- f"generation.positive_prompt.image_path={imgs_json}",
227
- f"generation.positive_prompt.prompts={prompts_json}",
228
- "generation.pad_img_placehoder=False",
229
- ]
230
- overrides = self._build_overrides(
231
- extra_overrides=base_overrides,
232
- cfg_scale=cfg_scale,
233
- resolution_input=resolution_input,
234
- aspect_ratio_input=aspect_ratio_input,
235
- steps=steps,
236
- )
237
-
238
- cmd = [
239
- self.python,
240
- "main.py",
241
- str(self.generate_yaml),
242
- *overrides,
243
- f"generation.output.dir={str(out_dir)}",
244
- ]
245
- self._run(cmd, cwd=self.repo_dir, env=self._env)
246
- self._clean_gpu_memory()
247
- return str(out_dir)
 
1
  #!/usr/bin/env python3
2
+ import os, sys, gc, subprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from pathlib import Path
4
  from typing import List, Optional
5
+ from omegaconf import OmegaConf, open_dict
6
 
7
+ VINCIE_DIR = Path(os.getenv("VINCIE_DIR", "/app/VINCIE"))
8
+ if str(VINCIE_DIR) not in sys.path:
9
+ sys.path.insert(0, str(VINCIE_DIR))
10
 
11
+ # inclui 'models/' relativo
12
+ try:
13
+ app_models = Path("/app/models"); vincie_models = VINCIE_DIR / "models"
14
+ if not app_models.exists() and vincie_models.exists():
15
+ app_models.symlink_to(vincie_models, target_is_directory=True)
16
+ except Exception as e:
17
+ print("[vince_server] warn: link /app/models failed:", e)
18
+
19
+ from common.config import load_config, create_object # type: ignore
20
+
21
+ class VinceServer:
22
+ def __init__(self, config_path: str="/app/VINCIE/configs/generate.yaml",
23
+ *, base_overrides: Optional[List[str]]=None,
24
+ output_root: str="/app/outputs", chdir_repo: bool=True):
25
+ self.config_path = config_path
26
+ self.output_root = Path(output_root); self.output_root.mkdir(parents=True, exist_ok=True)
27
+ overrides = list(base_overrides or [])
28
+ if chdir_repo:
29
+ try: os.chdir(str(VINCIE_DIR))
30
+ except Exception as e: print("[vince_server] warn: chdir repo failed:", e)
31
+
32
+ try:
33
+ self._load_and_bootstrap(overrides)
34
+ except Exception as e:
35
+ print("[vince_server] bootstrap failed; repairing symlinks:", e)
36
+ self._repair_ckpt_links()
37
+ self._load_and_bootstrap(overrides)
38
+
39
+ def _load_and_bootstrap(self, overrides: List[str]):
40
+ self._assert_ckpt_ready()
41
+ self.config = load_config(self.config_path, overrides)
42
+ self.gen = create_object(self.config)
43
+ self._bootstrap_models()
44
+
45
+ def _repair_ckpt_links(self):
46
+ # reusa snapshot atual para recriar links idempotentes
47
+ from services.vincie import VincieService
48
+ svc = VincieService(); svc.ensure_model()
49
+ snapshot = Path(str(svc.ckpt_dir))
50
+ for link in (VINCIE_DIR/"ckpt"/"VINCIE-3B", Path("/app/ckpt")/"VINCIE-3B"):
51
+ link.parent.mkdir(parents=True, exist_ok=True)
52
+ if link.is_symlink() and link.resolve()!=snapshot:
53
+ link.unlink()
54
+ if link.exists() and (not link.is_symlink()):
55
+ subprocess.run(["rm","-rf",str(link)], check=True)
56
+ if not link.exists():
57
+ link.symlink_to(snapshot, target_is_directory=True)
58
+ print("[vince_server] ckpt symlinks refreshed")
59
+
60
+ def _assert_ckpt_ready(self):
61
+ # ambos caminhos funcionam; a config usa 'ckpt/...' com cwd no repo
62
+ repo_link = VINCIE_DIR / "ckpt" / "VINCIE-3B"
63
+ if not repo_link.exists():
64
+ raise RuntimeError("missing ckpt link: /app/VINCIE/ckpt/VINCIE-3B")
65
+ must = [repo_link/"dit.pth", repo_link/"vae.pth", repo_link/"llm14b"]
66
+ missing = [str(p) for p in must if not p.exists()]
67
+ if missing:
68
+ raise RuntimeError(f"missing ckpt content: {missing}")
69
 
70
  @staticmethod
71
+ def _make_writable(cfg):
72
+ try: OmegaConf.set_readonly(cfg, False); OmegaConf.set_struct(cfg, False)
73
+ except Exception: pass
74
+
75
+ def _bootstrap_models(self):
76
+ for name in ("configure_persistence","configure_models","configure_diffusion","configure_sampler"):
77
+ fn = getattr(self.gen, name, None)
78
+ if not callable(fn): raise RuntimeError(f"[vince_server] missing step: {name}")
79
+ fn()
80
+ if not hasattr(self.gen, "sampler"):
81
+ raise RuntimeError("[vince_server] missing component: sampler")
82
+
83
+ def _set_steps(self, steps: Optional[int]):
84
+ if not steps: return
85
+ sampler = getattr(self.gen, "sampler", None); t = getattr(sampler, "timesteps", None)
86
+ if sampler is None or t is None: return
87
+ try:
88
+ import torch
89
+ if hasattr(t,"__len__") and len(t)>0:
90
+ steps = max(1, min(int(steps), len(t)))
91
+ if steps < len(t):
92
+ idx = torch.linspace(0, len(t)-1, steps).round().long().tolist()
93
+ sampler.timesteps = [t[i] for i in idx]
94
+ except Exception as e:
95
+ print(f"[vince_server] Warning: set_steps failed: {e}")
96
+
97
+ def _apply_generation_overrides(self, *, out_dir: Path,
98
+ image_paths: Optional[List[str]]=None,
99
+ prompts: Optional[List[str]]=None,
100
+ final_prompt: Optional[str]=None,
101
+ cfg_scale: Optional[float]=None,
102
+ aspect_ratio_input: Optional[str]=None,
103
+ resolution_input: Optional[int]=None,
104
+ steps: Optional[int]=None):
105
+ self._make_writable(self.gen.config)
106
+ g = self.gen.config.generation
107
+ self._make_writable(g); self._make_writable(g.output); self._make_writable(g.positive_prompt)
108
+ with open_dict(g):
109
+ g.output.dir = str(out_dir)
110
+ if image_paths is not None: g.positive_prompt.image_path = list(image_paths)
111
+ if prompts is not None: g.positive_prompt.prompts = list(prompts)
112
+ if cfg_scale is not None:
113
+ try: g.cfg_scale = float(cfg_scale)
114
+ except Exception:
115
+ with open_dict(self.gen.config):
116
+ try: self.gen.config.diffusion.cfg.scale = float(cfg_scale)
117
+ except Exception: print("[vince_server] Warning: unable to set cfg_scale")
118
+ if aspect_ratio_input is not None: g.aspect_ratio_input = str(aspect_ratio_input)
119
+ if resolution_input is not None:
120
+ try: g.resolution_input = int(resolution_input)
121
+ except Exception:
122
+ try: g.resolution = int(resolution_input)
123
+ except Exception: print("[vince_server] Warning: unable to set resolution")
124
+ self._set_steps(steps)
125
+
126
+ def _infer_once(self):
127
+ for name in ("inference_loop","entrypoint","run"):
128
+ fn = getattr(self.gen, name, None)
129
+ if callable(fn): fn(); return
130
+ raise RuntimeError("No valid inference method found on generator")
131
+
132
+ def _cleanup(self):
133
+ try:
134
+ import torch; torch.cuda.synchronize()
135
+ except Exception: pass
136
+ gc.collect()
137
+ try:
138
+ import torch; torch.cuda.empty_cache(); torch.cuda.memory.reset_peak_memory_stats()
139
+ except Exception: pass
140
+
141
+ # APIs
142
+ def generate_multi_turn(self, image_path: str, turns: List[str], *,
143
+ out_dir_name: Optional[str]=None, cfg_scale: Optional[float]=None,
144
+ aspect_ratio_input: Optional[str]=None, resolution_input: Optional[int]=None,
145
+ steps: Optional[int]=None) -> str:
146
+ out_dir = self.output_root / (out_dir_name or f"multi_turn_{Path(image_path).stem}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  out_dir.mkdir(parents=True, exist_ok=True)
148
+ self._apply_generation_overrides(out_dir=out_dir, image_paths=[str(image_path)], prompts=list(turns),
149
+ cfg_scale=cfg_scale, aspect_ratio_input=aspect_ratio_input,
150
+ resolution_input=resolution_input, steps=steps)
151
+ self._infer_once(); self._cleanup(); return str(out_dir)
152
+
153
+ def generate_multi_concept(self, concept_images: List[str], concept_prompts: List[str], final_prompt: str, *,
154
+ out_dir_name: Optional[str]=None, cfg_scale: Optional[float]=None,
155
+ aspect_ratio_input: Optional[str]=None, resolution_input: Optional[int]=None,
156
+ steps: Optional[int]=None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  out_dir = self.output_root / (out_dir_name or "multi_concept")
158
  out_dir.mkdir(parents=True, exist_ok=True)
159
+ prompts_all = list(concept_prompts) + ([final_prompt] if final_prompt else [])
160
+ self._apply_generation_overrides(out_dir=out_dir, image_paths=[str(p) for p in concept_images],
161
+ prompts=prompts_all, final_prompt=final_prompt, cfg_scale=cfg_scale,
162
+ aspect_ratio_input=aspect_ratio_input, resolution_input=resolution_input,
163
+ steps=steps)
164
+ self._infer_once(); self._cleanup(); return str(out_dir)
165
+
166
+ server = VinceServer(
167
+ config_path=os.getenv("VINCE_CONFIG", "/app/VINCIE/configs/generate.yaml"),
168
+ output_root=os.getenv("VINCE_OUTPUT", "/app/outputs"),
169
+ chdir_repo=True,
170
+ )