multimodalart HF Staff commited on
Commit
41e72bb
·
verified ·
1 Parent(s): 35ce3f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -243
app.py CHANGED
@@ -1,30 +1,35 @@
1
  import os
2
  import shutil
3
- import random
4
  import sys
 
 
 
 
5
  import tempfile
6
  from typing import Sequence, Mapping, Any, Union
7
 
8
- import spaces
9
  import torch
10
  import gradio as gr
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
- from comfy import model_management
 
 
14
 
15
  def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
 
16
  downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
17
  os.makedirs(local_dir, exist_ok=True)
18
  base_filename = os.path.basename(filename)
19
  target_path = os.path.join(local_dir, base_filename)
20
 
 
21
  if os.path.exists(target_path) or os.path.islink(target_path):
22
  os.remove(target_path)
23
 
24
  os.symlink(downloaded_path, target_path)
25
  return target_path
26
 
27
- # --- Model Downloads ---
28
  print("Downloading models from Hugging Face Hub...")
29
  hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
30
  hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
@@ -35,241 +40,151 @@ hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/
35
  hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
36
  print("Downloads complete.")
37
 
38
- model_management.vram_state = model_management.VRAMState.HIGH_VRAM
39
-
40
- # --- Image Processing Functions ---
41
- def calculate_video_dimensions(width, height, max_size=832, min_size=480):
42
- """
43
- Calculate video dimensions based on input image size.
44
- Larger dimension becomes max_size, smaller becomes proportional.
45
- If square, use min_size x min_size.
46
- Results are rounded to nearest multiple of 16.
47
- """
48
- # Handle square images
49
- if width == height:
50
- video_width = min_size
51
- video_height = min_size
52
- else:
53
- # Calculate aspect ratio
54
- aspect_ratio = width / height
55
-
56
- if width > height:
57
- # Landscape orientation
58
- video_width = max_size
59
- video_height = int(max_size / aspect_ratio)
60
- else:
61
- # Portrait orientation
62
- video_height = max_size
63
- video_width = int(max_size * aspect_ratio)
64
-
65
- # Round to nearest multiple of 16
66
- video_width = round(video_width / 16) * 16
67
- video_height = round(video_height / 16) * 16
68
-
69
- # Ensure minimum size
70
- video_width = max(video_width, 16)
71
- video_height = max(video_height, 16)
72
-
73
- return video_width, video_height
74
-
75
- def resize_and_crop_to_match(target_image, reference_image):
76
- """
77
- Resize and center crop target_image to match reference_image dimensions.
78
- """
79
- ref_width, ref_height = reference_image.size
80
- target_width, target_height = target_image.size
81
-
82
- # Calculate scaling factor to ensure target covers reference dimensions
83
- scale = max(ref_width / target_width, ref_height / target_height)
84
-
85
- # Resize target image
86
- new_width = int(target_width * scale)
87
- new_height = int(target_height * scale)
88
- resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
89
-
90
- # Center crop to match reference dimensions
91
- left = (new_width - ref_width) // 2
92
- top = (new_height - ref_height) // 2
93
- right = left + ref_width
94
- bottom = top + ref_height
95
-
96
- cropped = resized.crop((left, top, right, bottom))
97
- return cropped
98
-
99
- # --- Boilerplate code from the original script ---
100
- def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
101
- """Returns the value at the given index of a sequence or mapping.
102
 
103
- If the object is a sequence (like list or string), returns the value at the given index.
104
- If the object is a mapping (like a dictionary), returns the value at the index-th key.
105
-
106
- Some return a dictionary, in these cases, we look for the "results" key
107
-
108
- Args:
109
- obj (Union[Sequence, Mapping]): The object to retrieve the value from.
110
- index (int): The index of the value to retrieve.
111
-
112
- Returns:
113
- Any: The value at the given index.
114
-
115
- Raises:
116
- IndexError: If the index is out of bounds for the object and the object is not a mapping.
117
- """
118
- try:
119
- return obj[index]
120
- except KeyError:
121
- # This is a fallback for custom node outputs that might be dictionaries
122
- if isinstance(obj, Mapping) and "result" in obj:
123
- return obj["result"][index]
124
- raise
125
 
126
  def find_path(name: str, path: str = None) -> str:
