File size: 10,188 Bytes
9fcad90
f5c99ab
18359c8
b5023f2
3772b14
b5023f2
3772b14
 
 
f5c99ab
5137a03
 
3dce029
5137a03
8434eb9
3dce029
 
 
f5c99ab
8434eb9
f5c99ab
 
3dce029
9fcad90
 
5137a03
 
18359c8
8434eb9
 
 
9a0d6a9
8434eb9
9fcad90
3772b14
9fcad90
18359c8
9fcad90
 
3772b14
18359c8
9fcad90
18359c8
 
9fcad90
18359c8
9fcad90
 
 
18359c8
9fcad90
 
 
 
 
 
 
 
 
 
 
 
 
478560e
 
 
 
 
18359c8
478560e
 
 
9fcad90
3772b14
3dce029
5137a03
3dce029
 
 
18359c8
478560e
9a0d6a9
3772b14
18359c8
8434eb9
9a0d6a9
3dce029
9a0d6a9
 
f5c99ab
9a0d6a9
f5c99ab
 
3dce029
 
 
f5c99ab
8434eb9
18359c8
f5c99ab
 
3772b14
9fcad90
9a0d6a9
18359c8
f5c99ab
8434eb9
 
f5c99ab
8434eb9
 
f5c99ab
18359c8
f5c99ab
3772b14
 
 
 
 
 
 
 
 
 
3dce029
 
8434eb9
3dce029
 
 
 
18359c8
 
3dce029
18359c8
3dce029
9fcad90
 
3dce029
18359c8
3dce029
f5c99ab
 
 
18359c8
3dce029
f5c99ab
3dce029
f5c99ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8434eb9
 
 
 
f5c99ab
 
 
 
9fcad90
f5c99ab
 
 
 
 
8434eb9
f5c99ab
9a0d6a9
f5c99ab
 
 
 
18359c8
f5c99ab
3dce029
 
 
18359c8
9fcad90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# managers/seedvr_manager.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Version: 2.3.2
#
# Esta versão implementa uma correção robusta para o FileNotFoundError da configuração do VAE,
# antecipando a falha, carregando as configurações manualmente e fundindo-as para
# contornar o caminho fixo problemático na biblioteca externa.

import torch
import os
import gc
import logging
import sys
import subprocess
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file
import gradio as gr
import mediapy
from einops import rearrange

from tools.tensor_utils import wavelet_reconstruction

logger = logging.getLogger(__name__)

# --- Gerenciamento de Dependências ---
DEPS_DIR = Path("./deps")
SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"

