Aduc-sdr-2_5s / services /vince_server.py
euIaxs22's picture
Update services/vince_server.py
654e6f7 verified
raw
history blame
8.62 kB
#!/usr/bin/env python3
import os, sys, gc, subprocess
from pathlib import Path
from typing import List, Optional
from omegaconf import OmegaConf, open_dict
VINCIE_DIR = Path(os.getenv("VINCIE_DIR", "/app/VINCIE"))
if str(VINCIE_DIR) not in sys.path:
sys.path.insert(0, str(VINCIE_DIR))
# inclui 'models/' relativo
try:
app_models = Path("/app/models"); vincie_models = VINCIE_DIR / "models"
if not app_models.exists() and vincie_models.exists():
app_models.symlink_to(vincie_models, target_is_directory=True)
except Exception as e:
print("[vince_server] warn: link /app/models failed:", e)
from common.config import load_config, create_object # type: ignore
class VinceServer:
def __init__(self, config_path: str="/app/VINCIE/configs/generate.yaml",
*, base_overrides: Optional[List[str]]=None,
output_root: str="/app/outputs", chdir_repo: bool=True):
self.config_path = config_path
self.output_root = Path(output_root); self.output_root.mkdir(parents=True, exist_ok=True)
overrides = list(base_overrides or [])
if chdir_repo:
try: os.chdir(str(VINCIE_DIR))
except Exception as e: print("[vince_server] warn: chdir repo failed:", e)
try:
self._load_and_bootstrap(overrides)
except Exception as e:
print("[vince_server] bootstrap failed; repairing symlinks:", e)
self._repair_ckpt_links()
self._load_and_bootstrap(overrides)
def _load_and_bootstrap(self, overrides: List[str]):
self._assert_ckpt_ready()
self.config = load_config(self.config_path, overrides)
self.gen = create_object(self.config)
self._bootstrap_models()
def _repair_ckpt_links(self):
# reusa snapshot atual para recriar links idempotentes
from services.vincie import VincieService
svc = VincieService(); svc.ensure_model()
snapshot = Path(str(svc.ckpt_dir))
for link in (VINCIE_DIR/"ckpt"/"VINCIE-3B", Path("/app/ckpt")/"VINCIE-3B"):
link.parent.mkdir(parents=True, exist_ok=True)
if link.is_symlink() and link.resolve()!=snapshot:
link.unlink()
if link.exists() and (not link.is_symlink()):
subprocess.run(["rm","-rf",str(link)], check=True)
if not link.exists():
link.symlink_to(snapshot, target_is_directory=True)
print("[vince_server] ckpt symlinks refreshed")
def _assert_ckpt_ready(self):
# ambos caminhos funcionam; a config usa 'ckpt/...' com cwd no repo
repo_link = VINCIE_DIR / "ckpt" / "VINCIE-3B"
if not repo_link.exists():
raise RuntimeError("missing ckpt link: /app/VINCIE/ckpt/VINCIE-3B")
must = [repo_link/"dit.pth", repo_link/"vae.pth", repo_link/"llm14b"]
missing = [str(p) for p in must if not p.exists()]
if missing:
raise RuntimeError(f"missing ckpt content: {missing}")
@staticmethod
def _make_writable(cfg):
try: OmegaConf.set_readonly(cfg, False); OmegaConf.set_struct(cfg, False)
except Exception: pass
def _bootstrap_models(self):
for name in ("configure_persistence","configure_models","configure_diffusion","configure_sampler"):
fn = getattr(self.gen, name, None)
if not callable(fn): raise RuntimeError(f"[vince_server] missing step: {name}")
fn()
if not hasattr(self.gen, "sampler"):
raise RuntimeError("[vince_server] missing component: sampler")
def _set_steps(self, steps: Optional[int]):
if not steps: return
sampler = getattr(self.gen, "sampler", None); t = getattr(sampler, "timesteps", None)
if sampler is None or t is None: return
try:
import torch
if hasattr(t,"__len__") and len(t)>0:
steps = max(1, min(int(steps), len(t)))
if steps < len(t):
idx = torch.linspace(0, len(t)-1, steps).round().long().tolist()
sampler.timesteps = [t[i] for i in idx]
except Exception as e:
print(f"[vince_server] Warning: set_steps failed: {e}")
def _apply_generation_overrides(self, *, out_dir: Path,
image_paths: Optional[List[str]]=None,
prompts: Optional[List[str]]=None,
final_prompt: Optional[str]=None,
cfg_scale: Optional[float]=None,
aspect_ratio_input: Optional[str]=None,
resolution_input: Optional[int]=None,
steps: Optional[int]=None):
self._make_writable(self.gen.config)
g = self.gen.config.generation
self._make_writable(g); self._make_writable(g.output); self._make_writable(g.positive_prompt)
with open_dict(g):
g.output.dir = str(out_dir)
if image_paths is not None: g.positive_prompt.image_path = list(image_paths)
if prompts is not None: g.positive_prompt.prompts = list(prompts)
if cfg_scale is not None:
try: g.cfg_scale = float(cfg_scale)
except Exception:
with open_dict(self.gen.config):
try: self.gen.config.diffusion.cfg.scale = float(cfg_scale)
except Exception: print("[vince_server] Warning: unable to set cfg_scale")
if aspect_ratio_input is not None: g.aspect_ratio_input = str(aspect_ratio_input)
if resolution_input is not None:
try: g.resolution_input = int(resolution_input)
except Exception:
try: g.resolution = int(resolution_input)
except Exception: print("[vince_server] Warning: unable to set resolution")
self._set_steps(steps)
def _infer_once(self):
for name in ("inference_loop","entrypoint","run"):
fn = getattr(self.gen, name, None)
if callable(fn): fn(); return
raise RuntimeError("No valid inference method found on generator")
def _cleanup(self):
try:
import torch; torch.cuda.synchronize()
except Exception: pass
gc.collect()
try:
import torch; torch.cuda.empty_cache(); torch.cuda.memory.reset_peak_memory_stats()
except Exception: pass
# APIs
def generate_multi_turn(self, image_path: str, turns: List[str], *,
out_dir_name: Optional[str]=None, cfg_scale: Optional[float]=None,
aspect_ratio_input: Optional[str]=None, resolution_input: Optional[int]=None,
steps: Optional[int]=None) -> str:
out_dir = self.output_root / (out_dir_name or f"multi_turn_{Path(image_path).stem}")
out_dir.mkdir(parents=True, exist_ok=True)
self._apply_generation_overrides(out_dir=out_dir, image_paths=[str(image_path)], prompts=list(turns),
cfg_scale=cfg_scale, aspect_ratio_input=aspect_ratio_input,
resolution_input=resolution_input, steps=steps)
self._infer_once(); self._cleanup(); return str(out_dir)
def generate_multi_concept(self, concept_images: List[str], concept_prompts: List[str], final_prompt: str, *,
out_dir_name: Optional[str]=None, cfg_scale: Optional[float]=None,
aspect_ratio_input: Optional[str]=None, resolution_input: Optional[int]=None,
steps: Optional[int]=None) -> str:
out_dir = self.output_root / (out_dir_name or "multi_concept")
out_dir.mkdir(parents=True, exist_ok=True)
prompts_all = list(concept_prompts) + ([final_prompt] if final_prompt else [])
self._apply_generation_overrides(out_dir=out_dir, image_paths=[str(p) for p in concept_images],
prompts=prompts_all, final_prompt=final_prompt, cfg_scale=cfg_scale,
aspect_ratio_input=aspect_ratio_input, resolution_input=resolution_input,
steps=steps)
self._infer_once(); self._cleanup(); return str(out_dir)
server = VinceServer(
config_path=os.getenv("VINCE_CONFIG", "/app/VINCIE/configs/generate.yaml"),
output_root=os.getenv("VINCE_OUTPUT", "/app/outputs"),
chdir_repo=True,
)