euiia commited on
Commit
8434eb9
·
verified ·
1 Parent(s): 97d3157

Update managers/hd_specialist.py

Browse files
Files changed (1) hide show
  1. managers/hd_specialist.py +79 -36
managers/hd_specialist.py CHANGED
@@ -2,43 +2,34 @@
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
 
 
5
  # This file implements the HD Specialist (Δ+), which uses the SeedVR model
6
- # for video super-resolution. It's designed to be called by the ADUC orchestrator
7
- # to perform the final HD mastering pass on a generated video. It manages the
8
- # loading/unloading of the heavy SeedVR models to conserve VRAM and can switch
9
- # between different model sizes (e.g., 3B and 7B).
10
 
11
  import torch
12
- import gradio as gr
13
- import imageio
14
  import os
15
  import gc
16
  import logging
17
- import numpy as np
18
- from PIL import Image
19
- from tqdm import tqdm
20
- import shlex
21
  import subprocess
22
  from pathlib import Path
23
  from urllib.parse import urlparse
24
  from torch.hub import download_url_to_file
25
- from omegaconf import OmegaConf
26
  import mediapy
27
  from einops import rearrange
28
 
29
- # Assuming these files are in the project structure
30
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
31
- from common.config import load_config
32
- from common.seed import set_seed
33
- from data.image.transforms.divisible_crop import DivisibleCrop
34
- from data.image.transforms.na_resize import NaResize
35
- from data.video.transforms.rearrange import Rearrange
36
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
37
- from torchvision.transforms import Compose, Lambda, Normalize
38
- from torchvision.io.video import read_video
39
-
40
  logger = logging.getLogger(__name__)
41
 
 
 
 
 
 
42
  def _load_file_from_url(url, model_dir='./', file_name=None):
43
  """Helper function to download files from a URL to a local directory."""
44
  os.makedirs(model_dir, exist_ok=True)
@@ -59,12 +50,62 @@ class HDSpecialist:
59
  self.runner = None
60
  self.workspace_dir = workspace_dir
61
  self.is_initialized = False
