Aduc-sdr commited on
Commit
18359c8
·
verified ·
1 Parent(s): 1469a4f

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +44 -39
managers/seedvr_manager.py CHANGED
@@ -1,22 +1,12 @@
1
  # managers/seedvr_manager.py
2
- # AducSdr: Uma implementação aberta e funcional da arquitetura ADUC-SDR
3
- # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Contato:
6
- # Carlos Rodrigues dos Santos
7
8
- # Rua Eduardo Carlos Pereira, 4125, B1 Ap32, Curitiba, PR, Brazil, CEP 8102025
9
  #
10
- # Repositórios e Projetos Relacionados:
11
- # GitHub: https://github.com/carlex22/Aduc-sdr
12
  #
13
- # PENDING PATENT NOTICE: Please see NOTICE.md.
14
- #
15
- # Version: 2.3.0
16
- #
17
- # This file implements the SeedVrManager, which uses the SeedVR model for
18
- # video super-resolution. It is self-contained, automatically cloning its own
19
- # dependencies from the official SeedVR repository.
20
 
21
  import torch
22
  import os
@@ -31,41 +21,46 @@ import gradio as gr
31
  import mediapy
32
  from einops import rearrange
33
 
 
34
  from tools.tensor_utils import wavelet_reconstruction
35
 
36
  logger = logging.getLogger(__name__)
37
 
38
- # --- Dependency Management ---
39
  DEPS_DIR = Path("./deps")
40
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
41
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
 
42
  VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
43
 
44
  def setup_seedvr_dependencies():
45
  """
46
- Ensures the SeedVR repository is cloned and available in the sys.path.
 
47
  """
48
  if not SEEDVR_REPO_DIR.exists():
49
- logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
50
  try:
51
  DEPS_DIR.mkdir(exist_ok=True)