127
- """
128
- Recursively looks at parent folders starting from the given path until it finds the given name.
129
- Returns the path as a Path object if found, or None otherwise.
130
- """
131
- if path is None:
132
- path = os.getcwd()
133
-
134
- if name in os.listdir(path):
135
- path_name = os.path.join(path, name)
136
- print(f"'{name}' found: {path_name}")
137
- return path_name
138
-
139
  parent_directory = os.path.dirname(path)
140
- if parent_directory == path:
141
- return None
142
-
143
- return find_path(name, parent_directory)
144
-
145
 
146
  def add_comfyui_directory_to_sys_path() -> None:
147
- """
148
- Add 'ComfyUI' to the sys.path
149
- """
150
  comfyui_path = find_path("ComfyUI")
151
- if comfyui_path is not None and os.path.isdir(comfyui_path):
152
  sys.path.append(comfyui_path)
153
  print(f"'{comfyui_path}' added to sys.path")
154
- else:
155
- print("Could not find ComfyUI directory. Please run from a parent folder of ComfyUI.")
156
 
157
  def add_extra_model_paths() -> None:
158
- """
159
- Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.
160
- """
161
- try:
162
- from main import load_extra_path_config
163
- except ImportError:
164
- print(
165
- "Could not import load_extra_path_config from main.py. This might be okay if you don't use it."
166
- )
167
- return
168
-
169
- extra_model_paths = find_path("extra_model_paths.yaml")
170
- if extra_model_paths is not None:
171
- load_extra_path_config(extra_model_paths)
172
- else:
173
- print("Could not find an optional 'extra_model_paths.yaml' config file.")
174
 
175
  def import_custom_nodes() -> None:
176
- """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS
177
- This function sets up a new asyncio event loop, initializes the PromptServer,
178
- creates a PromptQueue, and initializes the custom nodes.
179
- """
180
- import asyncio
181
- import execution
182
- from nodes import init_extra_nodes
183
- import server
184
-
185
  loop = asyncio.new_event_loop()
186
  asyncio.set_event_loop(loop)
187
- server_instance = server.PromptServer(loop)
188
- execution.PromptQueue(server_instance)
189
- loop.run_until_complete(init_extra_nodes(init_custom_nodes=True))
190
-
191
 
192
- # --- Model Loading and Caching ---
193
- MODELS_AND_NODES = {}
194
-
195
- print("Setting up ComfyUI paths...")
196
  add_comfyui_directory_to_sys_path()
197
  add_extra_model_paths()
198
-
199
- print("Importing custom nodes...")
200
  import_custom_nodes()
 
 
 
 
201
 
202
- # Now that paths are set up, we can import from nodes
203
  from nodes import NODE_CLASS_MAPPINGS
204
- global folder_paths # Make folder_paths globally accessible
205
  import folder_paths
 
206
 
207
- print("Loading models into memory. This may take a few minutes...")
 
208
 
209
- # Load Text-to-Image models (CLIP, UNETs, VAE)
210
- cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
211
- MODELS_AND_NODES["clip"] = cliploader.load_clip(
212
- clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan"
213
- )
214
 
215
- unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
216
- unet_low_noise = unetloader.load_unet(
217
- unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
218
- weight_dtype="default",
219
- )
220
- unet_high_noise = unetloader.load_unet(
221
- unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
222
- weight_dtype="default",
223
- )
 
 
224
 
 
 
 
225
  vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
 
 
 
 
 
 
 
 
 
226
  MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
 
227
 
228
- # Load LoRAs
229
- loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
230
- MODELS_AND_NODES["model_low_noise"] = loraloadermodelonly.load_lora_model_only(
 
 
231
  lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
232
- strength_model=0.8,
233
- model=get_value_at_index(unet_low_noise, 0),
234
- )
235
- MODELS_AND_NODES["model_high_noise"] = loraloadermodelonly.load_lora_model_only(
236
- lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors",
237
- strength_model=0.8,
238
- model=get_value_at_index(unet_high_noise, 0),
239
- )
240
 
241
- # Load Vision model
242
- clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
243
- MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(
244
- clip_name="clip_vision_h.safetensors"
245
- )
 
246
 
247
- # Instantiate all required node classes
248
  MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
249
  MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]()
250
  MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
251
- MODELS_AND_NODES["ModelSamplingSD3"] = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
252
- MODELS_AND_NODES["PathchSageAttentionKJ"] = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
253
  MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
