euIaxs22 commited on
Commit
d35bae5
·
verified ·
1 Parent(s): 7fdcb08

Update services/ltx_server.py

Browse files
Files changed (1) hide show
  1. services/ltx_server.py +47 -21
services/ltx_server.py CHANGED
@@ -3,15 +3,22 @@ import subprocess
3
  import sys
4
  import time
5
  import shutil
 
6
  from pathlib import Path
7
  from typing import Optional, Tuple
8
 
9
- # Para o download seletivo, hf_hub_download é mais direto que snapshot_download
10
  from huggingface_hub import snapshot_download
11
 
 
12
  APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
13
 
14
  class LTXServer:
 
 
 
 
 
 
15
  _instance = None
16
 
17
  def __new__(cls, *args, **kwargs):
@@ -20,10 +27,12 @@ class LTXServer:
20
  return cls._instance
21
 
22
  def __init__(self):
23
- if hasattr(self, '_initialized') and self._initialized: return
 
24
 
25
- print("🚀 LTXServer (Q8) inicializando com download seletivo...")
26
 
 
27
  self.LTX_REPO_DIR = Path(os.getenv("LTX_REPO_DIR", "/data/LTX-Video"))
28
  self.LTX_CKPT_DIR = Path(os.getenv("LTX_CKPT_DIR", "/data/ckpt/ltxvideo_q8"))
29
  self.OUTPUT_ROOT = APP_HOME / "outputs" / "ltx"
@@ -33,6 +42,7 @@ class LTXServer:
33
  self.MODEL_REPO_ID_ORIGINAL = "Lightricks/LTX-Video"
34
  self.MODEL_REPO_ID_Q8 = "konakona/ltxvideo_q8"
35
 
 
36
  for p in [self.LTX_REPO_DIR.parent, self.LTX_CKPT_DIR, self.OUTPUT_ROOT, self.HF_HOME_CACHE]:
37
  p.mkdir(parents=True, exist_ok=True)
38
 
@@ -41,20 +51,22 @@ class LTXServer:
41
  print("✅ LTXServer (Q8) pronto.")
42
 
43
  def setup_dependencies(self):
 
44
  self._ensure_repo()
45
- self._ensure_model()
46
 
47
  def _ensure_repo(self) -> None:
 
48
  if not (self.LTX_REPO_DIR / ".git").exists():
49
  print(f"[LTXServer] Clonando repositório para {self.LTX_REPO_DIR}...")
50
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.LTX_REPO_DIR)], check=True)
51
  else:
52
  print("[LTXServer] Repositório LTX-Video já existe.")
53
 
54
- def _ensure_model(self) -> None:
55
  """
56
- Garante que todos os componentes do modelo existam, baixando apenas os
57
- subdiretórios necessários de cada repositório.
58
  """
59
  print(f"[LTXServer] Verificando e baixando componentes em {self.LTX_CKPT_DIR}...")
60
 
@@ -62,12 +74,7 @@ class LTXServer:
62
  cache_dir = str(self.HF_HOME_CACHE)
63
  local_dir = str(self.LTX_CKPT_DIR)
64
 
65
- # ====================================================================
66
- # <<< NOVA LÓGICA DE DOWNLOAD SELETIVO >>>
67
-
68
- # 1. Baixa os componentes base (VAE, Text Encoder, Scheduler) do repo original.
69
- # O padrão "glob" `componente/*` garante que baixemos todo o conteúdo de cada pasta.
70
- componentes_base = ["vae/*", "text_encoder/*", "scheduler/*"]
71
  print(f"[LTXServer] Baixando componentes base de '{self.MODEL_REPO_ID_ORIGINAL}'...")
72
  snapshot_download(
73
  repo_id=self.MODEL_REPO_ID_ORIGINAL,
@@ -75,12 +82,12 @@ class LTXServer:
75
  cache_dir=cache_dir,
76
  repo_type='model',
77
  token=token,
78
- allow_patterns=componentes_base,
79
  resume_download=True,
80
  )
81
  print("[LTXServer] Componentes base (VAE, T5, Scheduler) prontos.")
82
 
