Aduc-sdr commited on
Commit
9a0d6a9
·
verified ·
1 Parent(s): 5f2000a

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +18 -38
managers/seedvr_manager.py CHANGED
@@ -31,7 +31,6 @@ import gradio as gr
31
  import mediapy
32
  from einops import rearrange
33
 
34
- # Internalized utility for color correction, ensuring stability.
35
  from tools.tensor_utils import wavelet_reconstruction
36
 
37
  logger = logging.getLogger(__name__)
@@ -40,17 +39,16 @@ logger = logging.getLogger(__name__)
40
  DEPS_DIR = Path("./deps")
41
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
42
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
 
43
 
44
  def setup_seedvr_dependencies():
45
  """
46
  Ensures the SeedVR repository is cloned and available in the sys.path.
47
- This function is run once when the module is first imported.
48
  """
49
  if not SEEDVR_REPO_DIR.exists():
50
  logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
51
  try:
52
  DEPS_DIR.mkdir(exist_ok=True)
53
- # Use --depth 1 for a shallow clone to save space and time
54
  subprocess.run(
55
  ["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
56
  check=True, capture_output=True, text=True
@@ -62,15 +60,12 @@ def setup_seedvr_dependencies():
62
  else:
63
  logger.info("Found local SeedVR repository.")
64
 
65
- # Add the cloned repo to Python's path to allow direct imports
66
  if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
67
  sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
68
  logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
69
 
70
- # --- Execute dependency setup immediately upon module import ---
71
  setup_seedvr_dependencies()
72
 
73
- # --- Now that the path is set, we can safely import from the cloned repo ---
74
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
75
  from common.config import load_config
76
  from common.seed import set_seed
@@ -83,7 +78,6 @@ from omegaconf import OmegaConf
83
 
84
 
85
  def _load_file_from_url(url, model_dir='./', file_name=None):
86
- """Helper function to download files from a URL to a local directory."""
87
  os.makedirs(model_dir, exist_ok=True)
88
  filename = file_name or os.path.basename(urlparse(url).path)
89
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
@@ -103,14 +97,18 @@ class SeedVrManager:
103
  self.is_initialized = False
104
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
105
 
106
- def _download_models(self):
107
- """Downloads the necessary checkpoints for SeedVR2."""
108
- logger.info("Verifying and downloading SeedVR2 models...")
109
  ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
 
110
  ckpt_dir.mkdir(exist_ok=True)
 
 
 
111
 
112
  pretrain_model_urls = {
113
- 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
114
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
115
  'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
116
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
@@ -120,14 +118,12 @@ class SeedVrManager:
120
  for key, url in pretrain_model_urls.items():
121
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
122
 
123
- logger.info("SeedVR2 models downloaded successfully.")
124
 
125
  def _initialize_runner(self, model_version: str):
126
  """Loads and configures the SeedVR model on demand based on the selected version."""
127
  if self.runner is not None: return
128
-
129
- self._download_models()
130
-
131
  logger.info(f"Initializing SeedVR2 {model_version} runner...")
132
  if model_version == '3B':
133
  config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
@@ -139,11 +135,15 @@ class SeedVrManager:
139
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
140
 
141
  config = load_config(str(config_path))
142
-
143
  self.runner = VideoDiffusionInfer(config)
144
  OmegaConf.set_readonly(self.runner.config, False)
145
-
146
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
 
 
 
 
 
 
147
  self.runner.configure_vae_model()
148
 
149
  if hasattr(self.runner.vae, "set_memory_limit"):
@@ -153,7 +153,6 @@ class SeedVrManager:
153
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
154
 
155
  def _unload_runner(self):
156
- """Removes the runner from VRAM to free resources."""
157
  if self.runner is not None:
158
  del self.runner; self.runner = None
159
  gc.collect(); torch.cuda.empty_cache()
@@ -163,17 +162,13 @@ class SeedVrManager:
163
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
164
  model_version: str = '3B', steps: int = 50, seed: int = 666,
165
  progress: gr.Progress = None) -> str:
166
- """Applies HD enhancement to a video using the SeedVR logic."""
167
  try:
168
  self._initialize_runner(model_version)
169
  set_seed(seed, same_across_ranks=True)
170
-
171
  self.runner.config.diffusion.timesteps.sampling.steps = steps
172
  self.runner.configure_diffusion()
173
-
174
  video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
175
  res_h, res_w = video_tensor.shape[-2:]
176
-
177
  video_transform = Compose([
178
  NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
179
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
@@ -181,48 +176,33 @@ class SeedVrManager:
181
  Normalize(0.5, 0.5),
182
  Rearrange("t c h w -> c t h w"),
183
  ])
184
-
185
  cond_latents = [video_transform(video_tensor.to(self.device))]
186
  input_videos = cond_latents
187
-
188
  self.runner.dit.to("cpu")
189
  self.runner.vae.to(self.device)
190
  cond_latents = self.runner.vae_encode(cond_latents)
191
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
192
  self.runner.dit.to(self.device)
193
-
194
  pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
195
  neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
196
  text_pos_embeds = torch.load(pos_emb_path).to(self.device)
197
  text_neg_embeds = torch.load(neg_emb_path).to(self.device)
198
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
199
-
200
  noises = [torch.randn_like(latent) for latent in cond_latents]
201
  conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
202
-
203
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
204
  video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
205
-
206
  self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
207
-
208
  self.runner.vae.to(self.device)
209
  samples = self.runner.vae_decode(video_tensors)
210
-
211
  final_sample = samples[0]
212
  input_video_sample = input_videos[0]
213
-
214
  if final_sample.shape[1] < input_video_sample.shape[1]:
215
  input_video_sample = input_video_sample[:, :final_sample.shape[1]]
216
-
217
- final_sample = wavelet_reconstruction(
218
- rearrange(final_sample, "c t h w -> t c h w"),
219
- rearrange(input_video_sample, "c t h w -> t c h w")
220
- )
221
-
222
  final_sample = rearrange(final_sample, "t c h w -> t h w c")
223
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
224
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
225
-
226
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
227
  logger.info(f"HD Mastered video saved to: {output_video_path}")
228
  return output_video_path
 
31
  import mediapy
32
  from einops import rearrange
33
 
 
34
  from tools.tensor_utils import wavelet_reconstruction
35
 
36
  logger = logging.getLogger(__name__)
 
39
  DEPS_DIR = Path("./deps")
40
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
41
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
42
+ VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
43
 
44
  def setup_seedvr_dependencies():
45
  """
46
  Ensures the SeedVR repository is cloned and available in the sys.path.
 
47
  """
48
  if not SEEDVR_REPO_DIR.exists():
49
  logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
50
  try:
51
  DEPS_DIR.mkdir(exist_ok=True)
 
52
  subprocess.run(
53
  ["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
54
  check=True, capture_output=True, text=True
 
60
  else:
61
  logger.info("Found local SeedVR repository.")
62
 
 
63
  if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
64
  sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
65
  logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
66
 
 
67
  setup_seedvr_dependencies()
68
 
 
69
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
70
  from common.config import load_config
71
  from common.seed import set_seed
 
78
 
79
 
80
  def _load_file_from_url(url, model_dir='./', file_name=None):
 
81
  os.makedirs(model_dir, exist_ok=True)
82
  filename = file_name or os.path.basename(urlparse(url).path)
83
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
 
97
  self.is_initialized = False
98
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
99
 
100
+ def _download_models_and_configs(self):
101
+ """Downloads the necessary checkpoints AND the missing VAE config file."""
102
+ logger.info("Verifying and downloading SeedVR2 models and configs...")
103
  ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
104
+ config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
105
  ckpt_dir.mkdir(exist_ok=True)
106
+ config_dir.mkdir(parents=True, exist_ok=True)
107
+
108
+ _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
109
 
110
  pretrain_model_urls = {
111
+ 'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
112
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
113
  'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
114
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
 
118
  for key, url in pretrain_model_urls.items():
119
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
120
 
121
+ logger.info("SeedVR2 models and configs downloaded successfully.")
122
 
123
  def _initialize_runner(self, model_version: str):
124
  """Loads and configures the SeedVR model on demand based on the selected version."""
125
  if self.runner is not None: return
126
+ self._download_models_and_configs()
 
 
127
  logger.info(f"Initializing SeedVR2 {model_version} runner...")
128
  if model_version == '3B':
129
  config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
 
135
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
136
 
137
  config = load_config(str(config_path))
 
138
  self.runner = VideoDiffusionInfer(config)
139
  OmegaConf.set_readonly(self.runner.config, False)
 
140
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
141
+
142
+ # --- PATH CORRECTION ---
143
+ correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
144
+ logger.info(f"Correcting VAE config path to: {correct_vae_config_path}")
145
+ self.runner.config.vae.config = str(correct_vae_config_path)
146
+
147
  self.runner.configure_vae_model()
148
 
149
  if hasattr(self.runner.vae, "set_memory_limit"):
 
153
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
154
 
155
  def _unload_runner(self):
 
156
  if self.runner is not None:
157
  del self.runner; self.runner = None
158
  gc.collect(); torch.cuda.empty_cache()
 
162
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
163
  model_version: str = '3B', steps: int = 50, seed: int = 666,
164
  progress: gr.Progress = None) -> str:
 
165
  try:
166
  self._initialize_runner(model_version)
167
  set_seed(seed, same_across_ranks=True)
 
168
  self.runner.config.diffusion.timesteps.sampling.steps = steps
169
  self.runner.configure_diffusion()
 
170
  video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
171
  res_h, res_w = video_tensor.shape[-2:]
 
172
  video_transform = Compose([
173
  NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
174
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
 
176
  Normalize(0.5, 0.5),
177
  Rearrange("t c h w -> c t h w"),
178
  ])
 
179
  cond_latents = [video_transform(video_tensor.to(self.device))]
180
  input_videos = cond_latents
 
181
  self.runner.dit.to("cpu")
182
  self.runner.vae.to(self.device)
183
  cond_latents = self.runner.vae_encode(cond_latents)
184
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
185
  self.runner.dit.to(self.device)
 
186
  pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
187
  neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
188
  text_pos_embeds = torch.load(pos_emb_path).to(self.device)
189
  text_neg_embeds = torch.load(neg_emb_path).to(self.device)
190
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
 
191
  noises = [torch.randn_like(latent) for latent in cond_latents]
192
  conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
 
193
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
194
  video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
 
195
  self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
 
196
  self.runner.vae.to(self.device)
197
  samples = self.runner.vae_decode(video_tensors)
 
198
  final_sample = samples[0]
199
  input_video_sample = input_videos[0]
 
200
  if final_sample.shape[1] < input_video_sample.shape[1]:
201
  input_video_sample = input_video_sample[:, :final_sample.shape[1]]
202
+ final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
 
 
 
 
 
203
  final_sample = rearrange(final_sample, "t c h w -> t h w c")
204
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
205
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
 
206
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
207
  logger.info(f"HD Mastered video saved to: {output_video_path}")
208
  return output_video_path