euiia commited on
Commit
c80b8f1
·
verified ·
1 Parent(s): 62b888b

Update upscaler_specialist.py

Browse files
Files changed (1) hide show
  1. upscaler_specialist.py +35 -39
upscaler_specialist.py CHANGED
@@ -10,53 +10,49 @@ from ltx_manager_helpers import ltx_manager_singleton
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
13
  class UpscalerSpecialist:
14
- """
15
- Especialista responsável por aumentar a resolução espacial de tensores latentes
16
- usando o LTX Video Spatial Upscaler.
17
- """
18
- def __init__(self, base_vae):
19
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  self.pipe_upsample = None
21
- if base_vae is not None:
22
- logger.info("Inicializando Especialista de Upscale Latente...")
 
 
 
23
  try:
 
 
 
 
 
 
 
 
 
 
 
 
24
  self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
25
  "linoyts/LTX-Video-spatial-upscaler-0.9.8",
26
- vae=base_vae,
27
- torch_dtype=torch.bfloat16,
28
  ).to(self.device)
29
- logger.info("Especialista de Upscale Latente pronto.")
30
  except Exception as e:
31
- logger.error(f"Falha ao carregar o modelo de upscale: {e}", exc_info=True)
32
- else:
33
- logger.warning("VAE base não fornecido. Especialista de Upscale desativado.")
34
 
35
- @torch.no_grad()
36
  def upscale(self, latents: torch.Tensor) -> torch.Tensor:
37
- """
38
- Aplica o upscaling 2x nos tensores latentes fornecidos.
39
- """
40
  if self.pipe_upsample is None:
41
- logger.warning("Upscaler não está disponível. Retornando latentes originais.")
 
 
 
 
 
 
 
42
  return latents
43
-
44
- logger.info(f"Upscaler: Recebeu latentes com shape {latents.shape}. Aplicando upscale 2x...")
45
-
46
- # O upscaler opera em um batch de latentes.
47
- upscaled_latents = self.pipe_upsample(
48
- latents=latents,
49
- output_type="latent"
50
- ).frames
51
-
52
- logger.info(f"Upscaler: Latentes redimensionados para {upscaled_latents.shape}.")
53
- return upscaled_latents
54
-
55
- # Instanciação Singleton
56
- # Depende do VAE do ltx_manager, então o obtemos de lá.
57
- try:
58
- base_vae_for_upscaler = ltx_manager_singleton.workers[0].pipeline.vae
59
- upscaler_specialist_singleton = UpscalerSpecialist(base_vae=base_vae_for_upscaler)
60
- except Exception as e:
61
- logger.error(f"Não foi possível inicializar o UpscalerSpecialist Singleton: {e}")
62
- upscaler_specialist_singleton = None
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+
14
  class UpscalerSpecialist:
15
+ def __init__(self, device="cuda"):
16
+ self.device = device if torch.cuda.is_available() else "cpu"
 
 
 
 
17
  self.pipe_upsample = None
18
+ self.base_vae = None
19
+
20
+ def _lazy_init(self):
21
+ """Inicializa o VAE e o pipeline somente quando for chamado."""
22
+ if self.base_vae is None:
23
  try:
24
+ from ltx_manager_helpers import ltx_manager_singleton
25
+ if ltx_manager_singleton.workers:
26
+ self.base_vae = ltx_manager_singleton.workers[0].pipeline.vae
27
+ else:
28
+ logger.warning("[Upscaler] Nenhum worker disponível no ltx_manager_singleton.")
29
+ except Exception as e:
30
+ logger.error(f"[Upscaler] Falha ao inicializar VAE: {e}")
31
+ return
32
+
33
+ if self.pipe_upsample is None and self.base_vae is not None:
34
+ try:
35
+ from ltx_video.pipelines.latent_upscale import LTXLatentUpsamplePipeline
36
  self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
37
  "linoyts/LTX-Video-spatial-upscaler-0.9.8",
38
+ vae=self.base_vae,
39
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
40
  ).to(self.device)
41
+ logger.info("[Upscaler] Pipeline carregado com sucesso.")
42
  except Exception as e:
43
+ logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}")
 
 
44
 
 
45
  def upscale(self, latents: torch.Tensor) -> torch.Tensor:
46
+ self._lazy_init()
 
 
47
  if self.pipe_upsample is None:
48
+ logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.")
49
+ return latents
50
+ try:
51
+ with torch.no_grad():
52
+ result = self.pipe_upsample(latents=latents, output_type="latent")
53
+ return result.latents
54
+ except Exception as e:
55
+ logger.error(f"[Upscaler] Erro durante upscale: {e}")
56
  return latents
57
+
58
+