Spaces:
Runtime error
Runtime error
File size: 10,188 Bytes
9fcad90 f5c99ab 18359c8 b5023f2 3772b14 b5023f2 3772b14 f5c99ab 5137a03 3dce029 5137a03 8434eb9 3dce029 f5c99ab 8434eb9 f5c99ab 3dce029 9fcad90 5137a03 18359c8 8434eb9 9a0d6a9 8434eb9 9fcad90 3772b14 9fcad90 18359c8 9fcad90 3772b14 18359c8 9fcad90 18359c8 9fcad90 18359c8 9fcad90 18359c8 9fcad90 478560e 18359c8 478560e 9fcad90 3772b14 3dce029 5137a03 3dce029 18359c8 478560e 9a0d6a9 3772b14 18359c8 8434eb9 9a0d6a9 3dce029 9a0d6a9 f5c99ab 9a0d6a9 f5c99ab 3dce029 f5c99ab 8434eb9 18359c8 f5c99ab 3772b14 9fcad90 9a0d6a9 18359c8 f5c99ab 8434eb9 f5c99ab 8434eb9 f5c99ab 18359c8 f5c99ab 3772b14 3dce029 8434eb9 3dce029 18359c8 3dce029 18359c8 3dce029 9fcad90 3dce029 18359c8 3dce029 f5c99ab 18359c8 3dce029 f5c99ab 3dce029 f5c99ab 8434eb9 f5c99ab 9fcad90 f5c99ab 8434eb9 f5c99ab 9a0d6a9 f5c99ab 18359c8 f5c99ab 3dce029 18359c8 9fcad90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# managers/seedvr_manager.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Version: 2.3.2
#
# Esta versão implementa uma correção robusta para o FileNotFoundError da configuração do VAE,
# antecipando a falha, carregando as configurações manualmente e fundindo-as para
# contornar o caminho fixo problemático na biblioteca externa.
import torch
import os
import gc
import logging
import sys
import subprocess
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file
import gradio as gr
import mediapy
from einops import rearrange
from tools.tensor_utils import wavelet_reconstruction
logger = logging.getLogger(__name__)
# --- Gerenciamento de Dependências ---
DEPS_DIR = Path("./deps")
SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
def setup_seedvr_dependencies():
"""Garante que o repositório do SeedVR seja clonado e esteja disponível no sys.path."""
if not SEEDVR_REPO_DIR.exists():
logger.info(f"Repositório SeedVR não encontrado em '{SEEDVR_REPO_DIR}'. Clonando do GitHub...")
try:
DEPS_DIR.mkdir(exist_ok=True)
subprocess.run(["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)], check=True, capture_output=True, text=True)
logger.info("Repositório SeedVR clonado com sucesso.")
except subprocess.CalledProcessError as e:
logger.error(f"Falha ao clonar o repositório SeedVR. Git stderr: {e.stderr}")
raise RuntimeError("Não foi possível clonar a dependência necessária do SeedVR do GitHub.")
else:
logger.info("Repositório SeedVR local encontrado.")
if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
logger.info(f"Adicionado '{SEEDVR_REPO_DIR.resolve()}' ao sys.path.")
setup_seedvr_dependencies()
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.seed import set_seed
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
from omegaconf import OmegaConf
def _load_file_from_url(url, model_dir='./', file_name=None):
os.makedirs(model_dir, exist_ok=True)
filename = file_name or os.path.basename(urlparse(url).path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
logger.info(f'Baixando: "{url}" para {cached_file}')
download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
return cached_file
class SeedVrManager:
"""Gerencia o modelo SeedVR para tarefas de Masterização HD."""
def __init__(self, workspace_dir="deformes_workspace"):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.runner = None
self.workspace_dir = workspace_dir
self.is_initialized = False
logger.info("SeedVrManager inicializado. O modelo será carregado sob demanda.")
def _download_models_and_configs(self):
"""Baixa os checkpoints necessários E o arquivo de configuração do VAE que pode estar faltando."""
logger.info("Verificando e baixando modelos e configurações do SeedVR2...")
ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
ckpt_dir.mkdir(exist_ok=True)
config_dir.mkdir(parents=True, exist_ok=True)
_load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
pretrain_model_urls = {
'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
}
for key, url in pretrain_model_urls.items():
_load_file_from_url(url=url, model_dir=str(ckpt_dir))
logger.info("Modelos e configurações do SeedVR2 baixados com sucesso.")
def _initialize_runner(self, model_version: str):
"""Carrega e configura o modelo SeedVR, com uma correção robusta para o caminho da config do VAE."""
if self.runner is not None: return
self._download_models_and_configs()
logger.info(f"Inicializando o executor do SeedVR2 {model_version}...")
if model_version == '3B':
config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
elif model_version == '7B':
config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
else:
raise ValueError(f"Versão do modelo SeedVR não suportada: {model_version}")
try:
config = load_config(str(config_path))
except FileNotFoundError:
logger.warning("FileNotFoundError esperado capturado. Carregando config manualmente.")
config = OmegaConf.load(str(config_path))
correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
vae_config = OmegaConf.load(str(correct_vae_config_path))
config.vae = vae_config
logger.info("Configuração carregada e corrigida manualmente com sucesso.")
self.runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(self.runner.config, False)
self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
self.runner.configure_vae_model()
if hasattr(self.runner.vae, "set_memory_limit"):
self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
self.is_initialized = True
logger.info(f"Executor para SeedVR2 {model_version} inicializado e pronto.")
def _unload_runner(self):
"""Remove o executor da VRAM para liberar recursos."""
if self.runner is not None:
del self.runner; self.runner = None
gc.collect(); torch.cuda.empty_cache()
self.is_initialized = False
logger.info("Executor do SeedVR2 descarregado da VRAM.")
def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
model_version: str = '3B', steps: int = 50, seed: int = 666,
progress: gr.Progress = None) -> str:
"""Aplica o aprimoramento HD a um vídeo usando a lógica do SeedVR."""
try:
self._initialize_runner(model_version)
set_seed(seed, same_across_ranks=True)
self.runner.config.diffusion.timesteps.sampling.steps = steps
self.runner.configure_diffusion()
video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
res_h, res_w = video_tensor.shape[-2:]
video_transform = Compose([
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
cond_latents = [video_transform(video_tensor.to(self.device))]
input_videos = cond_latents
self.runner.dit.to("cpu")
self.runner.vae.to(self.device)
cond_latents = self.runner.vae_encode(cond_latents)
self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
self.runner.dit.to(self.device)
pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
text_pos_embeds = torch.load(pos_emb_path).to(self.device)
text_neg_embeds = torch.load(neg_emb_path).to(self.device)
text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
noises = [torch.randn_like(latent) for latent in cond_latents]
conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
self.runner.vae.to(self.device)
samples = self.runner.vae_decode(video_tensors)
final_sample = samples[0]
input_video_sample = input_videos[0]
if final_sample.shape[1] < input_video_sample.shape[1]:
input_video_sample = input_video_sample[:, :final_sample.shape[1]]
final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
final_sample = rearrange(final_sample, "t c h w -> t h w c")
final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
mediapy.write_video(output_video_path, final_sample_np, fps=24)
logger.info(f"Vídeo Masterizado em HD salvo em: {output_video_path}")
return output_video_path
finally:
self._unload_runner()
# --- Instância Singleton ---
seedvr_manager_singleton = SeedVrManager() |