62
- logger.info("HD Specialist (SeedVR) initialized. Model will be loaded on demand.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def _download_models(self):
65
  """Downloads the necessary checkpoints for SeedVR2."""
66
  logger.info("Verifying and downloading SeedVR2 models...")
67
- ckpt_dir = Path('./ckpts')
68
  ckpt_dir.mkdir(exist_ok=True)
69
 
70
  pretrain_model_urls = {
@@ -76,7 +117,7 @@ class HDSpecialist:
76
  }
77
 
78
  for key, url in pretrain_model_urls.items():
79
- _load_file_from_url(url=url, model_dir='./ckpts/')
80
 
81
  logger.info("SeedVR2 models downloaded successfully.")
82
 
@@ -84,25 +125,26 @@ class HDSpecialist:
84
  """Loads and configures the SeedVR model on demand based on the selected version."""
85
  if self.runner is not None:
86
  return
87
-
 
88
  self._download_models()
89
 
90
  logger.info(f"Initializing SeedVR2 {model_version} runner...")
91
  if model_version == '3B':
92
- config_path = os.path.join('./configs_3b', 'main.yaml')
93
- checkpoint_path = './ckpts/seedvr2_ema_3b.pth'
94
  elif model_version == '7B':
95
- config_path = os.path.join('./configs_7b', 'main.yaml')
96
- checkpoint_path = './ckpts/seedvr2_ema_7b.pth'
97
  else:
98
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
99
 
100
- config = load_config(config_path)
101
 
102
  self.runner = VideoDiffusionInfer(config)
103
  OmegaConf.set_readonly(self.runner.config, False)
104
 
105
- self.runner.configure_dit_model(device=self.device, checkpoint=checkpoint_path)
106
  self.runner.configure_vae_model()
107
 
108
  if hasattr(self.runner.vae, "set_memory_limit"):
@@ -129,7 +171,6 @@ class HDSpecialist:
129
  self._initialize_runner(model_version)
130
  set_seed(seed, same_across_ranks=True)
131
 
132
- # --- Adapted inference logic from SeedVR scripts ---
133
  self.runner.config.diffusion.timesteps.sampling.steps = steps
134
  self.runner.configure_diffusion()
135
 
@@ -153,8 +194,10 @@ class HDSpecialist:
153
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
154
  self.runner.dit.to(self.device)
155
 
156
- text_pos_embeds = torch.load('./ckpts/pos_emb.pt').to(self.device)
157
- text_neg_embeds = torch.load('./ckpts/neg_emb.pt').to(self.device)
 
 
158
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
159
 
160
  noises = [torch.randn_like(latent) for latent in cond_latents]
@@ -176,7 +219,7 @@ class HDSpecialist:
176
  final_sample = samples[0]
177
  input_video_sample = input_videos[0]
178
 
179
- if final_sample.shape[1] < input_video_sample.shape[1]: # if generated frames are less
180
  input_video_sample = input_video_sample[:, :final_sample.shape[1]]
181
 
182
  final_sample = wavelet_reconstruction(
 
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.2.0
6
+ #
7
  # This file implements the HD Specialist (Δ+), which uses the SeedVR model
8
+ # for video super-resolution. It has been refactored to be self-contained by
9
+ # automatically cloning its own dependencies from the official SeedVR repository
10
+ # if they are not found locally. This removes the need for manual file copying
11
+ # and makes the ADUC-SDR framework more robust and portable.
12
 
13
  import torch
 
 
14
  import os
15
  import gc
16
  import logging
17
+ import sys
 
 
 
18
  import subprocess
19
  from pathlib import Path
20
  from urllib.parse import urlparse
21
  from torch.hub import download_url_to_file
22
+ import gradio as gr
23
  import mediapy
24
  from einops import rearrange
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  logger = logging.getLogger(__name__)
27
 
28
+ # --- Dependency Management ---
29
+ DEPS_DIR = Path("./deps")
30
+ SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
31
+ SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
32
+
33
  def _load_file_from_url(url, model_dir='./', file_name=None):
34
  """Helper function to download files from a URL to a local directory."""
35
  os.makedirs(model_dir, exist_ok=True)
 
50
  self.runner = None
51
  self.workspace_dir = workspace_dir
52
  self.is_initialized = False
53
+ self._seedvr_modules_loaded = False
54
+ self._setup_dependencies()
55
+ logger.info("HD Specialist (SeedVR) initialized. Dependencies checked. Model will be loaded on demand.")
56
+
57
+ def _setup_dependencies(self):
58
+ """
59
+ Checks for the SeedVR repository locally. If not found, clones it.
60
+ Then, it adds the repository to the Python path to make its modules importable.
61
+ """
62
+ if not SEEDVR_REPO_DIR.exists():
63
+ logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
64
+ try:
65
+ DEPS_DIR.mkdir(exist_ok=True)
66
+ subprocess.run(
67
+ ["git", "clone", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
68
+ check=True, capture_output=True, text=True
69
+ )
70
+ logger.info("SeedVR repository cloned successfully.")
71
+ except subprocess.CalledProcessError as e:
72
+ logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
73
+ raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
74
+ else:
75
+ logger.info("Found local SeedVR repository.")
76
+
77
+ # Add the cloned repo to Python's path to allow direct imports
78
+ if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
79
+ sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
80
+ logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
81
+
82
+ def _lazy_load_seedvr_modules(self):
83
+ """
84
+ Dynamically imports SeedVR modules only when needed.
85
+ This prevents ImportError if the class is instantiated before dependencies are ready.
86
+ """
87
+ if self._seedvr_modules_loaded:
88
+ return
89
+
90
+ global VideoDiffusionInfer, load_config, set_seed, DivisibleCrop, NaResize, Rearrange, wavelet_reconstruction, Compose, Lambda, Normalize, read_video, OmegaConf
91
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
92
+ from common.config import load_config
93
+ from common.seed import set_seed
94
+ from data.image.transforms.divisible_crop import DivisibleCrop
95
+ from data.image.transforms.na_resize import NaResize
96
+ from data.video.transforms.rearrange import Rearrange
97
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
98
+ from torchvision.transforms import Compose, Lambda, Normalize
99
+ from torchvision.io.video import read_video
100
+ from omegaconf import OmegaConf
101
+
102
+ self._seedvr_modules_loaded = True
103
+ logger.info("SeedVR modules have been dynamically loaded.")
104
 
105
  def _download_models(self):
106
  """Downloads the necessary checkpoints for SeedVR2."""
107
  logger.info("Verifying and downloading SeedVR2 models...")
108
+ ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
109
  ckpt_dir.mkdir(exist_ok=True)
110
 
111
  pretrain_model_urls = {
 
117
  }
118
 
119
  for key, url in pretrain_model_urls.items():
120
+ _load_file_from_url(url=url, model_dir=str(ckpt_dir))
121
 
122
  logger.info("SeedVR2 models downloaded successfully.")
123
 
 
125
  """Loads and configures the SeedVR model on demand based on the selected version."""
126
  if self.runner is not None:
127
  return
128
+
129
+ self._lazy_load_seedvr_modules()
130
  self._download_models()
131
 
132
  logger.info(f"Initializing SeedVR2 {model_version} runner...")
133
  if model_version == '3B':
134
+ config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
135
+ checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
136
  elif model_version == '7B':
137
+ config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
138
+ checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
139
  else:
140
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
141
 
142
+ config = load_config(str(config_path))
143
 
144
  self.runner = VideoDiffusionInfer(config)
145
  OmegaConf.set_readonly(self.runner.config, False)
146
 
147
+ self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
148
  self.runner.configure_vae_model()
149
 
150
  if hasattr(self.runner.vae, "set_memory_limit"):
 
171
  self._initialize_runner(model_version)
172
  set_seed(seed, same_across_ranks=True)
173
 
 
174
  self.runner.config.diffusion.timesteps.sampling.steps = steps
175
  self.runner.configure_diffusion()
176
 
 
194
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
195
  self.runner.dit.to(self.device)
196
 
197
+ pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
198
+ neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
199
+ text_pos_embeds = torch.load(pos_emb_path).to(self.device)
200
+ text_neg_embeds = torch.load(neg_emb_path).to(self.device)
201
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
202
 
203
  noises = [torch.randn_like(latent) for latent in cond_latents]
 
219
  final_sample = samples[0]
220
  input_video_sample = input_videos[0]
221
 
222
+ if final_sample.shape[1] < input_video_sample.shape[1]:
223
  input_video_sample = input_video_sample[:, :final_sample.shape[1]]
224
 
225
  final_sample = wavelet_reconstruction(