def setup_seedvr_dependencies():
    """Garante que o repositório do SeedVR seja clonado e esteja disponível no sys.path."""
    if not SEEDVR_REPO_DIR.exists():
        logger.info(f"Repositório SeedVR não encontrado em '{SEEDVR_REPO_DIR}'. Clonando do GitHub...")
        try:
            DEPS_DIR.mkdir(exist_ok=True)
            subprocess.run(["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)], check=True, capture_output=True, text=True)
            logger.info("Repositório SeedVR clonado com sucesso.")
        except subprocess.CalledProcessError as e:
            logger.error(f"Falha ao clonar o repositório SeedVR. Git stderr: {e.stderr}")
            raise RuntimeError("Não foi possível clonar a dependência necessária do SeedVR do GitHub.")
    else:
        logger.info("Repositório SeedVR local encontrado.")
    
    if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
        sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
        logger.info(f"Adicionado '{SEEDVR_REPO_DIR.resolve()}' ao sys.path.")

setup_seedvr_dependencies()

from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.seed import set_seed
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
from omegaconf import OmegaConf

def _load_file_from_url(url, model_dir='./', file_name=None):
    os.makedirs(model_dir, exist_ok=True)
    filename = file_name or os.path.basename(urlparse(url).path)
    cached_file = os.path.abspath(os.path.join(model_dir, filename))
    if not os.path.exists(cached_file):
        logger.info(f'Baixando: "{url}" para {cached_file}')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
    return cached_file

class SeedVrManager:
    """Gerencia o modelo SeedVR para tarefas de Masterização HD."""
    def __init__(self, workspace_dir="deformes_workspace"):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.runner = None
        self.workspace_dir = workspace_dir
        self.is_initialized = False
        logger.info("SeedVrManager inicializado. O modelo será carregado sob demanda.")

    def _download_models_and_configs(self):
        """Baixa os checkpoints necessários E o arquivo de configuração do VAE que pode estar faltando."""
        logger.info("Verificando e baixando modelos e configurações do SeedVR2...")
        ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
        config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
        ckpt_dir.mkdir(exist_ok=True)
        config_dir.mkdir(parents=True, exist_ok=True)
        _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
        pretrain_model_urls = {
            'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
            'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
            'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
            'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
            'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
        }
        for key, url in pretrain_model_urls.items():
            _load_file_from_url(url=url, model_dir=str(ckpt_dir))
        logger.info("Modelos e configurações do SeedVR2 baixados com sucesso.")

    def _initialize_runner(self, model_version: str):
        """Carrega e configura o modelo SeedVR, com uma correção robusta para o caminho da config do VAE."""
        if self.runner is not None: return
        self._download_models_and_configs()
        logger.info(f"Inicializando o executor do SeedVR2 {model_version}...")
        if model_version == '3B':
            config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
            checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
        elif model_version == '7B':
            config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
            checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
        else:
            raise ValueError(f"Versão do modelo SeedVR não suportada: {model_version}")

        try:
            config = load_config(str(config_path))
        except FileNotFoundError:
            logger.warning("FileNotFoundError esperado capturado. Carregando config manualmente.")
            config = OmegaConf.load(str(config_path))
            correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
            vae_config = OmegaConf.load(str(correct_vae_config_path))
            config.vae = vae_config
            logger.info("Configuração carregada e corrigida manualmente com sucesso.")

        self.runner = VideoDiffusionInfer(config)
        OmegaConf.set_readonly(self.runner.config, False)
        self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
        self.runner.configure_vae_model()
        if hasattr(self.runner.vae, "set_memory_limit"):
            self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
        self.is_initialized = True
        logger.info(f"Executor para SeedVR2 {model_version} inicializado e pronto.")
    
    def _unload_runner(self):
        """Remove o executor da VRAM para liberar recursos."""
        if self.runner is not None:
            del self.runner; self.runner = None
            gc.collect(); torch.cuda.empty_cache()
            self.is_initialized = False
            logger.info("Executor do SeedVR2 descarregado da VRAM.")

    def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
                      model_version: str = '3B', steps: int = 50, seed: int = 666, 
                      progress: gr.Progress = None) -> str:
        """Aplica o aprimoramento HD a um vídeo usando a lógica do SeedVR."""
        try:
            self._initialize_runner(model_version)
            set_seed(seed, same_across_ranks=True)
            self.runner.config.diffusion.timesteps.sampling.steps = steps
            self.runner.configure_diffusion()
            video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
            res_h, res_w = video_tensor.shape[-2:]
            video_transform = Compose([
                NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
                Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
                DivisibleCrop((16, 16)),
                Normalize(0.5, 0.5),
                Rearrange("t c h w -> c t h w"),
            ])
            cond_latents = [video_transform(video_tensor.to(self.device))]
            input_videos = cond_latents
            self.runner.dit.to("cpu")
            self.runner.vae.to(self.device)
            cond_latents = self.runner.vae_encode(cond_latents)
            self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
            self.runner.dit.to(self.device)
            pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
            neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
            text_pos_embeds = torch.load(pos_emb_path).to(self.device)
            text_neg_embeds = torch.load(neg_emb_path).to(self.device)
            text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
            noises = [torch.randn_like(latent) for latent in cond_latents]
            conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
            with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
                video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
            self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
            self.runner.vae.to(self.device)
            samples = self.runner.vae_decode(video_tensors)
            final_sample = samples[0]
            input_video_sample = input_videos[0]
            if final_sample.shape[1] < input_video_sample.shape[1]:
                input_video_sample = input_video_sample[:, :final_sample.shape[1]]
            final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
            final_sample = rearrange(final_sample, "t c h w -> t h w c")
            final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
            final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
            mediapy.write_video(output_video_path, final_sample_np, fps=24)
            logger.info(f"Vídeo Masterizado em HD salvo em: {output_video_path}")
            return output_video_path
        finally:
            self._unload_runner()

# --- Instância Singleton ---
seedvr_manager_singleton = SeedVrManager()