52
  subprocess.run(
53
  ["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
54
  check=True, capture_output=True, text=True
55
  )
56
- logger.info("SeedVR repository cloned successfully.")
57
  except subprocess.CalledProcessError as e:
58
- logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
59
- raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
60
  else:
61
- logger.info("Found local SeedVR repository.")
62
 
63
  if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
64
  sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
65
- logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
66
 
 
67
  setup_seedvr_dependencies()
68
 
 
69
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
70
  from common.config import load_config
71
  from common.seed import set_seed
@@ -78,33 +73,37 @@ from omegaconf import OmegaConf
78
 
79
 
80
  def _load_file_from_url(url, model_dir='./', file_name=None):
 
81
  os.makedirs(model_dir, exist_ok=True)
82
  filename = file_name or os.path.basename(urlparse(url).path)
83
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
84
  if not os.path.exists(cached_file):
85
- logger.info(f'Downloading: "{url}" to {cached_file}')
86
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
87
  return cached_file
88
 
89
  class SeedVrManager:
90
  """
91
- Manages the SeedVR model for HD Mastering tasks.
92
  """
93
  def __init__(self, workspace_dir="deformes_workspace"):
94
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
95
  self.runner = None
96
  self.workspace_dir = workspace_dir
97
  self.is_initialized = False
98
- logger.info("SeedVrManager initialized. Model will be loaded on demand.")
99
 
100
  def _download_models_and_configs(self):
101
- """Downloads the necessary checkpoints AND the missing VAE config file."""
102
- logger.info("Verifying and downloading SeedVR2 models and configs...")
 
 
103
  ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
104
  config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
105
  ckpt_dir.mkdir(exist_ok=True)
106
  config_dir.mkdir(parents=True, exist_ok=True)
107
 
 
108
  _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
109
 
110
  pretrain_model_urls = {
@@ -118,13 +117,15 @@ class SeedVrManager:
118
  for key, url in pretrain_model_urls.items():
119
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
120
 
121
- logger.info("SeedVR2 models and configs downloaded successfully.")
122
 
123
  def _initialize_runner(self, model_version: str):
124
- """Loads and configures the SeedVR model on demand based on the selected version."""
125
  if self.runner is not None: return
 
126
  self._download_models_and_configs()
127
- logger.info(f"Initializing SeedVR2 {model_version} runner...")
 
128
  if model_version == '3B':
129
  config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
130
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
@@ -132,16 +133,18 @@ class SeedVrManager:
132
  config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
133
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
134
  else:
135
- raise ValueError(f"Unsupported SeedVR model version: {model_version}")
136
 
137
  config = load_config(str(config_path))
 
138
  self.runner = VideoDiffusionInfer(config)
139
  OmegaConf.set_readonly(self.runner.config, False)
 
140
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
141
 
142
- # --- PATH CORRECTION ---
143
  correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
144
- logger.info(f"Correcting VAE config path to: {correct_vae_config_path}")
145
  self.runner.config.vae.config = str(correct_vae_config_path)
146
 
147
  self.runner.configure_vae_model()
@@ -150,18 +153,20 @@ class SeedVrManager:
150
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
151
 
152
  self.is_initialized = True
153
- logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
154
-
155
  def _unload_runner(self):
 
156
  if self.runner is not None:
157
  del self.runner; self.runner = None
158
  gc.collect(); torch.cuda.empty_cache()
159
  self.is_initialized = False
160
- logger.info("SeedVR2 runner unloaded from VRAM.")
161
 
162
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
163
  model_version: str = '3B', steps: int = 50, seed: int = 666,
164
  progress: gr.Progress = None) -> str:
 
165
  try:
166
  self._initialize_runner(model_version)
167
  set_seed(seed, same_across_ranks=True)
@@ -204,10 +209,10 @@ class SeedVrManager:
204
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
205
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
206
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
207
- logger.info(f"HD Mastered video saved to: {output_video_path}")
208
  return output_video_path
209
  finally:
210
  self._unload_runner()
211
 
212
- # --- Singleton Instance ---
213
  seedvr_manager_singleton = SeedVrManager()
 
1
  # managers/seedvr_manager.py
 
 
2
  #
3
+ # Copyright (C) 2025 Carlos Rodrigues dos Santos
 
 
 
4
  #
5
+ # Version: 2.3.1
 
6
  #
7
+ # Esta versão adiciona uma correção robusta para o FileNotFoundError da configuração do VAE,
8
+ # sobrescrevendo o caminho fixo incorreto em tempo de execução para apontar para o local
9
+ # correto dentro do repositório clonado.
 
 
 
 
10
 
11
  import torch
12
  import os
 
21
  import mediapy
22
  from einops import rearrange
23
 
24
+ # Utilitário internalizado para correção de cor, garantindo estabilidade.
25
  from tools.tensor_utils import wavelet_reconstruction
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
+ # --- Gerenciamento de Dependências ---
30
  DEPS_DIR = Path("./deps")
31
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
32
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
33
+ # URL direto para o arquivo de configuração do VAE que pode estar faltando
34
  VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
35
 
36
  def setup_seedvr_dependencies():
37
  """
38
+ Garante que o repositório do SeedVR seja clonado e esteja disponível no sys.path.
39
+ Esta função é executada uma vez quando o módulo é importado pela primeira vez.
40
  """
41
  if not SEEDVR_REPO_DIR.exists():
42
+ logger.info(f"Repositório SeedVR não encontrado em '{SEEDVR_REPO_DIR}'. Clonando do GitHub...")
43
  try:
44
  DEPS_DIR.mkdir(exist_ok=True)
45
  subprocess.run(
46
  ["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
47
  check=True, capture_output=True, text=True
48
  )
49
+ logger.info("Repositório SeedVR clonado com sucesso.")
50
  except subprocess.CalledProcessError as e:
51
+ logger.error(f"Falha ao clonar o repositório SeedVR. Git stderr: {e.stderr}")
52
+ raise RuntimeError("Não foi possível clonar a dependência necessária do SeedVR do GitHub.")
53
  else:
54
+ logger.info("Repositório SeedVR local encontrado.")
55
 
56
  if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
57
  sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
58
+ logger.info(f"Adicionado '{SEEDVR_REPO_DIR.resolve()}' ao sys.path.")
59
 
60
+ # --- Executa a configuração da dependência imediatamente na importação do módulo ---
61
  setup_seedvr_dependencies()
62
 
63
+ # --- Agora que o caminho está configurado, podemos importar com segurança do repositório clonado ---
64
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
65
  from common.config import load_config
66
  from common.seed import set_seed
 
73
 
74
 
75
  def _load_file_from_url(url, model_dir='./', file_name=None):
76
+ """Função auxiliar para baixar arquivos de uma URL para um diretório local."""
77
  os.makedirs(model_dir, exist_ok=True)
78
  filename = file_name or os.path.basename(urlparse(url).path)
79
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
80
  if not os.path.exists(cached_file):
81
+ logger.info(f'Baixando: "{url}" para {cached_file}')
82
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
83
  return cached_file
84
 
85
  class SeedVrManager:
86
  """
87
+ Gerencia o modelo SeedVR para tarefas de Masterização HD.
88
  """
89
  def __init__(self, workspace_dir="deformes_workspace"):
90
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
91
  self.runner = None
92
  self.workspace_dir = workspace_dir
93
  self.is_initialized = False
94
+ logger.info("SeedVrManager inicializado. O modelo será carregado sob demanda.")
95
 
96
  def _download_models_and_configs(self):
97
+ """
98
+ Baixa os checkpoints necessários E o arquivo de configuração do VAE que pode estar faltando.
99
+ """
100
+ logger.info("Verificando e baixando modelos e configurações do SeedVR2...")
101
  ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
102
  config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
103
  ckpt_dir.mkdir(exist_ok=True)
104
  config_dir.mkdir(parents=True, exist_ok=True)
105
 
106
+ # Baixa a configuração do VAE para garantir que ela exista
107
  _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
108
 
109
  pretrain_model_urls = {
 
117
  for key, url in pretrain_model_urls.items():
118
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
119
 
120
+ logger.info("Modelos e configurações do SeedVR2 baixados com sucesso.")
121
 
122
  def _initialize_runner(self, model_version: str):
123
+ """Carrega e configura o modelo SeedVR sob demanda com base na versão selecionada."""
124
  if self.runner is not None: return
125
+
126
  self._download_models_and_configs()
127
+
128
+ logger.info(f"Inicializando o executor do SeedVR2 {model_version}...")
129
  if model_version == '3B':
130
  config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
131
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
 
133
  config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
134
  checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
135
  else:
136
+ raise ValueError(f"Versão do modelo SeedVR não suportada: {model_version}")
137
 
138
  config = load_config(str(config_path))
139
+
140
  self.runner = VideoDiffusionInfer(config)
141
  OmegaConf.set_readonly(self.runner.config, False)
142
+
143
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
144
 
145
+ # --- CORREÇÃO APLICADA AQUI ---
146
  correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
147
+ logger.info(f"Corrigindo o caminho da configuração do VAE para: {correct_vae_config_path}")
148
  self.runner.config.vae.config = str(correct_vae_config_path)
149
 
150
  self.runner.configure_vae_model()
 
153
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
154
 
155
  self.is_initialized = True
156
+ logger.info(f"Executor para SeedVR2 {model_version} inicializado e pronto.")
157
+
158
  def _unload_runner(self):
159
+ """Remove o executor da VRAM para liberar recursos."""
160
  if self.runner is not None:
161
  del self.runner; self.runner = None
162
  gc.collect(); torch.cuda.empty_cache()
163
  self.is_initialized = False
164
+ logger.info("Executor do SeedVR2 descarregado da VRAM.")
165
 
166
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
167
  model_version: str = '3B', steps: int = 50, seed: int = 666,
168
  progress: gr.Progress = None) -> str:
169
+ """Aplica o aprimoramento HD a um vídeo usando a lógica do SeedVR."""
170
  try:
171
  self._initialize_runner(model_version)
172
  set_seed(seed, same_across_ranks=True)
 
209
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
210
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
211
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
212
+ logger.info(f"Vídeo Masterizado em HD salvo em: {output_video_path}")
213
  return output_video_path
214
  finally:
215
  self._unload_runner()
216
 
217
+ # --- Instância Singleton ---
218
  seedvr_manager_singleton = SeedVrManager()