File size: 3,736 Bytes
795e89c
bc96793
795e89c
 
 
 
 
dc7f43f
795e89c
 
 
 
bc96793
 
 
 
 
 
be876b2
c80b8f1
bc96793
c80b8f1
5ce2d70
c80b8f1
96817a0
5ce2d70
 
 
 
 
 
 
 
 
 
 
795e89c
5ce2d70
c80b8f1
795e89c
5ce2d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96817a0
5ce2d70
 
 
795e89c
5ce2d70
bc96793
795e89c
bc96793
c80b8f1
795e89c
c80b8f1
 
bc96793
c80b8f1
bc96793
9bfc665
b88a339
 
 
9bfc665
 
 
 
c80b8f1
bc96793
795e89c
c80b8f1
bc96793
 
 
 
9bfc665
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# upscaler_specialist.py
# Copyright (C) 2025 Carlos Rodrigues
# Especialista ADUC para upscaling espacial de tensores latentes.

import torch
import logging
from diffusers import LTXLatentUpsamplePipeline
from managers.ltx_manager import ltx_manager_singleton

logger = logging.getLogger(__name__)

class UpscalerSpecialist:
    """
    Especialista responsável por aumentar a resolução espacial de tensores latentes
    usando o LTX Video Spatial Upscaler.
    """
    def __init__(self):
        # Força uso de CUDA se disponível
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.base_vae = None
        self.pipe_upsample = None


    def _lazy_init(self):
      try:
        # Tenta usar o VAE do ltx_manager
        if ltx_manager_singleton.workers:
            candidate_vae = ltx_manager_singleton.workers[0].pipeline.vae
            if candidate_vae.__class__.__name__ == "AutoencoderKLLTXVideo":
                self.base_vae = candidate_vae
                logger.info("[Upscaler] Usando VAE do ltx_manager (AutoencoderKLLTXVideo).")
            else:
                logger.warning(f"[Upscaler] VAE incompatível: {type(candidate_vae)}. "
                               "Carregando AutoencoderKLLTXVideo manualmente...")
                from diffusers.models.autoencoders import AutoencoderKLLTXVideo
                self.base_vae = AutoencoderKLLTXVideo.from_pretrained(
                    "linoyts/LTX-Video-spatial-upscaler-0.9.8",
                    subfolder="vae",
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
                ).to(self.device)
        else:
            logger.warning("[Upscaler] Nenhum worker disponível, carregando VAE manualmente...")
            from diffusers.models.autoencoders import AutoencoderKLLTXVideo
            self.base_vae = AutoencoderKLLTXVideo.from_pretrained(
                "linoyts/LTX-Video-spatial-upscaler-0.9.8",
                subfolder="vae",
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            ).to(self.device)

        # Carregar pipeline
        self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
            "linoyts/LTX-Video-spatial-upscaler-0.9.8",
            vae=self.base_vae,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device)

        logger.info("[Upscaler] Pipeline carregado com sucesso.")

      except Exception as e:
        logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}")
        self.pipe_upsample = None
        

    
    @torch.no_grad()
    def upscale(self, latents: torch.Tensor) -> torch.Tensor:
        """Aplica o upscaling 2x nos tensores latentes fornecidos."""
        self._lazy_init()
        if self.pipe_upsample is None:
            logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.")
            return latents

        try:
            logger.info(f"[Upscaler] Recebido shape {latents.shape}. Executando upscale em {self.device}...")
            
            # [CORREÇÃO FINAL] Conforme a documentação oficial, o resultado está em .frames
            result = self.pipe_upsample(latents=latents, output_type="latent")
            output_tensor = result.frames
            
            logger.info(f"[Upscaler] Upscale concluído. Novo shape: {output_tensor.shape}")
            return output_tensor
        
        except Exception as e:
            logger.error(f"[Upscaler] Erro durante upscale: {e}", exc_info=True)
            return latents


# ---------------------------
# Singleton global
# ---------------------------
upscaler_specialist_singleton = UpscalerSpecialist()