File size: 4,122 Bytes
fb56537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# aduc_framework/managers/vae_manager.py
#
# Versão 2.1.0 (Correção de Timestep no Decode)
# Copyright (C) August 4, 2025  Carlos Rodrigues dos Santos
#
# - Corrige um `AssertionError` na função `decode` ao não passar o argumento
#   `timestep` esperado pelo decodificador do VAE.
# - Adiciona um `timestep` padrão (0.05) para a decodificação, garantindo
#   uma reconstrução de imagem limpa e estável.

import torch
import logging
import gc
import yaml
from typing import List
from PIL import Image
import numpy as np

from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
from ..tools.hardware_manager import hardware_manager

logger = logging.getLogger(__name__)

class VaeManager:
    """
    Especialista VAE "Hot" e Persistente.
    Carrega o modelo VAE em uma GPU dedicada uma única vez e o mantém lá,
    pronto para processar requisições de encode/decode com latência mínima.
    """
    def __init__(self):
        with open("config.yaml", 'r') as f:
            config = yaml.safe_load(f)
        gpus_required = config['specialists'].get('vae', {}).get('gpus_required', 0)

        if gpus_required > 0 and torch.cuda.is_available():
            device_id = hardware_manager.allocate_gpus('VAE_Manager', gpus_required)[0]
            self.device = torch.device(device_id)
            logger.info(f"VaeManager: GPU dedicada '{device_id}' alocada.")
        else:
            self.device = torch.device('cpu')
            logger.warning("VaeManager: Nenhuma GPU dedicada foi alocada no config.yaml. Operando em modo CPU.")

        try:
            from ..managers.ltx_manager import ltx_manager_singleton
            self.vae = ltx_manager_singleton.workers[0].pipeline.vae
        except ImportError as e:
            logger.critical("Falha ao importar ltx_manager_singleton. Garanta que VaeManager seja importado DEPOIS de LtxManager.", exc_info=True)
            raise e

        self.vae.to(self.device)
        self.vae.eval()
        self.dtype = self.vae.dtype
        logger.info(f"VaeManager inicializado. Modelo VAE está 'quente' e pronto na {self.device} com dtype {self.dtype}.")

    def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
        from PIL import ImageOps
        img = pil_image.convert("RGB")
        processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
        image_np = np.array(processed_img).astype(np.float32) / 255.0
        tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).unsqueeze(2)
        return (tensor * 2.0) - 1.0

    @torch.no_grad()
    def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]:
        if not pil_images:
            return []
        
        latents_list = []
        for img in pil_images:
            pixel_tensor = self._preprocess_pil_image(img, target_resolution)
            pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
            latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
            latents_list.append(latents.cpu())
        return latents_list

    @torch.no_grad()
    def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        """Decodifica um tensor latente para o espaço de pixels."""
        latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
        
        # --- CORREÇÃO APLICADA AQUI ---
        # O modelo espera um tensor de timestep, um para cada item no batch.
        num_items_in_batch = latent_tensor_gpu.shape[0]
        timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
        
        pixels = vae_decode(
            latent_tensor_gpu, 
            self.vae, 
            is_video=True, 
            timestep=timestep_tensor,  # Passando o tensor de timestep
            vae_per_channel_normalize=True
        )
        # --- FIM DA CORREÇÃO ---
        
        return pixels.cpu()

# --- Instância Singleton ---
vae_manager_singleton = VaeManager()