254
  MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
255
  MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]()
256
  MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]()
257
  MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]()
258
 
259
- print("Pre-loading main models onto GPU...")
260
- model_loaders = [
 
261
  MODELS_AND_NODES["clip"],
262
  MODELS_AND_NODES["vae"],
263
- MODELS_AND_NODES["model_low_noise"], # This is the UNET + LoRA
264
- MODELS_AND_NODES["model_high_noise"], # This is the other UNET + LoRA
265
  MODELS_AND_NODES["clip_vision"],
266
  ]
267
  model_management.load_models_gpu([
268
- loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
269
- ])
270
- print("All models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- # --- Main Video Generation Logic ---
273
  @spaces.GPU(duration=120)
274
  def generate_video(
275
  start_image_pil,
@@ -280,65 +195,59 @@ def generate_video(
280
  progress=gr.Progress(track_tqdm=True)
281
  ):
282
  """
283
- The main function to generate a video based on user inputs.
284
- This function is called every time the user clicks the 'Generate' button.
285
  """
286
  FPS = 16
287
-
288
- # Process images: resize and crop second image to match first
289
- # The first image determines the dimensions
290
- processed_start_image = start_image_pil.copy()
291
- processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil)
292
-
293
- # Calculate video dimensions based on the first image
294
- video_width, video_height = calculate_video_dimensions(
295
- processed_start_image.width,
296
- processed_start_image.height
297
- )
298
-
299
- print(f"Input image size: {processed_start_image.width}x{processed_start_image.height}")
300
- print(f"Video dimensions: {video_width}x{video_height}")
301
-
302
  clip = MODELS_AND_NODES["clip"]
303
  vae = MODELS_AND_NODES["vae"]
304
- model_low_noise = MODELS_AND_NODES["model_low_noise"]
305
- model_high_noise = MODELS_AND_NODES["model_high_noise"]
306
  clip_vision = MODELS_AND_NODES["clip_vision"]
307
 
308
  cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
309
  loadimage = MODELS_AND_NODES["LoadImage"]
310
  clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
311
- modelsamplingsd3 = MODELS_AND_NODES["ModelSamplingSD3"]
312
- pathchsageattentionkj = MODELS_AND_NODES["PathchSageAttentionKJ"]
313
  wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
314
  ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
315
  vaedecode = MODELS_AND_NODES["VAEDecode"]
316
  createvideo = MODELS_AND_NODES["CreateVideo"]
317
  savevideo = MODELS_AND_NODES["SaveVideo"]
318
 
319
- # Save processed images to temporary files
320
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as start_file, \
321
- tempfile.NamedTemporaryFile(suffix=".png", delete=False) as end_file:
 
 
 
 
 
 
 
 
 
322
  processed_start_image.save(start_file.name)
323
  processed_end_image.save(end_file.name)
324
- start_image_path = start_file.name
325
- end_image_path = end_file.name
326
-
 
 
327
  with torch.inference_mode():
328
  progress(0.1, desc="Encoding text and images...")
329
- # --- Workflow execution ---
 
330
  positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
331
  negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
332
 
333
  start_image_loaded = loadimage.load_image(image=start_image_path)
334
  end_image_loaded = loadimage.load_image(image=end_image_path)
335
 
336
- clip_vision_encoded_start = clipvisionencode.encode(
337
- crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)
338
- )
339
- clip_vision_encoded_end = clipvisionencode.encode(
340
- crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)
341
- )
342
 
343
  progress(0.2, desc="Preparing initial latents...")
344
  initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
@@ -352,30 +261,27 @@ def generate_video(
352
  end_image=get_value_at_index(end_image_loaded, 0),
353
  )
354
 
355
- progress(0.3, desc="Patching models...")
356
- model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0))
357
- model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
358
-
359
- model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0))
360
- model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
361
-
362
- progress(0.5, desc="Running KSampler (Step 1/2)...")
363
  latent_step1 = ksampleradvanced.sample(
364
  add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
365
  sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
366
  return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
367
- positive=get_value_at_index(initial_latents, 0),
368
- negative=get_value_at_index(initial_latents, 1),
369
- latent_image=get_value_at_index(initial_latents, 2),
370
  )
371
 