83
- # 2. Baixa apenas o UNet quantizado (Q8) do repo do konakona.
84
  print(f"[LTXServer] Baixando UNet quantizado (Q8) de '{self.MODEL_REPO_ID_Q8}'...")
85
  snapshot_download(
86
  repo_id=self.MODEL_REPO_ID_Q8,
@@ -88,18 +95,33 @@ class LTXServer:
88
  cache_dir=cache_dir,
89
  repo_type='model',
90
  token=token,
91
- allow_patterns=["unet/*"], # Baixa apenas a pasta unet
92
  resume_download=True,
93
  )
94
  print("[LTXServer] UNet quantizado (Q8) pronto.")
95
- # ====================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  print("[LTXServer] Todos os componentes do modelo foram baixados e mesclados com sucesso.")
98
 
99
  def run_inference(self, prompt: str, image_path: str, height: int, width: int, num_frames: int, seed: int) -> str:
100
- # ... (O resto da classe permanece o mesmo da resposta anterior)
101
  script_path = self.LTX_REPO_DIR / "inference.py"
102
- if not script_path.exists(): raise FileNotFoundError(f"Script de inferência não encontrado: {script_path}")
 
103
 
104
  job_output_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{os.urandom(4).hex()}"
105
  job_output_dir.mkdir(parents=True)
@@ -123,8 +145,12 @@ class LTXServer:
123
 
124
  try:
125
  subprocess.run(
126
- cmd, cwd=str(self.LTX_REPO_DIR), check=True,
127
- env=os.environ.copy(), stdout=sys.stdout, stderr=sys.stderr
 
 
 
 
128
  )
129
  except Exception as e:
130
  print(f"[LTXServer] Erro na execução da inferência: {e}")
 
3
  import sys
4
  import time
5
  import shutil
6
+ import json # Importa o módulo json para manipular o arquivo de configuração
7
  from pathlib import Path
8
  from typing import Optional, Tuple
9
 
 
10
  from huggingface_hub import snapshot_download
11
 
12
+ # Define a raiz da aplicação a partir de variáveis de ambiente para robustez.
13
  APP_HOME = Path(os.environ.get("APP_HOME", "/app"))
14
 
15
  class LTXServer:
16
+ """
17
+ Gerencia o setup e a execução da inferência para o LTX-Video Q8.
18
+ - Baixa os componentes de modelo de dois repositórios diferentes.
19
+ - Aplica um patch no config.json do VAE para garantir compatibilidade.
20
+ - Executa o script 'inference.py' como um subprocesso.
21
+ """
22
  _instance = None
23
 
24
  def __new__(cls, *args, **kwargs):
 
27
  return cls._instance
28
 
29
  def __init__(self):
30
+ if hasattr(self, '_initialized') and self._initialized:
31
+ return
32
 
33
+ print("🚀 LTXServer (Q8) inicializando e preparando o ambiente...")
34
 
35
+ # Define os caminhos principais
36
  self.LTX_REPO_DIR = Path(os.getenv("LTX_REPO_DIR", "/data/LTX-Video"))
37
  self.LTX_CKPT_DIR = Path(os.getenv("LTX_CKPT_DIR", "/data/ckpt/ltxvideo_q8"))
38
  self.OUTPUT_ROOT = APP_HOME / "outputs" / "ltx"
 
42
  self.MODEL_REPO_ID_ORIGINAL = "Lightricks/LTX-Video"
43
  self.MODEL_REPO_ID_Q8 = "konakona/ltxvideo_q8"
44
 
45
+ # Garante que os diretórios existam
46
  for p in [self.LTX_REPO_DIR.parent, self.LTX_CKPT_DIR, self.OUTPUT_ROOT, self.HF_HOME_CACHE]:
47
  p.mkdir(parents=True, exist_ok=True)
48
 
 
51
  print("✅ LTXServer (Q8) pronto.")
52
 
53
  def setup_dependencies(self):
54
+ """Orquestra o setup: clona o repo, baixa os modelos e aplica o patch."""
55
  self._ensure_repo()
56
+ self._ensure_model_and_patch_config()
57
 
58
  def _ensure_repo(self) -> None:
59
+ """Clona o repositório do LTX-Video se ele não existir."""
60
  if not (self.LTX_REPO_DIR / ".git").exists():
61
  print(f"[LTXServer] Clonando repositório para {self.LTX_REPO_DIR}...")
62
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.LTX_REPO_DIR)], check=True)
63
  else:
64
  print("[LTXServer] Repositório LTX-Video já existe.")
65
 
66
+ def _ensure_model_and_patch_config(self) -> None:
67
  """
68
+ Garante que todos os componentes existam e aplica um patch no config.json
69
+ do VAE para corresponder ao que o código do fork espera.
70
  """
71
  print(f"[LTXServer] Verificando e baixando componentes em {self.LTX_CKPT_DIR}...")
72
 
 
74
  cache_dir = str(self.HF_HOME_CACHE)
75
  local_dir = str(self.LTX_CKPT_DIR)
76
 
77
+ # 1. Baixa os componentes base (VAE, Text Encoder, Scheduler) do repo original
 
 
 
 
 
78
  print(f"[LTXServer] Baixando componentes base de '{self.MODEL_REPO_ID_ORIGINAL}'...")
79
  snapshot_download(
80
  repo_id=self.MODEL_REPO_ID_ORIGINAL,
 
82
  cache_dir=cache_dir,
83
  repo_type='model',
84
  token=token,
85
+ ignore_patterns=["unet/*", "*.safetensors"],
86
  resume_download=True,
87
  )
88
  print("[LTXServer] Componentes base (VAE, T5, Scheduler) prontos.")
89
 
90
+ # 2. Baixa apenas o UNet quantizado (Q8) do repo do konakona
91
  print(f"[LTXServer] Baixando UNet quantizado (Q8) de '{self.MODEL_REPO_ID_Q8}'...")
92
  snapshot_download(
93
  repo_id=self.MODEL_REPO_ID_Q8,
 
95
  cache_dir=cache_dir,
96
  repo_type='model',
97
  token=token,
98
+ allow_patterns=["unet/*"],
99
  resume_download=True,
100
  )
101
  print("[LTXServer] UNet quantizado (Q8) pronto.")
102
+
103
+ # 3. Aplica o patch no config.json do VAE para resolver o AssertionError
104
+ vae_config_path = self.LTX_CKPT_DIR / "vae" / "config.json"
105
+ if vae_config_path.exists():
106
+ print(f"[LTXServer] Aplicando patch de compatibilidade ao '{vae_config_path.name}' do VAE...")
107
+ with open(vae_config_path, 'r+') as f:
108
+ config_data = json.load(f)
109
+ # Adiciona ou sobrescreve a chave para corresponder ao que o código do fork espera
110
+ config_data["_class_name"] = "CausalVideoAutoencoder"
111
+ f.seek(0)
112
+ json.dump(config_data, f, indent=4)
113
+ f.truncate()
114
+ print("[LTXServer] Patch do config.json aplicado com sucesso.")
115
+ else:
116
+ # Isso seria um erro crítico, pois significa que o download falhou
117
+ raise FileNotFoundError(f"Não foi possível encontrar {vae_config_path} para aplicar o patch.")
118
 
119
  print("[LTXServer] Todos os componentes do modelo foram baixados e mesclados com sucesso.")
120
 
121
  def run_inference(self, prompt: str, image_path: str, height: int, width: int, num_frames: int, seed: int) -> str:
 
122
  script_path = self.LTX_REPO_DIR / "inference.py"
123
+ if not script_path.exists():
124
+ raise FileNotFoundError(f"Script de inferência não encontrado: {script_path}")
125
 
126
  job_output_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{os.urandom(4).hex()}"
127
  job_output_dir.mkdir(parents=True)
 
145
 
146
  try:
147
  subprocess.run(
148
+ cmd,
149
+ cwd=str(self.LTX_REPO_DIR),
150
+ check=True,
151
+ env=os.environ.copy(),
152
+ stdout=sys.stdout,
153
+ stderr=sys.stderr
154
  )
155
  except Exception as e:
156
  print(f"[LTXServer] Erro na execução da inferência: {e}")