372
- progress(0.7, desc="Running KSampler (Step 2/2)...")
373
  latent_step2 = ksampleradvanced.sample(
374
  add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
375
  sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
376
  return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
377
- positive=get_value_at_index(initial_latents, 0),
378
- negative=get_value_at_index(initial_latents, 1),
379
  latent_image=get_value_at_index(latent_step1, 0),
380
  )
381
 
@@ -385,15 +291,23 @@ def generate_video(
385
  progress(0.9, desc="Creating and saving video...")
386
  video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
387
 
388
- # Save the video to ComfyUI's output directory
389
  save_result = savevideo.save_video(
390
  filename_prefix="GradioVideo", format="mp4", codec="h264",
391
  video=get_value_at_index(video_data, 0),
392
  )
393
-
394
  progress(1.0, desc="Done!")
395
- return f"output/{save_result['ui']['images'][0]['filename']}"
396
 
 
 
 
 
 
 
 
 
 
397
 
398
 
399
  css = '''
 
1
  import os
2
  import shutil
 
3
  import sys
4
+ import subprocess
5
+ import asyncio
6
+ import uuid
7
+ import random
8
  import tempfile
9
  from typing import Sequence, Mapping, Any, Union
10
 
 
11
  import torch
12
  import gradio as gr
13
  from PIL import Image
14
  from huggingface_hub import hf_hub_download
15
+ import spaces
16
+
17
+ # --- 1. Model Download and Setup ---
18
 
19
  def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
20
+ """Downloads a file from Hugging Face Hub and symlinks it to a local directory."""
21
  downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
22
  os.makedirs(local_dir, exist_ok=True)
23
  base_filename = os.path.basename(filename)
24
  target_path = os.path.join(local_dir, base_filename)
25
 
26
+ # Remove existing symlink or file to avoid errors
27
  if os.path.exists(target_path) or os.path.islink(target_path):
28
  os.remove(target_path)
29
 
30
  os.symlink(downloaded_path, target_path)
31
  return target_path
32
 
 
33
  print("Downloading models from Hugging Face Hub...")
34
  hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
35
  hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
 
40
  hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
41
  print("Downloads complete.")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # --- 2. ComfyUI Backend Initialization ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def find_path(name: str, path: str = None) -> str:
47
+ """Recursively finds a directory with a given name."""
48
+ if path is None: path = os.getcwd()
49
+ if name in os.listdir(path): return os.path.join(path, name)
 
 
 
 
 
 
 
 
 
50
  parent_directory = os.path.dirname(path)
51
+ return find_path(name, parent_directory) if parent_directory != path else None
 
 
 
 
52
 
53
  def add_comfyui_directory_to_sys_path() -> None:
54
+ """Adds the ComfyUI directory to sys.path for imports."""
 
 
55
  comfyui_path = find_path("ComfyUI")
56
+ if comfyui_path and os.path.isdir(comfyui_path):
57
  sys.path.append(comfyui_path)
58
  print(f"'{comfyui_path}' added to sys.path")
 
 
59
 
60
  def add_extra_model_paths() -> None:
61
+ """Initializes ComfyUI's folder_paths with custom paths."""
62
+ from main import apply_custom_paths
63
+ apply_custom_paths()
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def import_custom_nodes() -> None:
66
+ """Initializes all ComfyUI custom nodes."""
67
+ import nodes
 
 
 
 
 
 
 
68
  loop = asyncio.new_event_loop()
69
  asyncio.set_event_loop(loop)
70
+ loop.run_until_complete(nodes.init_extra_nodes(init_custom_nodes=True))
 
 
 
71
 
72
+ print("Setting up ComfyUI paths and nodes...")
 
 
 
73
  add_comfyui_directory_to_sys_path()
74
  add_extra_model_paths()
 
 
75
  import_custom_nodes()
76
+ print("ComfyUI setup complete.")
77
+
78
+
79
+ # --- 3. Global Model & Node Loading and Patching ---
80
 
 
81
  from nodes import NODE_CLASS_MAPPINGS
 
82
  import folder_paths
83
+ from comfy import model_management
84
 
85
+ # Set VRAM mode to HIGH to prevent models from being offloaded from GPU after use.
86
+ model_management.vram_state = model_management.VRAMState.HIGH_VRAM
87
 
88
+ MODELS_AND_NODES = {}
 
 
 
 
89
 
90
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
91
+ """Helper to safely access outputs from ComfyUI nodes, which are often tuples."""
92
+ try:
93
+ return obj[index]
94
+ except (KeyError, TypeError):
95
+ # Fallback for custom nodes that might return a dictionary with a 'result' key
96
+ if isinstance(obj, Mapping) and "result" in obj:
97
+ return obj["result"][index]
98
+ raise
99
+
100
+ print("Loading models and instantiating nodes into memory. This may take a few minutes...")
101
 
102
+ # Instantiate Node Classes that will be used for loading and patching
103
+ cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
104
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
105
  vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
106
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
107
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
108
+ modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
109
+ pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
110
+
111
+ # Load base models into CPU RAM initially
112
+ MODELS_AND_NODES["clip"] = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan")
113
+ unet_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
114
+ unet_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
115
  MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
116
+ MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")
117
 
118
+ # Chain all patching operations together for the final models
119
+ print("Applying all patches to models...")
120
+
121
+ # --- Low Noise Model Chain ---
122
+ model_low_with_lora = loraloadermodelonly.load_lora_model_only(
123
  lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
124
+ strength_model=0.8, model=get_value_at_index(unet_low_noise, 0))
125
+ model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_with_lora, 0))
126
+ MODELS_AND_NODES["model_low_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
 
 
 
 
 
127
 
128
+ # --- High Noise Model Chain ---
129
+ model_high_with_lora = loraloadermodelonly.load_lora_model_only(
130
+ lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors",
131
+ strength_model=0.8, model=get_value_at_index(unet_high_noise, 0))
132
+ model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_with_lora, 0))
133
+ MODELS_AND_NODES["model_high_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
134
 
135
+ # Instantiate all other node classes ONCE and store them
136
  MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
137
  MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]()
138
  MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
 
 
139
  MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
140
  MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
141
  MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]()
142
  MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]()
143
  MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]()
144
 
145
+ # Move all final, fully-patched models to the GPU
146
+ print("Moving final models to GPU...")
147
+ model_loaders_final = [
148
  MODELS_AND_NODES["clip"],
149
  MODELS_AND_NODES["vae"],
150
+ MODELS_AND_NODES["model_low_noise"],
151
+ MODELS_AND_NODES["model_high_noise"],
152
  MODELS_AND_NODES["clip_vision"],
153
  ]
154
  model_management.load_models_gpu([
155
+ loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders_final
156
+ ], force_patch_weights=True) # force_patch_weights permanently merges the LoRA
157
+
158
+ print("All models loaded, patched, and on GPU. Gradio app is ready.")
159
+
160
+
161
+ # --- 4. Application Logic and Gradio Interface ---
162
+
163
+ def calculate_video_dimensions(width, height, max_size=832, min_size=480):
164
+ """Calculates video dimensions, ensuring they are multiples of 16."""
165
+ if width == height:
166
+ return min_size, min_size
167
+ aspect_ratio = width / height
168
+ if width > height:
169
+ video_width = max_size
170
+ video_height = int(max_size / aspect_ratio)
171
+ else:
172
+ video_height = max_size
173
+ video_width = int(max_size * aspect_ratio)
174
+ video_width = max(16, round(video_width / 16) * 16)
175
+ video_height = max(16, round(video_height / 16) * 16)
176
+ return video_width, video_height
177
+
178
+ def resize_and_crop_to_match(target_image, reference_image):
179
+ """Resizes and center-crops the target image to match the reference image's dimensions."""
180
+ ref_width, ref_height = reference_image.size
181
+ target_width, target_height = target_image.size
182
+ scale = max(ref_width / target_width, ref_height / target_height)
183
+ new_width, new_height = int(target_width * scale), int(target_height * scale)
184
+ resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
185
+ left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
186
+ return resized.crop((left, top, left + ref_width, top + ref_height))
187
 
 
188
  @spaces.GPU(duration=120)
189
  def generate_video(
190
  start_image_pil,
 
195
  progress=gr.Progress(track_tqdm=True)
196
  ):
197
  """
198
+ Generates a video by interpolating between a start and end image, guided by a text prompt.
199
+ This function relies on globally pre-loaded models and pre-instantiated ComfyUI nodes.
200
  """
201
  FPS = 16
202
+
203
+ # --- 1. Retrieve Pre-loaded and Pre-patched Models & Node Instances ---
204
+ # These are not re-instantiated; we are just getting references to the global objects.
 
 
 
 
 
 
 
 
 
 
 
 
205
  clip = MODELS_AND_NODES["clip"]
206
  vae = MODELS_AND_NODES["vae"]
207
+ model_low_final = MODELS_AND_NODES["model_low_noise"]
208
+ model_high_final = MODELS_AND_NODES["model_high_noise"]
209
  clip_vision = MODELS_AND_NODES["clip_vision"]
210
 
211
  cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
212
  loadimage = MODELS_AND_NODES["LoadImage"]
213
  clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
 
 
214
  wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
215
  ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
216
  vaedecode = MODELS_AND_NODES["VAEDecode"]
217
  createvideo = MODELS_AND_NODES["CreateVideo"]
218
  savevideo = MODELS_AND_NODES["SaveVideo"]
219
 
220
+ # --- 2. Image Preprocessing for the Current Run ---
221
+ print("Preprocessing images with Pillow...")
222
+ processed_start_image = start_image_pil.copy()
223
+ processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil)
224
+ video_width, video_height = calculate_video_dimensions(processed_start_image.width, processed_start_image.height)
225
+
226
+ # Save processed images to temporary files for the LoadImage node
227
+ temp_dir = "input" # ComfyUI's default input directory
228
+ os.makedirs(temp_dir, exist_ok=True)
229
+
230
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as start_file, \
231
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as end_file:
232
  processed_start_image.save(start_file.name)
233
  processed_end_image.save(end_file.name)
234
+ start_image_path = os.path.basename(start_file.name)
235
+ end_image_path = os.path.basename(end_file.name)
236
+ print(f"Images resized to {video_width}x{video_height} and saved temporarily.")
237
+
238
+ # --- 3. Execute the ComfyUI Workflow in Inference Mode ---
239
  with torch.inference_mode():
240
  progress(0.1, desc="Encoding text and images...")
241
+
242
+ # Encode prompts and vision models
243
  positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
244
  negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
245
 
246
  start_image_loaded = loadimage.load_image(image=start_image_path)
247
  end_image_loaded = loadimage.load_image(image=end_image_path)
248
 
249
+ clip_vision_encoded_start = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0))
250
+ clip_vision_encoded_end = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0))
 
 
 
 
251
 
252
  progress(0.2, desc="Preparing initial latents...")
253
  initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
 
261
  end_image=get_value_at_index(end_image_loaded, 0),
262
  )
263
 
264
+ ksampler_positive = get_value_at_index(initial_latents, 0)
265
+ ksampler_negative = get_value_at_index(initial_latents, 1)
266
+ ksampler_latent = get_value_at_index(initial_latents, 2)
267
+
268
+ progress(0.5, desc="Denoising (Step 1/2)...")
 
 
 
269
  latent_step1 = ksampleradvanced.sample(
270
  add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
271
  sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
272
  return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
273
+ positive=ksampler_positive,
274
+ negative=ksampler_negative,
275
+ latent_image=ksampler_latent,
276
  )
277
 
278
+ progress(0.7, desc="Denoising (Step 2/2)...")
279
  latent_step2 = ksampleradvanced.sample(
280
  add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
281
  sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
282
  return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
283
+ positive=ksampler_positive,
284
+ negative=ksampler_negative,
285
  latent_image=get_value_at_index(latent_step1, 0),
286
  )
287
 
 
291
  progress(0.9, desc="Creating and saving video...")
292
  video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
293
 
294
+ # Save the video to ComfyUI's default output directory
295
  save_result = savevideo.save_video(
296
  filename_prefix="GradioVideo", format="mp4", codec="h264",
297
  video=get_value_at_index(video_data, 0),
298
  )
299
+
300
  progress(1.0, desc="Done!")
 
301
 
302
+ # --- 4. Cleanup and Return ---
303
+ try:
304
+ os.remove(start_file.name)
305
+ os.remove(end_file.name)
306
+ except Exception as e:
307
+ print(f"Error cleaning up temporary files: {e}")
308
+
309
+ # Gradio video component expects a filepath relative to the root of the app
310
+ return f"output/{save_result['ui']['images'][0]['filename']}"
311
 
312
 
313
  css = '''