dn6 HF Staff commited on
Commit
64f40e5
·
verified ·
1 Parent(s): 5b1c701

Diffusers updates

Browse files
.gitattributes CHANGED
@@ -33,9 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- hf_assets/3.mp4 filter=lfs diff=lfs merge=lfs -text
37
- hf_assets/masonry.mp4 filter=lfs diff=lfs merge=lfs -text
38
- hf_assets/trimferarricrop2_2x_speed.mp4 filter=lfs diff=lfs merge=lfs -text
39
- hf_assets/v2v_me_crop_1final.mov filter=lfs diff=lfs merge=lfs -text
40
- hf_assets/vertical_grid_all_videos.mp4 filter=lfs diff=lfs merge=lfs -text
41
- hf_assets/vertical_grid_output_reordered.mp4 filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
README.md CHANGED
@@ -12,7 +12,7 @@ library_name: diffusers
12
  ---
13
  Krea Realtime 14B is distilled from the [Wan 2.1 14B text-to-video model](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) using Self-Forcing, a technique for converting regular video diffusion models into autoregressive models. It achieves a text-to-video inference speed of **11fps** using 4 inference steps on a single NVIDIA B200 GPU. For more details on our training methodology and sampling innovations, refer to our [technical blog post](https://www.krea.ai/blog/krea-realtime-14b).
14
 
15
- Inference code can be found [here](https://github.com/krea-ai/realtime-video).
16
 
17
 
18
  <video width="100%" controls>
@@ -118,60 +118,288 @@ export CUDA_VISIBLE_DEVICES=0 # pick the GPU you want to serve on
118
  export DO_COMPILE=true
119
 
120
  uvicorn release_server:app --host 0.0.0.0 --port 8000
121
- ```
122
 
123
- And use the web app at http://localhost:8000/ in your browser
124
  (for more advanced use-cases and custom pipeline check out our GitHub repository: https://github.com/krea-ai/realtime-video)
125
 
126
  # Use it with 🧨 diffusers
127
 
128
- Krea Realtime 14B can be used with the `diffusers` library utilizing the new Modular Diffusers structure (for now supporting text-to-video, video-to-video coming soon)
129
 
130
  ```bash
131
  # Install diffusers from main
132
  pip install git+github.com/huggingface/diffusers.git
133
- ```
 
 
 
134
 
135
  ```py
136
  import torch
137
- from collections import deque
138
  from diffusers.utils import export_to_video
139
- from diffusers import ModularPipelineBlocks
140
- from diffusers.modular_pipelines import PipelineState, WanModularPipeline
141
 
142
  repo_id = "krea/krea-realtime-video"
143
- blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
144
- pipe = WanModularPipeline(blocks, repo_id)
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  pipe.load_components(
147
  trust_remote_code=True,
148
  device_map="cuda",
149
  torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
150
  )
 
 
151
 
152
- num_frames_per_block = 3
153
  num_blocks = 9
 
154
 
155
  frames = []
 
 
156
  state = PipelineState()
157
- state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- prompt = ["a cat sitting on a boat"]
 
 
 
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  for block in pipe.transformer.blocks:
162
  block.self_attn.fuse_projections()
163
 
164
- for block_idx in range(num_blocks):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  state = pipe(
166
  state,
167
  prompt=prompt,
168
  num_inference_steps=6,
169
  num_blocks=num_blocks,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  num_frames_per_block=num_frames_per_block,
171
  block_idx=block_idx,
172
  generator=torch.Generator("cuda").manual_seed(42),
173
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  frames.extend(state.values["videos"][0])
175
 
176
- export_to_video(frames, "output.mp4", fps=16)
177
- ```
 
 
12
  ---
13
  Krea Realtime 14B is distilled from the [Wan 2.1 14B text-to-video model](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) using Self-Forcing, a technique for converting regular video diffusion models into autoregressive models. It achieves a text-to-video inference speed of **11fps** using 4 inference steps on a single NVIDIA B200 GPU. For more details on our training methodology and sampling innovations, refer to our [technical blog post](https://www.krea.ai/blog/krea-realtime-14b).
14
 
15
+ Inference code can be found [here](https://github.com/krea-ai/realtime-video).
16
 
17
 
18
  <video width="100%" controls>
 
118
  export DO_COMPILE=true
119
 
120
  uvicorn release_server:app --host 0.0.0.0 --port 8000
121
+ ```
122
 
123
+ And use the web app at http://localhost:8000/ in your browser
124
  (for more advanced use-cases and custom pipeline check out our GitHub repository: https://github.com/krea-ai/realtime-video)
125
 
126
  # Use it with 🧨 diffusers
127
 
128
+ Krea Realtime 14B can be used with the `diffusers` library utilizing the new Modular Diffusers structure
129
 
130
  ```bash
131
  # Install diffusers from main
132
  pip install git+github.com/huggingface/diffusers.git
133
+ ```
134
+
135
+ <details>
136
+ <summary>Text to Video</summary>
137
 
138
  ```py
139
  import torch
140
+ from tqdm import tqdm
141
  from diffusers.utils import export_to_video
142
+ from diffusers import ModularPipeline
143
+ from diffusers.modular_pipelines import PipelineState
144
 
145
  repo_id = "krea/krea-realtime-video"
146
+ pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
147
+ pipe.load_components(
148
+ trust_remote_code=True,
149
+ device_map="cuda",
150
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
151
+ )
152
+ for block in pipe.transformer.blocks:
153
+ block.self_attn.fuse_projections()
154
+
155
+ num_blocks = 9
156
 
157
+ frames = []
158
+ state = PipelineState()
159
+ prompt = ["a cat sitting on a boat"]
160
+
161
+ generator = torch.Generator(device=pipe.device).manual_seed(42)
162
+ for block_idx in tqdm(range(num_blocks)):
163
+ state = pipe(
164
+ state,
165
+ prompt=prompt,
166
+ num_inference_steps=6,
167
+ num_blocks=num_blocks,
168
+ block_idx=block_idx,
169
+ generator=generator,
170
+ )
171
+ frames.extend(state.values["videos"][0])
172
+
173
+ export_to_video(frames, "output.mp4", fps=24)
174
+ ```
175
+ </details>
176
+
177
+ <details>
178
+ <summary>Video to Video</summary>
179
+
180
+ ```py
181
+ import torch
182
+ from tqdm import tqdm
183
+ from diffusers.utils import load_video, export_to_video
184
+ from diffusers import ModularPipeline
185
+ from diffusers.modular_pipelines import PipelineState
186
+
187
+ repo_id = "krea/krea-realtime-video"
188
+ pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
189
  pipe.load_components(
190
  trust_remote_code=True,
191
  device_map="cuda",
192
  torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
193
  )
194
+ for block in pipe.transformer.blocks:
195
+ block.self_attn.fuse_projections()
196
 
 
197
  num_blocks = 9
198
+ video = load_video("https://app-uploads.krea.ai/public/a8218957-1a80-43dc-81b2-da970b5f2221-video.mp4")
199
 
200
  frames = []
201
+ prompt = ["A car racing down a snowy mountain"]
202
+
203
  state = PipelineState()
204
+ generator = torch.Generator("cuda").manual_seed(42)
205
+ for block_idx in tqdm(range(num_blocks)):
206
+ state = pipe(
207
+ state,
208
+ video=video,
209
+ prompt=prompt,
210
+ num_inference_steps=6,
211
+ strength=0.3,
212
+ block_idx=block_idx,
213
+ generator=generator,
214
+ )
215
+ frames.extend(state.values["videos"][0])
216
 
217
+ export_to_video(frames, "output-v2v.mp4", fps=24)
218
+ ```
219
+ </details>
220
+
221
+ <details>
222
+ <summary>Streaming Video to Video</summary>
223
 
224
+ Using the `video_stream` input will process video frames in as they arrive, while maintaining temporal consistency across chunks.
225
+
226
+ ```py
227
+ import torch
228
+ from collections import deque
229
+ from tqdm import tqdm
230
+ from diffusers.utils import load_video, export_to_video
231
+ from diffusers import ModularPipeline
232
+ from diffusers.modular_pipelines import PipelineState
233
+
234
+ repo_id = "krea/krea-realtime-video"
235
+ pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
236
+ pipe.load_components(
237
+ trust_remote_code=True,
238
+ device_map="cuda",
239
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
240
+ )
241
  for block in pipe.transformer.blocks:
242
  block.self_attn.fuse_projections()
243
 
244
+ n_samples = 9
245
+ frame_sample_len = 12
246
+ video = load_video(
247
+ "https://app-uploads.krea.ai/public/a8218957-1a80-43dc-81b2-da970b5f2221-video.mp4"
248
+ )
249
+
250
+ # Simulate streaming video input
251
+ frame_samples = [
252
+ video[sample_start : sample_start + frame_sample_len]
253
+ for sample_start in range(0, n_samples * frame_sample_len, frame_sample_len)
254
+ ]
255
+
256
+ frames = []
257
+ state = PipelineState()
258
+ prompt = ["A car racing down a snowny mountain road"]
259
+
260
+ block_idx = 0
261
+ generator = torch.Generator("cpu").manual_seed(42)
262
+ for frame_sample in tqdm(frame_samples):
263
+ state = pipe(
264
+ state,
265
+ video_stream=frame_sample,
266
+ prompt=prompt,
267
+ num_inference_steps=6,
268
+ strength=0.3,
269
+ block_idx=block_idx,
270
+ generator=generator,
271
+ )
272
+ frames.extend(state.values["videos"][0])
273
+
274
+ block_idx += 1
275
+
276
+ export_to_video(frames, "output-v2v-streaming.mp4", fps=24)
277
+ ```
278
+ </details>
279
+
280
+ <details>
281
+ <summary>Using LoRAs</summary>
282
+
283
+ ```py
284
+ import torch
285
+ from collections import deque
286
+ from tqdm import tqdm
287
+ from diffusers.utils import export_to_video
288
+ from diffusers import ModularPipeline
289
+ from diffusers.modular_pipelines import PipelineState
290
+
291
+ repo_id = "krea/krea-realtime-video"
292
+ pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
293
+ pipe.load_components(
294
+ trust_remote_code=True,
295
+ device_map="cuda",
296
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
297
+ )
298
+ pipe.transformer.load_lora_adapter(
299
+ "shauray/Origami_WanLora",
300
+ prefix="diffusion_model",
301
+ weight_name="origami_000000500.safetensors",
302
+ adapter_name="origami",
303
+ )
304
+ for block in pipe.transformer.blocks:
305
+ block.self_attn.fuse_projections()
306
+
307
+ num_blocks = 9
308
+
309
+ frames = []
310
+ state = PipelineState()
311
+ prompt = ["[origami] a cat sitting on a boat"]
312
+
313
+ generator = torch.Generator("cuda").manual_seed(42)
314
+ for block_idx in tqdm(range(num_blocks)):
315
  state = pipe(
316
  state,
317
  prompt=prompt,
318
  num_inference_steps=6,
319
  num_blocks=num_blocks,
320
+ block_idx=block_idx,
321
+ generator=generator,
322
+ )
323
+ frames.extend(state.values["videos"][0])
324
+
325
+ export_to_video(frames, "output.mp4", fps=24)
326
+ ```
327
+ </details>
328
+
329
+ <details>
330
+ <summary>Optimized Inference</summary>
331
+
332
+ To optimize inference speed and memory usage on Hopper level GPUs (H100s), we recommend using `torch.compile`, Flash Attention 3 and FP8 quantization with [torchao](https://github.com/pytorch/ao).
333
+
334
+ First let's set up our depedencies by enabling Flash Attention 3 via Hub [kernels](https://huggingface.co/docs/kernels/en/index) and installing the `torchao` and `kernels` packages.
335
+
336
+ ```shell
337
+ export DIFFUSERS_ENABLE_HUB_KERNELS=true
338
+ pip install -U kernels torchao
339
+ ```
340
+
341
+ Then we will iterate over the blocks of the transformer and apply quantization and `torch.compile`.
342
+
343
+ ```py
344
+ import torch
345
+ from collections import deque
346
+ from tqdm import tqdm
347
+ from diffusers.utils import export_to_video
348
+ from diffusers import ModularPipeline
349
+ from diffusers.modular_pipelines import PipelineState
350
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, quantize_
351
+
352
+ repo_id = "krea/krea-realtime-video"
353
+ pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
354
+ pipe.load_components(
355
+ trust_remote_code=True,
356
+ device_map="cuda",
357
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
358
+ )
359
+
360
+ for block in pipe.transformer.blocks:
361
+ block.self_attn.fuse_projections()
362
+
363
+ # Quantize just the transformer blocks
364
+ for block in pipe.transformer.blocks:
365
+ quantize_(block, Float8DynamicActivationFloat8WeightConfig())
366
+
367
+ # Compile just the attention modules
368
+ for submod in pipe.transformer.modules():
369
+ if submod.__class__.__name__ in ["CausalWanAttentionBlock"]:
370
+ submod.compile(fullgraph=False)
371
+
372
+ num_blocks = 9
373
+
374
+ state = PipelineState()
375
+ prompt = ["a cat sitting on a boat"]
376
+
377
+ # Compile warmup
378
+ for block_idx in range(num_blocks):
379
+ state = pipe(
380
+ state,
381
+ prompt=prompt,
382
+ num_inference_steps=2,
383
+ num_blocks=num_blocks,
384
  num_frames_per_block=num_frames_per_block,
385
  block_idx=block_idx,
386
  generator=torch.Generator("cuda").manual_seed(42),
387
  )
388
+
389
+ # Reset state
390
+ state = PipelineState()
391
+ generator = torch.Generator("cuda").manual_seed(42)
392
+ for block_idx in tqdm(range(num_blocks)):
393
+ state = pipe(
394
+ state,
395
+ prompt=prompt,
396
+ num_inference_steps=6,
397
+ num_blocks=num_blocks,
398
+ block_idx=block_idx,
399
+ generator=generator,
400
+ )
401
  frames.extend(state.values["videos"][0])
402
 
403
+ export_to_video(frames, "output.mp4", fps=24)
404
+ ```
405
+ </details>
before_denoise.py CHANGED
@@ -14,6 +14,7 @@
14
 
15
  import inspect
16
  from typing import List, Optional, Union, Dict
 
17
 
18
  import torch
19
 
@@ -25,6 +26,7 @@ from diffusers.modular_pipelines import (
25
  ModularPipeline,
26
  ModularPipelineBlocks,
27
  SequentialPipelineBlocks,
 
28
  PipelineState,
29
  )
30
  from diffusers.modular_pipelines.modular_pipeline_utils import (
@@ -221,7 +223,7 @@ def _initialize_crossattn_cache(
221
 
222
 
223
  class WanInputStep(ModularPipelineBlocks):
224
- model_name = "WanRT"
225
 
226
  @property
227
  def description(self) -> str:
@@ -237,7 +239,11 @@ class WanInputStep(ModularPipelineBlocks):
237
  @property
238
  def inputs(self) -> List[InputParam]:
239
  return [
240
- InputParam("num_videos_per_prompt", default=1),
 
 
 
 
241
  InputParam(
242
  "prompt_embeds",
243
  required=True,
@@ -331,8 +337,8 @@ class WanInputStep(ModularPipelineBlocks):
331
  return components, state
332
 
333
 
334
- class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
335
- model_name = "WanRT"
336
 
337
  @property
338
  def expected_components(self) -> List[ComponentSpec]:
@@ -350,6 +356,7 @@ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
350
  InputParam("num_inference_steps", default=4),
351
  InputParam("timesteps"),
352
  InputParam("sigmas"),
 
353
  ]
354
 
355
  @property
@@ -391,7 +398,10 @@ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
391
  ]
392
  )
393
  denoising_steps = torch.linspace(
394
- 1000, 0, block_state.num_inference_steps, dtype=torch.float32
 
 
 
395
  ).to(torch.long)
396
 
397
  block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps]
@@ -403,8 +413,8 @@ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
403
  return components, state
404
 
405
 
406
- class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
407
- model_name = "WanRT"
408
 
409
  @property
410
  def expected_components(self) -> List[ComponentSpec]:
@@ -423,15 +433,36 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
423
  @property
424
  def inputs(self) -> List[InputParam]:
425
  return [
426
- InputParam("height", type_hint=int),
427
- InputParam("width", type_hint=int),
428
- InputParam("num_blocks", type_hint=int),
429
- InputParam("num_frames_per_block", type_hint=int),
430
- InputParam("latents", type_hint=Optional[torch.Tensor]),
431
- InputParam("init_latents", type_hint=Optional[torch.Tensor]),
432
- InputParam("final_latents", type_hint=Optional[torch.Tensor]),
433
- InputParam("num_videos_per_prompt", type_hint=int, default=1),
434
- InputParam("generator"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  InputParam(
436
  "dtype",
437
  type_hint=torch.dtype,
@@ -442,20 +473,11 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
442
  @property
443
  def intermediate_outputs(self) -> List[OutputParam]:
444
  return [
445
- OutputParam(
446
- "latents",
447
- type_hint=torch.Tensor,
448
- description="The initial latents to use for the denoising process",
449
- ),
450
  OutputParam(
451
  "init_latents",
452
  type_hint=torch.Tensor,
453
  description="The initial latents to use for the denoising process",
454
  ),
455
- OutputParam(
456
- "final_latents",
457
- type_hint=torch.Tensor,
458
- ),
459
  ]
460
 
461
  @staticmethod
@@ -476,8 +498,8 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
476
  components,
477
  batch_size: int,
478
  num_channels_latents: int = 16,
479
- height: int = 352,
480
- width: int = 640,
481
  num_blocks: int = 9,
482
  num_frames_per_block: int = 3,
483
  dtype: Optional[torch.dtype] = None,
@@ -536,56 +558,398 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
536
  block_state.generator,
537
  block_state.init_latents,
538
  )
539
- if block_state.final_latents is None:
540
- block_state.final_latents = torch.zeros_like(
541
- block_state.init_latents, device=components.transformer.device
542
- )
543
  self.set_block_state(state, block_state)
544
 
545
  return components, state
546
 
547
 
548
- class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
549
  """
550
- Extracts a single block of latents from the full video buffer for streaming generation.
551
 
552
- This block simply slices the final_latents buffer to get the current block's latents.
553
- The final_latents buffer should be created beforehand using WanRTStreamingPrepareAllLatents.
 
 
554
  """
555
 
556
- model_name = "WanRT"
557
 
558
  @property
559
  def expected_components(self) -> List[ComponentSpec]:
560
- return []
 
 
561
 
562
  @property
563
  def description(self) -> str:
564
  return (
565
- "Extracts a single block from the full latent buffer for streaming generation. "
566
- "Slices final_latents based on block_idx to get current block's latents."
567
  )
568
 
569
  @property
570
  def inputs(self) -> List[InputParam]:
571
  return [
572
  InputParam(
573
- "final_latents",
574
- required=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  type_hint=torch.Tensor,
576
- description="Full latent buffer [B, C, total_frames, H, W]",
 
 
 
 
 
577
  ),
578
  InputParam(
579
  "init_latents",
580
- required=True,
581
  type_hint=torch.Tensor,
582
- description="Full latent buffer [B, C, total_frames, H, W]",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  ),
584
  InputParam(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  "latents",
586
  type_hint=torch.Tensor,
587
- description="Full latent buffer [B, C, total_frames, H, W]",
588
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  InputParam(
590
  "block_idx",
591
  required=True,
@@ -593,6 +957,12 @@ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
593
  default=0,
594
  description="Current block index to process",
595
  ),
 
 
 
 
 
 
596
  InputParam(
597
  "num_frames_per_block",
598
  required=True,
@@ -623,7 +993,7 @@ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
623
  ) -> PipelineState:
624
  block_state = self.get_block_state(state)
625
 
626
- num_frames_per_block = block_state.num_frames_per_block
627
  block_idx = block_state.block_idx
628
 
629
  # Calculate frame range for current block
@@ -642,7 +1012,7 @@ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
642
  return components, state
643
 
644
 
645
- class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
646
  """
647
  Initializes KV cache and cross-attention cache for streaming generation.
648
 
@@ -651,7 +1021,7 @@ class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
651
  Should be called once at the start of streaming generation.
652
  """
653
 
654
- model_name = "WanRT"
655
 
656
  @property
657
  def expected_components(self) -> List[ComponentSpec]:
@@ -772,7 +1142,7 @@ class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
772
  return components, state
773
 
774
 
775
- class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
776
  @property
777
  def inputs(self) -> List[InputParam]:
778
  return [
@@ -782,34 +1152,20 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
782
  description="Current block latents [B, C, num_frames_per_block, H, W]",
783
  ),
784
  InputParam(
785
- "num_frames_per_block",
786
  type_hint=int,
787
- description="Number of frames per block",
788
  ),
789
  InputParam(
790
  "block_idx",
791
  type_hint=int,
792
  description="Current block index to process",
793
  ),
794
- InputParam(
795
- "block_mask",
796
- description="Block-wise causal attention mask",
797
- ),
798
  InputParam(
799
  "current_start_frame",
800
  type_hint=int,
801
  description="Starting frame index for current block",
802
  ),
803
- InputParam(
804
- "videos",
805
- type_hint=torch.Tensor,
806
- description="Video frames for context encoding",
807
- ),
808
- InputParam(
809
- "final_latents",
810
- type_hint=torch.Tensor,
811
- description="Full latent buffer [B, C, total_frames, H, W]",
812
- ),
813
  InputParam(
814
  "prompt_embeds",
815
  type_hint=torch.Tensor,
@@ -825,16 +1181,14 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
825
  type_hint=torch.Tensor,
826
  description="Cross-attention cache",
827
  ),
828
- InputParam(
829
- "encoder_cache",
830
- description="Encoder feature cache",
831
- ),
832
  InputParam(
833
  "frame_cache_context",
834
  description="Cached context frames for reencoding",
835
  ),
836
  InputParam(
837
- "local_attn_size",
 
 
838
  ),
839
  ]
840
 
@@ -842,9 +1196,7 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
842
  def expected_configs(self) -> List[ConfigSpec]:
843
  return [ConfigSpec("seq_length", 32760)]
844
 
845
- def prepare_latents(self, components, block_state):
846
- frames = block_state.frame_cache_context[0].half()
847
-
848
  components.vae._enc_feat_map = [None] * 55
849
  latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax")
850
  latents_mean = (
@@ -861,30 +1213,23 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
861
 
862
  def get_context_frames(self, components, block_state):
863
  current_kv_cache_num_frames = components.config.kv_cache_num_frames
864
- context_frames = block_state.final_latents[
865
- :, :, : block_state.current_start_frame
866
- ]
867
-
868
- if (
869
  block_state.block_idx - 1
870
- ) * block_state.num_frames_per_block < current_kv_cache_num_frames:
871
- if current_kv_cache_num_frames == 1:
872
- context_frames = context_frames[:, :, :1]
873
- else:
874
- context_frames = torch.cat(
875
- (
876
- context_frames[:, :, :1],
877
- context_frames[:, :, 1:][
878
- :, :, -current_kv_cache_num_frames + 1 :
879
- ],
880
- ),
881
- dim=2,
882
- )
883
  else:
 
884
  context_frames = context_frames[:, :, 1:][
885
  :, :, -current_kv_cache_num_frames + 1 :
886
  ]
887
- first_frame_latent = self.prepare_latents(components, block_state)
 
 
888
  first_frame_latent = first_frame_latent.to(block_state.latents)
889
  context_frames = torch.cat((first_frame_latent, context_frames), dim=2)
890
 
@@ -895,20 +1240,15 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
895
  if block_state.block_idx == 0:
896
  return components, state
897
 
898
- start_frame = min(
899
- block_state.current_start_frame, components.config.kv_cache_num_frames
900
- )
901
  context_frames = self.get_context_frames(components, block_state)
902
- block_state.block_mask = (
903
- components.transformer._prepare_blockwise_causal_attn_mask(
904
- components.transformer.device,
905
- num_frames=context_frames.shape[2],
906
- frame_seqlen=components.config.frame_seq_length,
907
- num_frame_per_block=block_state.num_frames_per_block,
908
- local_attn_size=-1,
909
- )
910
  )
911
- components.transformer.block_mask = block_state.block_mask
912
  context_timestep = torch.zeros(
913
  (context_frames.shape[0], context_frames.shape[2]),
914
  device=components.transformer.device,
@@ -921,7 +1261,7 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
921
  kv_cache=block_state.kv_cache,
922
  seq_len=components.config.seq_length,
923
  crossattn_cache=block_state.crossattn_cache,
924
- current_start=start_frame * components.config.frame_seq_length,
925
  cache_start=None,
926
  )
927
  components.transformer.block_mask = None
@@ -929,13 +1269,13 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
929
  return components, state
930
 
931
 
932
- class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks):
933
  block_classes = [
934
- WanRTStreamingSetTimestepsStep,
935
- WanRTStreamingPrepareLatentsStep,
936
- WanRTStreamingExtractBlockLatentsStep,
937
- WanRTStreamingSetupKVCache,
938
- WanRTStreamingRecomputeKVCache,
939
  ]
940
  block_names = [
941
  "set_timesteps",
@@ -953,4 +1293,69 @@ class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks):
953
  + " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
954
  + " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
955
  + " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  import inspect
16
  from typing import List, Optional, Union, Dict
17
+ from collections import deque
18
 
19
  import torch
20
 
 
26
  ModularPipeline,
27
  ModularPipelineBlocks,
28
  SequentialPipelineBlocks,
29
+ AutoPipelineBlocks,
30
  PipelineState,
31
  )
32
  from diffusers.modular_pipelines.modular_pipeline_utils import (
 
223
 
224
 
225
  class WanInputStep(ModularPipelineBlocks):
226
+ model_name = "wan"
227
 
228
  @property
229
  def description(self) -> str:
 
239
  @property
240
  def inputs(self) -> List[InputParam]:
241
  return [
242
+ InputParam(
243
+ "num_videos_per_prompt",
244
+ default=1,
245
+ description="Number of videos to generate per prompt",
246
+ ),
247
  InputParam(
248
  "prompt_embeds",
249
  required=True,
 
337
  return components, state
338
 
339
 
340
+ class WanRTSetTimestepsStep(ModularPipelineBlocks):
341
+ model_name = "wan"
342
 
343
  @property
344
  def expected_components(self) -> List[ComponentSpec]:
 
356
  InputParam("num_inference_steps", default=4),
357
  InputParam("timesteps"),
358
  InputParam("sigmas"),
359
+ InputParam("strength", default=1.0),
360
  ]
361
 
362
  @property
 
398
  ]
399
  )
400
  denoising_steps = torch.linspace(
401
+ block_state.strength * 1000,
402
+ 0,
403
+ block_state.num_inference_steps,
404
+ dtype=torch.float32,
405
  ).to(torch.long)
406
 
407
  block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps]
 
413
  return components, state
414
 
415
 
416
+ class WanRTPrepareLatentsStep(ModularPipelineBlocks):
417
+ model_name = "wan"
418
 
419
  @property
420
  def expected_components(self) -> List[ComponentSpec]:
 
433
  @property
434
  def inputs(self) -> List[InputParam]:
435
  return [
436
+ InputParam(
437
+ "height",
438
+ type_hint=int,
439
+ description="Height of the video to generate in pixels",
440
+ ),
441
+ InputParam(
442
+ "width",
443
+ type_hint=int,
444
+ description="Width of the video to generate in pixels",
445
+ ),
446
+ InputParam(
447
+ "num_blocks",
448
+ type_hint=int,
449
+ description="Number of temporal blocks to generate",
450
+ ),
451
+ InputParam(
452
+ "init_latents",
453
+ type_hint=Optional[torch.Tensor],
454
+ description="Pre-initialized latents to use instead of random noise",
455
+ ),
456
+ InputParam(
457
+ "num_videos_per_prompt",
458
+ type_hint=int,
459
+ default=1,
460
+ description="Number of videos to generate per prompt",
461
+ ),
462
+ InputParam(
463
+ "generator",
464
+ description="Random number generator for reproducible generation",
465
+ ),
466
  InputParam(
467
  "dtype",
468
  type_hint=torch.dtype,
 
473
  @property
474
  def intermediate_outputs(self) -> List[OutputParam]:
475
  return [
 
 
 
 
 
476
  OutputParam(
477
  "init_latents",
478
  type_hint=torch.Tensor,
479
  description="The initial latents to use for the denoising process",
480
  ),
 
 
 
 
481
  ]
482
 
483
  @staticmethod
 
498
  components,
499
  batch_size: int,
500
  num_channels_latents: int = 16,
501
+ height: int = 480,
502
+ width: int = 832,
503
  num_blocks: int = 9,
504
  num_frames_per_block: int = 3,
505
  dtype: Optional[torch.dtype] = None,
 
558
  block_state.generator,
559
  block_state.init_latents,
560
  )
561
+ block_state.init_latents = block_state.init_latents.contiguous()
 
 
 
562
  self.set_block_state(state, block_state)
563
 
564
  return components, state
565
 
566
 
567
+ class WanRTPrepareVideoLatentStep(ModularPipelineBlocks):
568
  """
569
+ Prepares video latents from input PIL images for video-to-video generation.
570
 
571
+ This block:
572
+ 1. Processes input PIL images
573
+ 2. Encodes them to latent space using the VAE encoder
574
+ 3. Adds noise based on denoising strength for partial denoising
575
  """
576
 
577
+ model_name = "wan"
578
 
579
  @property
580
  def expected_components(self) -> List[ComponentSpec]:
581
+ return [
582
+ ComponentSpec("vae", AutoencoderKLWan),
583
+ ]
584
 
585
  @property
586
  def description(self) -> str:
587
  return (
588
+ "Prepares video latents from input PIL images by encoding to latent space "
589
+ "and optionally adding noise for video-to-video generation."
590
  )
591
 
592
  @property
593
  def inputs(self) -> List[InputParam]:
594
  return [
595
  InputParam(
596
+ "video",
597
+ type_hint=list,
598
+ description="List of PIL Images for input video",
599
+ ),
600
+ InputParam(
601
+ "height",
602
+ type_hint=int,
603
+ default=480,
604
+ description="Target height for video processing",
605
+ ),
606
+ InputParam(
607
+ "width",
608
+ type_hint=int,
609
+ default=832,
610
+ description="Target width for video processing",
611
+ ),
612
+ InputParam(
613
+ "strength",
614
+ type_hint=float,
615
+ default=1.0,
616
+ description="Denoising strength (0-1). Lower values preserve more of original video.",
617
+ ),
618
+ InputParam(
619
+ "generator",
620
+ description="Random generator for noise",
621
+ ),
622
+ InputParam(
623
+ "timesteps",
624
  type_hint=torch.Tensor,
625
+ description="All timesteps for noise scheduling",
626
+ ),
627
+ InputParam(
628
+ "num_blocks",
629
+ type_hint=int,
630
+ description="Number of blocks for generation",
631
  ),
632
  InputParam(
633
  "init_latents",
 
634
  type_hint=torch.Tensor,
635
+ ),
636
+ ]
637
+
638
+ @property
639
+ def intermediate_outputs(self) -> List[OutputParam]:
640
+ return [
641
+ OutputParam(
642
+ "init_latents",
643
+ type_hint=torch.Tensor,
644
+ description="Noised latents from input video ready for denoising",
645
+ ),
646
+ OutputParam(
647
+ "num_blocks",
648
+ type_hint=int,
649
+ description="Updated number of blocks based on video length",
650
+ ),
651
+ ]
652
+
653
+ def encode_frames(
654
+ self,
655
+ components,
656
+ video: Optional[torch.Tensor] = None,
657
+ timesteps: Optional[torch.Tensor] = None,
658
+ generator: Optional[torch.Generator] = None,
659
+ dtype: Optional[torch.dtype] = None,
660
+ device: Optional[torch.device] = None,
661
+ latents: Optional[torch.Tensor] = None,
662
+ ):
663
+ if latents is not None:
664
+ return latents.to(device, dtype)
665
+
666
+ if not hasattr(components.vae, "_enc_feat_map"):
667
+ components.vae.clear_cache()
668
+ else:
669
+ components.vae._enc_feat_map = [None] * 55
670
+
671
+ init_latents = [
672
+ retrieve_latents(
673
+ components.vae.encode(vid.unsqueeze(0).transpose(2, 1)),
674
+ sample_mode="argmax",
675
+ )
676
+ for vid in video
677
+ ]
678
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
679
+
680
+ latents_mean = (
681
+ torch.tensor(components.vae.config.latents_mean)
682
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
683
+ .to(device, dtype)
684
+ )
685
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
686
+ 1, components.vae.config.z_dim, 1, 1, 1
687
+ ).to(device, dtype)
688
+ init_latents = (init_latents - latents_mean) * latents_std
689
+ init_denoising_strength = timesteps[0] / 1000.0
690
+
691
+ # Add noise to latents
692
+ noise = randn_tensor(
693
+ init_latents.shape,
694
+ device=init_latents.device,
695
+ dtype=init_latents.dtype,
696
+ generator=generator,
697
+ )
698
+ init_latents = (
699
+ init_latents * (1.0 - init_denoising_strength)
700
+ + noise * init_denoising_strength
701
+ )
702
+ init_latents = init_latents.to(components.transformer.dtype).contiguous()
703
+
704
+ return init_latents
705
+
706
+ @torch.no_grad()
707
+ def __call__(
708
+ self, components: ModularPipeline, state: PipelineState
709
+ ) -> PipelineState:
710
+ block_state = self.get_block_state(state)
711
+
712
+ if block_state.init_latents is not None:
713
+ block_state.init_latents = block_state.init_latents.to(
714
+ components.transformer.dtype
715
+ )
716
+ self.set_block_state(state, block_state)
717
+ return components, state
718
+
719
+ video = (
720
+ components.video_processor.preprocess(
721
+ block_state.video, block_state.height, block_state.width
722
+ )
723
+ .unsqueeze(0)
724
+ .to(components.vae.device, components.vae.dtype)
725
+ )
726
+ block_state.init_latents = self.encode_frames(
727
+ components,
728
+ video,
729
+ block_state.timesteps,
730
+ block_state.generator,
731
+ components.vae.dtype,
732
+ components.vae.device,
733
+ block_state.init_latents,
734
+ )
735
+ block_state.init_latents = block_state.init_latents.to(
736
+ components.transformer.dtype
737
+ )
738
+
739
+ self.set_block_state(state, block_state)
740
+ return components, state
741
+
742
+
743
+ class WanRTStreamPrepareVideoLatentStep(ModularPipelineBlocks):
744
+ """
745
+ Prepares video latents from input PIL images for video-to-video generation.
746
+
747
+ This block:
748
+ 1. Processes input PIL images
749
+ 2. Encodes them to latent space using the VAE encoder
750
+ 3. Adds noise based on denoising strength for partial denoising
751
+ """
752
+
753
+ model_name = "wan"
754
+
755
+ @property
756
+ def expected_components(self) -> List[ComponentSpec]:
757
+ return [
758
+ ComponentSpec("vae", AutoencoderKLWan),
759
+ ]
760
+
761
+ @property
762
+ def description(self) -> str:
763
+ return (
764
+ "Prepares video latents from input PIL images by encoding to latent space "
765
+ "and optionally adding noise for video-to-video generation."
766
+ )
767
+
768
+ @property
769
+ def inputs(self) -> List[InputParam]:
770
+ return [
771
+ InputParam(
772
+ "video_stream",
773
+ type_hint=list,
774
+ description="List of PIL Images for input video",
775
+ ),
776
+ InputParam(
777
+ "height",
778
+ type_hint=int,
779
+ default=480,
780
+ description="Target height for video processing",
781
  ),
782
  InputParam(
783
+ "width",
784
+ type_hint=int,
785
+ default=832,
786
+ description="Target width for video processing",
787
+ ),
788
+ InputParam(
789
+ "generator",
790
+ type_hint=torch.Generator,
791
+ description="Random generator for noise",
792
+ ),
793
+ InputParam(
794
+ "timesteps",
795
+ type_hint=torch.Tensor,
796
+ description="All timesteps for noise scheduling",
797
+ ),
798
+ InputParam(
799
+ "block_idx",
800
+ type_hint=int,
801
+ description="Index of current block to denoise",
802
+ ),
803
+ InputParam(
804
+ "num_blocks",
805
+ type_hint=int,
806
+ description="Total number of blocks to denoise",
807
+ ),
808
+ InputParam(
809
+ "input_frames_cache",
810
+ default=deque(maxlen=24),
811
+ description="Cached input video frames for context encoding",
812
+ ),
813
+ ]
814
+
815
+ @property
816
+ def intermediate_outputs(self) -> List[OutputParam]:
817
+ return [
818
+ OutputParam(
819
  "latents",
820
  type_hint=torch.Tensor,
821
+ description="Noised latents from input video ready for denoising",
822
  ),
823
+ OutputParam(
824
+ "current_start_frame",
825
+ type_hint=int,
826
+ ),
827
+ ]
828
+
829
+ def encode_frames(
830
+ self,
831
+ components,
832
+ video: Optional[torch.Tensor] = None,
833
+ dtype: Optional[torch.dtype] = None,
834
+ device: Optional[torch.device] = None,
835
+ latents: Optional[torch.Tensor] = None,
836
+ ):
837
+ if latents is not None:
838
+ return latents.to(device, dtype)
839
+
840
+ if not hasattr(components.vae, "_enc_feat_map"):
841
+ components.vae.clear_cache()
842
+ else:
843
+ components.vae._enc_feat_map = [None] * 55
844
+
845
+ init_latents = [
846
+ retrieve_latents(
847
+ components.vae.encode(vid.unsqueeze(0).transpose(2, 1)),
848
+ sample_mode="argmax",
849
+ )
850
+ for vid in video
851
+ ]
852
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
853
+
854
+ latents_mean = (
855
+ torch.tensor(components.vae.config.latents_mean)
856
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
857
+ .to(device, dtype)
858
+ )
859
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
860
+ 1, components.vae.config.z_dim, 1, 1, 1
861
+ ).to(device, dtype)
862
+ init_latents = (init_latents - latents_mean) * latents_std
863
+
864
+ return init_latents
865
+
866
+ def resample_frames(self, frames, target_length):
867
+ """Resample a list to the target length using linear interpolation of indices"""
868
+ if len(frames) == target_length:
869
+ return frames
870
+
871
+ indices = (
872
+ torch.linspace(0, len(frames) - 1, target_length, device="cpu")
873
+ .round()
874
+ .long()
875
+ )
876
+ return [frames[i] for i in indices]
877
+
878
+ @torch.no_grad()
879
+ def __call__(
880
+ self, components: ModularPipeline, state: PipelineState
881
+ ) -> PipelineState:
882
+ block_state = self.get_block_state(state)
883
+
884
+ if block_state.video_stream is None:
885
+ raise ValueError(
886
+ "Stream Video to Video requires an input video. Please provide a`video` input to the Pipeline"
887
+ )
888
+
889
+ block_state.input_frames_cache.extend(block_state.video_stream)
890
+ video = (
891
+ components.video_processor.preprocess(
892
+ list(block_state.input_frames_cache),
893
+ block_state.height,
894
+ block_state.width,
895
+ )
896
+ .unsqueeze(0)
897
+ .to(components.vae.device, components.vae.dtype)
898
+ )
899
+
900
+ block_state.current_start_frame = (
901
+ block_state.block_idx * components.config.num_frames_per_block
902
+ )
903
+ init_latents = self.encode_frames(
904
+ components,
905
+ video,
906
+ components.vae.dtype,
907
+ components.vae.device,
908
+ None,
909
+ )
910
+ init_latents = init_latents[:, :, -components.config.num_frames_per_block :]
911
+
912
+ strength = block_state.timesteps[0] / 1000.0
913
+ noise = randn_tensor(
914
+ init_latents.shape,
915
+ device=components.transformer.device,
916
+ dtype=components.transformer.dtype,
917
+ generator=block_state.generator,
918
+ )
919
+
920
+ init_latents = init_latents * (1.0 - strength) + noise * strength
921
+ init_latents = init_latents.to(components.transformer.dtype).contiguous()
922
+
923
+ block_state.latents = init_latents
924
+
925
+ self.set_block_state(state, block_state)
926
+ return components, state
927
+
928
+
929
+ class WanRTExtractBlockLatentsStep(ModularPipelineBlocks):
930
+ """
931
+ Extracts a single block of latents from the full video buffer for streaming generation.
932
+
933
+ This block simply slices the final_latents buffer to get the current block's latents.
934
+ The final_latents buffer should be created beforehand using WanRTPrepareAllLatents.
935
+ """
936
+
937
+ model_name = "wan"
938
+
939
+ @property
940
+ def expected_components(self) -> List[ComponentSpec]:
941
+ return []
942
+
943
+ @property
944
+ def description(self) -> str:
945
+ return (
946
+ "Extracts a single block from the full latent buffer for streaming generation. "
947
+ "Slices final_latents based on block_idx to get current block's latents."
948
+ )
949
+
950
+ @property
951
+ def inputs(self) -> List[InputParam]:
952
+ return [
953
  InputParam(
954
  "block_idx",
955
  required=True,
 
957
  default=0,
958
  description="Current block index to process",
959
  ),
960
+ InputParam(
961
+ "init_latents",
962
+ required=True,
963
+ type_hint=torch.Tensor,
964
+ description="Full latent buffer [B, C, total_frames, H, W]",
965
+ ),
966
  InputParam(
967
  "num_frames_per_block",
968
  required=True,
 
993
  ) -> PipelineState:
994
  block_state = self.get_block_state(state)
995
 
996
+ num_frames_per_block = components.config.num_frames_per_block
997
  block_idx = block_state.block_idx
998
 
999
  # Calculate frame range for current block
 
1012
  return components, state
1013
 
1014
 
1015
+ class WanRTSetupKVCache(ModularPipelineBlocks):
1016
  """
1017
  Initializes KV cache and cross-attention cache for streaming generation.
1018
 
 
1021
  Should be called once at the start of streaming generation.
1022
  """
1023
 
1024
+ model_name = "wan"
1025
 
1026
  @property
1027
  def expected_components(self) -> List[ComponentSpec]:
 
1142
  return components, state
1143
 
1144
 
1145
+ class WanRTRecomputeKVCache(ModularPipelineBlocks):
1146
  @property
1147
  def inputs(self) -> List[InputParam]:
1148
  return [
 
1152
  description="Current block latents [B, C, num_frames_per_block, H, W]",
1153
  ),
1154
  InputParam(
1155
+ "num_blocks",
1156
  type_hint=int,
1157
+ description="Number of blocks to denoise",
1158
  ),
1159
  InputParam(
1160
  "block_idx",
1161
  type_hint=int,
1162
  description="Current block index to process",
1163
  ),
 
 
 
 
1164
  InputParam(
1165
  "current_start_frame",
1166
  type_hint=int,
1167
  description="Starting frame index for current block",
1168
  ),
 
 
 
 
 
 
 
 
 
 
1169
  InputParam(
1170
  "prompt_embeds",
1171
  type_hint=torch.Tensor,
 
1181
  type_hint=torch.Tensor,
1182
  description="Cross-attention cache",
1183
  ),
 
 
 
 
1184
  InputParam(
1185
  "frame_cache_context",
1186
  description="Cached context frames for reencoding",
1187
  ),
1188
  InputParam(
1189
+ "current_denoised_latents",
1190
+ type_hint=torch.Tensor,
1191
+ description="Current denoised latents",
1192
  ),
1193
  ]
1194
 
 
1196
  def expected_configs(self) -> List[ConfigSpec]:
1197
  return [ConfigSpec("seq_length", 32760)]
1198
 
1199
+ def prepare_latents(self, components, frames):
 
 
1200
  components.vae._enc_feat_map = [None] * 55
1201
  latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax")
1202
  latents_mean = (
 
1213
 
1214
  def get_context_frames(self, components, block_state):
1215
  current_kv_cache_num_frames = components.config.kv_cache_num_frames
1216
+ total_frames_generated = (
 
 
 
 
1217
  block_state.block_idx - 1
1218
+ ) * components.config.num_frames_per_block
1219
+
1220
+ if total_frames_generated < current_kv_cache_num_frames:
1221
+ context_frames = block_state.current_denoised_latents[
1222
+ :, :, :current_kv_cache_num_frames
1223
+ ]
1224
+
 
 
 
 
 
 
1225
  else:
1226
+ context_frames = block_state.current_denoised_latents
1227
  context_frames = context_frames[:, :, 1:][
1228
  :, :, -current_kv_cache_num_frames + 1 :
1229
  ]
1230
+ first_frame_latent = self.prepare_latents(
1231
+ components, frames=block_state.frame_cache_context[0].half()
1232
+ )
1233
  first_frame_latent = first_frame_latent.to(block_state.latents)
1234
  context_frames = torch.cat((first_frame_latent, context_frames), dim=2)
1235
 
 
1240
  if block_state.block_idx == 0:
1241
  return components, state
1242
 
 
 
 
1243
  context_frames = self.get_context_frames(components, block_state)
1244
+ block_mask = components.transformer._prepare_blockwise_causal_attn_mask(
1245
+ components.transformer.device,
1246
+ num_frames=context_frames.shape[2],
1247
+ frame_seqlen=components.config.frame_seq_length,
1248
+ num_frame_per_block=components.config.num_frames_per_block,
1249
+ local_attn_size=-1,
 
 
1250
  )
1251
+ components.transformer.block_mask = block_mask
1252
  context_timestep = torch.zeros(
1253
  (context_frames.shape[0], context_frames.shape[2]),
1254
  device=components.transformer.device,
 
1261
  kv_cache=block_state.kv_cache,
1262
  seq_len=components.config.seq_length,
1263
  crossattn_cache=block_state.crossattn_cache,
1264
+ current_start=0, # when updating the kv cache with block_mask the current_start is unused
1265
  cache_start=None,
1266
  )
1267
  components.transformer.block_mask = None
 
1269
  return components, state
1270
 
1271
 
1272
+ class WanRTBeforeDenoiseStep(SequentialPipelineBlocks):
1273
  block_classes = [
1274
+ WanRTSetTimestepsStep,
1275
+ WanRTPrepareLatentsStep,
1276
+ WanRTExtractBlockLatentsStep,
1277
+ WanRTSetupKVCache,
1278
+ WanRTRecomputeKVCache,
1279
  ]
1280
  block_names = [
1281
  "set_timesteps",
 
1293
  + " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
1294
  + " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
1295
  + " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
1296
+ + " - `WanRTPrepareVideoLatentStep` is used to prepare video latents from input video\n"
1297
+ )
1298
+
1299
+
1300
+ class WanRTVideoToVideoBeforeDenoiseStep(SequentialPipelineBlocks):
1301
+ block_classes = [
1302
+ WanRTSetTimestepsStep,
1303
+ WanRTPrepareVideoLatentStep,
1304
+ WanRTExtractBlockLatentsStep,
1305
+ WanRTSetupKVCache,
1306
+ WanRTRecomputeKVCache,
1307
+ ]
1308
+ block_names = [
1309
+ "set_timesteps",
1310
+ "prepare_video_latents",
1311
+ "extract_block_init_latents",
1312
+ "setup_kv_cache",
1313
+ "recompute_kv_cache",
1314
+ ]
1315
+
1316
+ @property
1317
+ def description(self):
1318
+ return (
1319
+ "Before denoise step that prepare the inputs for the denoise step.\n"
1320
+ + "This is a sequential pipeline blocks:\n"
1321
+ + " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
1322
+ + " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
1323
+ + " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
1324
+ + " - `WanRTPrepareVideoLatentStep` is used to prepare video latents from input video\n"
1325
  )
1326
+
1327
+
1328
+ class WanRTStreamVideoToVideoBeforeDenoiseStep(SequentialPipelineBlocks):
1329
+ block_classes = [
1330
+ WanRTSetTimestepsStep,
1331
+ WanRTStreamPrepareVideoLatentStep,
1332
+ WanRTSetupKVCache,
1333
+ WanRTRecomputeKVCache,
1334
+ ]
1335
+ block_names = [
1336
+ "set_timesteps",
1337
+ "prepare_video_latents",
1338
+ "setup_kv_cache",
1339
+ "recompute_kv_cache",
1340
+ ]
1341
+
1342
+ @property
1343
+ def description(self):
1344
+ return (
1345
+ "Before denoise step that prepare the inputs for the denoise step.\n"
1346
+ + "This is a sequential pipeline blocks:\n"
1347
+ + " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
1348
+ + " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
1349
+ + " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
1350
+ + " - `WanRTPrepareVideoLatentStep` is used to prepare video latents from input video\n"
1351
+ )
1352
+
1353
+
1354
+ class WanRTAutoBeforeDenoiseStep(AutoPipelineBlocks):
1355
+ block_classes = [
1356
+ WanRTVideoToVideoBeforeDenoiseStep,
1357
+ WanRTStreamVideoToVideoBeforeDenoiseStep,
1358
+ WanRTBeforeDenoiseStep,
1359
+ ]
1360
+ block_names = ["video-to-video", "stream-to-video", "text-to-video"]
1361
+ block_trigger_inputs = ["video", "video_stream", None]
decoders.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
 
15
  from typing import Any, List, Tuple, Union
 
16
 
17
  import numpy as np
18
  import PIL
@@ -35,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
 
36
 
37
  class WanRTDecodeStep(ModularPipelineBlocks):
38
- model_name = "WanRT"
39
  decoder_cache = []
40
 
41
  @property
@@ -62,7 +63,15 @@ class WanRTDecodeStep(ModularPipelineBlocks):
62
  @property
63
  def inputs(self) -> List[Tuple[str, Any]]:
64
  return [
65
- InputParam("output_type", default="pil"),
 
 
 
 
 
 
 
 
66
  InputParam(
67
  "latents",
68
  required=True,
@@ -71,15 +80,12 @@ class WanRTDecodeStep(ModularPipelineBlocks):
71
  ),
72
  InputParam(
73
  "frame_cache_context",
74
- description="The denoised latents from the denoising step",
75
- ),
76
- InputParam(
77
- "block_idx",
78
- description="The denoised latents from the denoising step",
79
  ),
80
  InputParam(
81
  "decoder_cache",
82
- description="The denoised latents from the denoising step",
83
  ),
84
  ]
85
 
@@ -100,6 +106,10 @@ class WanRTDecodeStep(ModularPipelineBlocks):
100
  block_state = self.get_block_state(state)
101
  vae_dtype = components.vae.dtype
102
 
 
 
 
 
103
  # Disable clearing cache
104
  if block_state.block_idx == 0:
105
  components.vae.clear_cache()
@@ -134,12 +144,10 @@ class WanRTDecodeStep(ModularPipelineBlocks):
134
 
135
  block_state.decoder_cache = components.vae._feat_map
136
  block_state.frame_cache_context.extend(videos.split(1, dim=2))
137
-
138
  videos = components.video_processor.postprocess_video(
139
  videos, output_type=block_state.output_type
140
  )
141
  block_state.videos = videos
142
-
143
  self.set_block_state(state, block_state)
144
 
145
  return components, state
 
13
  # limitations under the License.
14
 
15
  from typing import Any, List, Tuple, Union
16
+ from collections import deque
17
 
18
  import numpy as np
19
  import PIL
 
36
 
37
 
38
  class WanRTDecodeStep(ModularPipelineBlocks):
39
+ model_name = "wan"
40
  decoder_cache = []
41
 
42
  @property
 
63
  @property
64
  def inputs(self) -> List[Tuple[str, Any]]:
65
  return [
66
+ InputParam(
67
+ "output_type",
68
+ default="pil",
69
+ description="The output format for the generated videos (pil, latent, pt, or np)",
70
+ ),
71
+ InputParam(
72
+ "block_idx",
73
+ description="Index of the current block being decoded",
74
+ ),
75
  InputParam(
76
  "latents",
77
  required=True,
 
80
  ),
81
  InputParam(
82
  "frame_cache_context",
83
+ description="Deque object to store most recently decoded frames",
84
+ type_hint=deque
 
 
 
85
  ),
86
  InputParam(
87
  "decoder_cache",
88
+ description="Decoder feature cache",
89
  ),
90
  ]
91
 
 
106
  block_state = self.get_block_state(state)
107
  vae_dtype = components.vae.dtype
108
 
109
+ if block_state.frame_cache_context is None:
110
+ frame_cache_len = 1 + (components.config.kv_cache_num_frames - 1) * 4
111
+ block_state.frame_cache_context = deque(maxlen=frame_cache_len)
112
+
113
  # Disable clearing cache
114
  if block_state.block_idx == 0:
115
  components.vae.clear_cache()
 
144
 
145
  block_state.decoder_cache = components.vae._feat_map
146
  block_state.frame_cache_context.extend(videos.split(1, dim=2))
 
147
  videos = components.video_processor.postprocess_video(
148
  videos, output_type=block_state.output_type
149
  )
150
  block_state.videos = videos
 
151
  self.set_block_state(state, block_state)
152
 
153
  return components, state
denoise.py CHANGED
@@ -16,8 +16,6 @@ from typing import Any, List, Tuple
16
 
17
  import torch
18
 
19
- from diffusers.configuration_utils import FrozenDict
20
- from diffusers.guiders import ClassifierFreeGuidance
21
  from diffusers.models import AutoModel
22
  from diffusers.schedulers import UniPCMultistepScheduler
23
  from diffusers.utils import logging
@@ -39,8 +37,8 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
 
41
 
42
- class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
43
- model_name = "WanRTStreaming"
44
 
45
  @property
46
  def expected_components(self) -> List[ComponentSpec]:
@@ -51,14 +49,12 @@ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
51
  return (
52
  "Step within the denoising loop that denoise the latents with guidance. "
53
  "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
54
- "object (e.g. `WanRTStreamingDenoiseLoopWrapper`)"
55
  )
56
 
57
  @property
58
  def inputs(self) -> List[Tuple[str, Any]]:
59
  return [
60
- InputParam("attention_kwargs"),
61
- InputParam("block_idx"),
62
  InputParam(
63
  "latents",
64
  required=True,
@@ -69,36 +65,25 @@ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
69
  "prompt_embeds",
70
  required=True,
71
  type_hint=torch.Tensor,
 
72
  ),
73
  InputParam(
74
  "kv_cache",
75
  required=True,
76
  type_hint=torch.Tensor,
 
77
  ),
78
  InputParam(
79
  "crossattn_cache",
80
  required=True,
81
  type_hint=torch.Tensor,
 
82
  ),
83
  InputParam(
84
  "current_start_frame",
85
  required=True,
86
  type_hint=torch.Tensor,
87
- ),
88
- InputParam(
89
- "num_inference_steps",
90
- required=True,
91
- type_hint=int,
92
- default=4,
93
- description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
94
- ),
95
- InputParam(
96
- kwargs_type="guider_input_fields",
97
- description=(
98
- "All conditional model inputs that need to be prepared with guider. "
99
- "It should contain prompt_embeds/negative_prompt_embeds. "
100
- "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
101
- ),
102
  ),
103
  ]
104
 
@@ -116,20 +101,21 @@ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
116
 
117
  block_state.noise_pred = components.transformer(
118
  x=block_state.latents,
119
- t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block),
 
 
120
  context=block_state.prompt_embeds,
121
  kv_cache=block_state.kv_cache,
122
  seq_len=components.config.seq_length,
123
  crossattn_cache=block_state.crossattn_cache,
124
  current_start=start_frame * components.config.frame_seq_length,
125
- cache_start=start_frame * components.config.frame_seq_length,
126
  )
127
-
128
  return components, block_state
129
 
130
 
131
- class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
132
- model_name = "WanRTStreaming"
133
 
134
  @property
135
  def expected_components(self) -> List[ComponentSpec]:
@@ -142,18 +128,24 @@ class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
142
  return (
143
  "step within the denoising loop that update the latents. "
144
  "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
145
- "object (e.g. `WanRTStreamingDenoiseLoopWrapper`)"
146
  )
147
 
148
  @property
149
  def inputs(self) -> List[Tuple[str, Any]]:
150
- return []
151
-
152
- @property
153
- def intermediate_inputs(self) -> List[str]:
154
  return [
155
- InputParam("generator"),
156
- InputParam("block_id"),
 
 
 
 
 
 
 
 
 
 
157
  ]
158
 
159
  @property
@@ -185,14 +177,13 @@ class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
185
  block_state.latents.double()
186
  - sigma_t.double() * block_state.noise_pred.double()
187
  ).to(latents_dtype)
188
-
189
  block_state.latents = latents
190
 
191
  return components, block_state
192
 
193
 
194
- class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
195
- model_name = "WanRTStreaming"
196
 
197
  @property
198
  def description(self) -> str:
@@ -201,7 +192,7 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
201
  "Recomputes cache from context frames, denoises current block, and updates cache."
202
  )
203
 
204
- def add_noise(self, components, block_state, sample, noise, timestep, index):
205
  timesteps = block_state.all_timesteps
206
  sigmas = block_state.sigmas.to(timesteps.device)
207
 
@@ -232,38 +223,25 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
232
  "all_timesteps",
233
  required=True,
234
  type_hint=torch.Tensor,
235
- description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
236
  ),
237
  InputParam(
238
  "sigmas",
239
  required=True,
240
  type_hint=torch.Tensor,
241
- description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
242
  ),
243
- InputParam("final_latents", type_hint=torch.Tensor),
244
  InputParam(
245
  "num_inference_steps",
246
  required=True,
247
  type_hint=int,
248
  description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
249
  ),
250
- InputParam(
251
- "num_frames_per_block",
252
- required=True,
253
- type_hint=int,
254
- default=3,
255
- ),
256
  InputParam(
257
  "current_start_frame",
258
  required=True,
259
  type_hint=int,
260
  ),
261
- InputParam(
262
- "block_idx",
263
- ),
264
- InputParam(
265
- "generator",
266
- ),
267
  ]
268
 
269
  @torch.no_grad()
@@ -279,7 +257,6 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
279
 
280
  block_state.latents = (
281
  self.add_noise(
282
- components,
283
  block_state,
284
  block_state.latents.transpose(1, 2).squeeze(0),
285
  randn_tensor(
@@ -290,31 +267,23 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
290
  ),
291
  t1.expand(
292
  block_state.latents.shape[0],
293
- block_state.num_frames_per_block,
294
  ),
295
- i,
296
  )
297
  .unsqueeze(0)
298
  .transpose(1, 2)
299
  )
300
 
301
- # Update the state
302
- block_state.final_latents[
303
- :,
304
- :,
305
- block_state.current_start_frame : block_state.current_start_frame
306
- + block_state.num_frames_per_block,
307
- ] = block_state.latents
308
-
309
  self.set_block_state(state, block_state)
310
 
311
  return components, state
312
 
313
 
314
- class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper):
315
  block_classes = [
316
- WanRTStreamingLoopDenoiser,
317
- WanRTStreamingLoopAfterDenoiser,
318
  ]
319
  block_names = ["denoiser", "after_denoiser"]
320
 
@@ -322,9 +291,9 @@ class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper):
322
  def description(self) -> str:
323
  return (
324
  "Denoise step that iteratively denoise the latents. \n"
325
- "Its loop logic is defined in `WanRTStreamingDenoiseLoopWrapper.__call__` method \n"
326
  "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
327
- " - `WanRTStreamingLoopDenoiser`\n"
328
- " - `WanRTStreamingLoopAfterDenoiser`\n"
329
  "This block supports both text2vid tasks."
330
  )
 
16
 
17
  import torch
18
 
 
 
19
  from diffusers.models import AutoModel
20
  from diffusers.schedulers import UniPCMultistepScheduler
21
  from diffusers.utils import logging
 
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
 
39
 
40
+ class WanRTLoopDenoiser(ModularPipelineBlocks):
41
+ model_name = "wan"
42
 
43
  @property
44
  def expected_components(self) -> List[ComponentSpec]:
 
49
  return (
50
  "Step within the denoising loop that denoise the latents with guidance. "
51
  "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
52
+ "object (e.g. `WanRTDenoiseLoopWrapper`)"
53
  )
54
 
55
  @property
56
  def inputs(self) -> List[Tuple[str, Any]]:
57
  return [
 
 
58
  InputParam(
59
  "latents",
60
  required=True,
 
65
  "prompt_embeds",
66
  required=True,
67
  type_hint=torch.Tensor,
68
+ description="Text embeddings to condition the denoising process",
69
  ),
70
  InputParam(
71
  "kv_cache",
72
  required=True,
73
  type_hint=torch.Tensor,
74
+ description="KV Cache of the transformer model",
75
  ),
76
  InputParam(
77
  "crossattn_cache",
78
  required=True,
79
  type_hint=torch.Tensor,
80
+ description="Cross Attention Cache of the transformer model",
81
  ),
82
  InputParam(
83
  "current_start_frame",
84
  required=True,
85
  type_hint=torch.Tensor,
86
+ description="Starting frame index for the current block in the streaming generation",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ),
88
  ]
89
 
 
101
 
102
  block_state.noise_pred = components.transformer(
103
  x=block_state.latents,
104
+ t=t.expand(
105
+ block_state.latents.shape[0], components.config.num_frames_per_block
106
+ ),
107
  context=block_state.prompt_embeds,
108
  kv_cache=block_state.kv_cache,
109
  seq_len=components.config.seq_length,
110
  crossattn_cache=block_state.crossattn_cache,
111
  current_start=start_frame * components.config.frame_seq_length,
112
+ cache_start=None,
113
  )
 
114
  return components, block_state
115
 
116
 
117
+ class WanRTLoopAfterDenoiser(ModularPipelineBlocks):
118
+ model_name = "wan"
119
 
120
  @property
121
  def expected_components(self) -> List[ComponentSpec]:
 
128
  return (
129
  "step within the denoising loop that update the latents. "
130
  "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
131
+ "object (e.g. `WanRTDenoiseLoopWrapper`)"
132
  )
133
 
134
  @property
135
  def inputs(self) -> List[Tuple[str, Any]]:
 
 
 
 
136
  return [
137
+ InputParam(
138
+ "latents",
139
+ description="Current latents being denoised",
140
+ ),
141
+ InputParam(
142
+ "all_timesteps",
143
+ description="All timesteps for the denoising process",
144
+ ),
145
+ InputParam(
146
+ "sigmas",
147
+ description="Noise schedule sigmas for each timestep",
148
+ ),
149
  ]
150
 
151
  @property
 
177
  block_state.latents.double()
178
  - sigma_t.double() * block_state.noise_pred.double()
179
  ).to(latents_dtype)
 
180
  block_state.latents = latents
181
 
182
  return components, block_state
183
 
184
 
185
+ class WanRTDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
186
+ model_name = "wan"
187
 
188
  @property
189
  def description(self) -> str:
 
192
  "Recomputes cache from context frames, denoises current block, and updates cache."
193
  )
194
 
195
+ def add_noise(self, block_state, sample, noise, timestep):
196
  timesteps = block_state.all_timesteps
197
  sigmas = block_state.sigmas.to(timesteps.device)
198
 
 
223
  "all_timesteps",
224
  required=True,
225
  type_hint=torch.Tensor,
 
226
  ),
227
  InputParam(
228
  "sigmas",
229
  required=True,
230
  type_hint=torch.Tensor,
 
231
  ),
232
+ InputParam("current_denoised_latents", type_hint=torch.Tensor),
233
  InputParam(
234
  "num_inference_steps",
235
  required=True,
236
  type_hint=int,
237
  description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
238
  ),
 
 
 
 
 
 
239
  InputParam(
240
  "current_start_frame",
241
  required=True,
242
  type_hint=int,
243
  ),
244
+ InputParam("generator", type_hint=torch.Generator),
 
 
 
 
 
245
  ]
246
 
247
  @torch.no_grad()
 
257
 
258
  block_state.latents = (
259
  self.add_noise(
 
260
  block_state,
261
  block_state.latents.transpose(1, 2).squeeze(0),
262
  randn_tensor(
 
267
  ),
268
  t1.expand(
269
  block_state.latents.shape[0],
270
+ components.config.num_frames_per_block,
271
  ),
 
272
  )
273
  .unsqueeze(0)
274
  .transpose(1, 2)
275
  )
276
 
277
+ block_state.current_denoised_latents = block_state.latents
 
 
 
 
 
 
 
278
  self.set_block_state(state, block_state)
279
 
280
  return components, state
281
 
282
 
283
+ class WanRTDenoiseStep(WanRTDenoiseLoopWrapper):
284
  block_classes = [
285
+ WanRTLoopDenoiser,
286
+ WanRTLoopAfterDenoiser,
287
  ]
288
  block_names = ["denoiser", "after_denoiser"]
289
 
 
291
  def description(self) -> str:
292
  return (
293
  "Denoise step that iteratively denoise the latents. \n"
294
+ "Its loop logic is defined in `WanRTDenoiseLoopWrapper.__call__` method \n"
295
  "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
296
+ " - `WanRTLoopDenoiser`\n"
297
+ " - `WanRTLoopAfterDenoiser`\n"
298
  "This block supports both text2vid tasks."
299
  )
encoders.py CHANGED
@@ -56,7 +56,7 @@ def prompt_clean(text):
56
  return text
57
 
58
 
59
- class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
60
  model_name = "WanRTStreaming"
61
 
62
  @property
@@ -83,8 +83,14 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
83
  @property
84
  def inputs(self) -> List[InputParam]:
85
  return [
86
- InputParam("prompt"),
87
- InputParam("negative_prompt"),
 
 
 
 
 
 
88
  InputParam(
89
  "prompt_embeds",
90
  type_hint=torch.Tensor,
@@ -95,7 +101,10 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
95
  type_hint=torch.Tensor,
96
  description="negative text embeddings used to guide the image generation",
97
  ),
98
- InputParam("attention_kwargs"),
 
 
 
99
  ]
100
 
101
  @property
@@ -205,7 +214,7 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
205
  batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
206
 
207
  if prompt_embeds is None:
208
- prompt_embeds = WanRTStreamingTextEncoderStep._get_t5_prompt_embeds(
209
  components, prompt, max_sequence_length, device
210
  )
211
 
@@ -229,10 +238,8 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
229
  " the batch size of `prompt`."
230
  )
231
 
232
- negative_prompt_embeds = (
233
- WanRTStreamingTextEncoderStep._get_t5_prompt_embeds(
234
- components, negative_prompt, max_sequence_length, device
235
- )
236
  )
237
 
238
  bs_embed, seq_len, _ = prompt_embeds.shape
@@ -266,7 +273,7 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
266
  (
267
  block_state.prompt_embeds,
268
  block_state.negative_prompt_embeds,
269
- ) = WanRTStreamingTextEncoderStep.encode_prompt(
270
  components,
271
  block_state.prompt,
272
  block_state.device,
@@ -276,6 +283,7 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
276
  prompt_embeds=block_state.prompt_embeds,
277
  negative_prompt_embeds=block_state.negative_prompt_embeds,
278
  )
 
279
 
280
  # Add outputs
281
  self.set_block_state(state, block_state)
 
56
  return text
57
 
58
 
59
+ class WanRTTextEncoderStep(ModularPipelineBlocks):
60
  model_name = "WanRTStreaming"
61
 
62
  @property
 
83
  @property
84
  def inputs(self) -> List[InputParam]:
85
  return [
86
+ InputParam(
87
+ "prompt",
88
+ description="The prompt or prompts to guide the video generation",
89
+ ),
90
+ InputParam(
91
+ "negative_prompt",
92
+ description="The prompt or prompts not to guide the video generation",
93
+ ),
94
  InputParam(
95
  "prompt_embeds",
96
  type_hint=torch.Tensor,
 
101
  type_hint=torch.Tensor,
102
  description="negative text embeddings used to guide the image generation",
103
  ),
104
+ InputParam(
105
+ "attention_kwargs",
106
+ description="Additional keyword arguments to pass to the attention mechanism",
107
+ ),
108
  ]
109
 
110
  @property
 
214
  batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
215
 
216
  if prompt_embeds is None:
217
+ prompt_embeds = WanRTTextEncoderStep._get_t5_prompt_embeds(
218
  components, prompt, max_sequence_length, device
219
  )
220
 
 
238
  " the batch size of `prompt`."
239
  )
240
 
241
+ negative_prompt_embeds = WanRTTextEncoderStep._get_t5_prompt_embeds(
242
+ components, negative_prompt, max_sequence_length, device
 
 
243
  )
244
 
245
  bs_embed, seq_len, _ = prompt_embeds.shape
 
273
  (
274
  block_state.prompt_embeds,
275
  block_state.negative_prompt_embeds,
276
+ ) = WanRTTextEncoderStep.encode_prompt(
277
  components,
278
  block_state.prompt,
279
  block_state.device,
 
283
  prompt_embeds=block_state.prompt_embeds,
284
  negative_prompt_embeds=block_state.negative_prompt_embeds,
285
  )
286
+ block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
287
 
288
  # Add outputs
289
  self.set_block_state(state, block_state)
modular_blocks.py CHANGED
@@ -16,27 +16,24 @@ from diffusers.utils import logging
16
  from diffusers.modular_pipelines import SequentialPipelineBlocks
17
  from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
18
 
19
- from .before_denoise import WanRTStreamingBeforeDenoiseStep
20
  from .decoders import WanRTDecodeStep
21
- from .encoders import WanRTStreamingTextEncoderStep
22
- from .denoise import WanRTStreamingDenoiseStep
23
 
24
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
 
26
- TEXT2VIDEO_BLOCKS = InsertableDict(
 
27
  [
28
- ("text_encoder", WanRTStreamingTextEncoderStep),
29
- ("before_denoise", WanRTStreamingBeforeDenoiseStep),
30
- ("denoise", WanRTStreamingDenoiseStep),
31
  ("decode", WanRTDecodeStep),
32
  ]
33
  )
34
 
35
- ALL_BLOCKS = {
36
- "text2video": TEXT2VIDEO_BLOCKS,
37
- }
38
-
39
 
40
- class WanStreamingRTBlocks(SequentialPipelineBlocks):
41
- block_classes = list(TEXT2VIDEO_BLOCKS.copy().values())
42
- block_names = list(TEXT2VIDEO_BLOCKS.copy().keys())
 
16
  from diffusers.modular_pipelines import SequentialPipelineBlocks
17
  from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
18
 
19
+ from .before_denoise import WanRTAutoBeforeDenoiseStep
20
  from .decoders import WanRTDecodeStep
21
+ from .encoders import WanRTTextEncoderStep
22
+ from .denoise import WanRTDenoiseStep
23
 
24
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
 
26
+
27
+ AUTO_BLOCKS = InsertableDict(
28
  [
29
+ ("text_encoder", WanRTTextEncoderStep),
30
+ ("before_denoise", WanRTAutoBeforeDenoiseStep),
31
+ ("denoise", WanRTDenoiseStep),
32
  ("decode", WanRTDecodeStep),
33
  ]
34
  )
35
 
 
 
 
 
36
 
37
+ class WanRTBlocks(SequentialPipelineBlocks):
38
+ block_classes = list(AUTO_BLOCKS.copy().values())
39
+ block_names = list(AUTO_BLOCKS.copy().keys())
modular_config.json CHANGED
@@ -2,6 +2,6 @@
2
  "_class_name": "WanRTBlocks",
3
  "_diffusers_version": "0.36.0.dev0",
4
  "auto_map": {
5
- "ModularPipelineBlocks": "modular_blocks.WanStreamingRTBlocks"
6
  }
7
  }
 
2
  "_class_name": "WanRTBlocks",
3
  "_diffusers_version": "0.36.0.dev0",
4
  "auto_map": {
5
+ "ModularPipelineBlocks": "modular_blocks.WanRTBlocks"
6
  }
7
  }
modular_model_index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
- "_blocks_class_name": "WanStreamingRTBlocks",
3
- "_class_name": "WanRTStreamingPipeline",
4
  "_diffusers_version": "0.36.0.dev0",
5
  "frame_seq_length": 1560,
6
  "kv_cache_num_frames": 3,
@@ -52,7 +52,7 @@
52
  null,
53
  null,
54
  {
55
- "repo": "diffusers-internal-dev/krt",
56
  "revision": null,
57
  "subfolder": "transformer",
58
  "type_hint": [
 
1
  {
2
+ "_blocks_class_name": "WanRTBlocks",
3
+ "_class_name": "WanModularPipeline",
4
  "_diffusers_version": "0.36.0.dev0",
5
  "frame_seq_length": 1560,
6
  "kv_cache_num_frames": 3,
 
52
  null,
53
  null,
54
  {
55
+ "repo": "krea/krea-realtime-video",
56
  "revision": null,
57
  "subfolder": "transformer",
58
  "type_hint": [
transformer/attention.py CHANGED
@@ -1,65 +1,40 @@
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import torch
3
- from typing import Optional
4
  import os
5
  import warnings
6
-
7
- # Global state for lazy initialization
8
- _SAGEATTN_AVAILABLE = None
9
- _FLASH_ATTN_3_AVAILABLE = None
10
- _FLASH_ATTN_2_AVAILABLE = None
11
- _sageattn_func = None
12
- _flash_attn_func = None
13
- _flash_attn_interface = None
14
- _flash_attn = None
15
-
16
-
17
- def _init_sageattention():
18
- """Lazy initialization for SageAttention."""
19
- global _SAGEATTN_AVAILABLE, _sageattn_func
20
-
21
- if _SAGEATTN_AVAILABLE is not None:
22
- return _SAGEATTN_AVAILABLE
23
-
24
- _SAGEATTN_AVAILABLE = False
25
- try:
26
- if os.getenv("DISABLE_SAGEATTENTION", "0") != "0":
27
- raise Exception("DISABLE_SAGEATTENTION is set")
28
-
29
- from sageattention import sageattn
30
-
31
- @torch.library.custom_op(
32
- "mylib::sageattn", mutates_args={"q", "k", "v"}, device_types="cuda"
33
  )
34
- def sageattn_func(
35
- q: torch.Tensor,
36
- k: torch.Tensor,
37
- v: torch.Tensor,
38
- attn_mask: Optional[torch.Tensor] = None,
39
- dropout_p: float = 0,
40
- is_causal: bool = False,
41
- ) -> torch.Tensor:
42
- return sageattn(
43
- q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
44
- )
45
 
46
- @sageattn_func.register_fake
47
- def _sageattn_fake(q, k, v, attn_mask=None, dropout_p=0, is_causal=False):
48
- return torch.empty(*q.shape, device=q.device, dtype=q.dtype)
49
 
50
- print("SageAttention loaded successfully")
51
- _sageattn_func = sageattn_func
52
- _SAGEATTN_AVAILABLE = True
53
 
54
- except Exception as e:
55
- print(f"Warning: Could not load sageattention: {str(e)}")
56
- if isinstance(e, ModuleNotFoundError):
57
- print("sageattention package is not installed")
58
- elif isinstance(e, ImportError) and "DLL" in str(e):
59
- print("sageattention DLL loading error")
60
- _sageattn_func = None
61
-
62
- return _SAGEATTN_AVAILABLE
63
 
64
 
65
  def _is_hopper_gpu():
@@ -69,65 +44,41 @@ def _is_hopper_gpu():
69
  device_name = torch.cuda.get_device_name(0).lower()
70
  return "h100" in device_name or "hopper" in device_name
71
 
 
 
 
 
 
 
72
 
73
- def _init_flash_attention_3():
74
- """Lazy initialization for Flash Attention 3."""
75
- global _FLASH_ATTN_3_AVAILABLE, _flash_attn_func, _flash_attn_interface
76
-
77
- if _FLASH_ATTN_3_AVAILABLE is not None:
78
- return _FLASH_ATTN_3_AVAILABLE
79
-
80
- _FLASH_ATTN_3_AVAILABLE = False
81
- try:
82
- from flash_attn import flash_attn_func
83
- import flash_attn_interface
84
 
85
- # Always set the function reference if flash_attn is available
86
- _flash_attn_func = flash_attn_func
87
- _flash_attn_interface = flash_attn_interface
88
- # FA3 optimizations only available on Hopper GPUs
89
- _FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
90
- except ModuleNotFoundError:
91
- _FLASH_ATTN_3_AVAILABLE = False
92
- _flash_attn_func = None
93
- _flash_attn_interface = None
94
 
95
- return _FLASH_ATTN_3_AVAILABLE
96
 
 
 
97
 
98
- def _init_flash_attention_2():
99
- """Lazy initialization for Flash Attention 2."""
100
- global _FLASH_ATTN_2_AVAILABLE, _flash_attn
101
 
102
- if _FLASH_ATTN_2_AVAILABLE is not None:
103
- return _FLASH_ATTN_2_AVAILABLE
104
-
105
- _FLASH_ATTN_2_AVAILABLE = False
106
- try:
107
- import flash_attn
108
-
109
- _flash_attn = flash_attn
110
- _FLASH_ATTN_2_AVAILABLE = True
111
- except ModuleNotFoundError:
112
- _FLASH_ATTN_2_AVAILABLE = False
113
-
114
- return _FLASH_ATTN_2_AVAILABLE
115
 
116
  __all__ = ["flash_attention", "attention"]
117
 
118
-
119
- # Compatibility getters for external code
120
- def sageattn_func():
121
- """Getter for sageattn_func - initializes if needed."""
122
- _init_sageattention()
123
- return _sageattn_func
124
-
125
-
126
- def SAGEATTN_AVAILABLE():
127
- """Getter for SAGEATTN_AVAILABLE - initializes if needed."""
128
- return _init_sageattention()
129
-
130
-
131
  def flash_attention(
132
  q,
133
  k,
@@ -156,14 +107,15 @@ def flash_attention(
156
  deterministic: bool. If True, slightly slower and uses more memory.
157
  dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
158
  """
159
- # Initialize flash attention modules
160
- flash_attn_3_available = _init_flash_attention_3()
161
- flash_attn_2_available = _init_flash_attention_2()
162
-
163
- # Early fallback for simple cases when advanced features aren't needed
164
- # Only use this path if flash_attn is available but we're not using FA3 features
165
- if not flash_attn_3_available and _flash_attn_func is not None and q_lens is None and k_lens is None:
166
- return _flash_attn_func(
 
167
  q,
168
  k,
169
  v,
@@ -205,15 +157,15 @@ def flash_attention(
205
  if q_scale is not None:
206
  q = q * q_scale
207
 
208
- if version is not None and version == 3 and not flash_attn_3_available:
209
  warnings.warn(
210
  "Flash attention 3 is not available, use flash attention 2 instead."
211
  )
212
 
213
  # apply attention
214
- if (version is None or version == 3) and flash_attn_3_available:
215
  # Note: dropout_p, window_size are not supported in FA3 now.
216
- x = _flash_attn_interface.flash_attn_varlen_func(
217
  q=q,
218
  k=k,
219
  v=v,
@@ -230,8 +182,8 @@ def flash_attention(
230
  deterministic=deterministic,
231
  ).unflatten(0, (b, lq))
232
  else:
233
- assert flash_attn_2_available
234
- x = _flash_attn.flash_attn_varlen_func(
235
  q=q,
236
  k=k,
237
  v=v,
@@ -270,12 +222,8 @@ def attention(
270
  fa_version=None,
271
  # og_dtype=torch.bfloat16,
272
  ):
273
- # Initialize attention modules
274
- sageattn_available = _init_sageattention()
275
- flash_attn_2_available = _init_flash_attention_2()
276
- flash_attn_3_available = _init_flash_attention_3()
277
 
278
- if sageattn_available:
279
  # print("Using sageattention")
280
  attn_mask = None
281
 
@@ -284,14 +232,14 @@ def attention(
284
  k = k.transpose(1, 2).to(dtype)
285
  v = v.transpose(1, 2).to(dtype)
286
 
287
- out = _sageattn_func(
288
  q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
289
  )
290
 
291
  out = out.transpose(1, 2).contiguous().to(og_dtype)
292
  return out
293
 
294
- elif flash_attn_2_available or flash_attn_3_available:
295
  return flash_attention(
296
  q=q,
297
  k=k,
 
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import torch
 
3
  import os
4
  import warnings
5
+ from typing import Optional
6
+ from diffusers.utils import is_kernels_available
7
+
8
+ SAGEATTN_AVAILABLE = False
9
+ try:
10
+ if os.getenv("DISABLE_SAGEATTENTION", "0") != "0":
11
+ raise Exception("DISABLE_SAGEATTENTION is set")
12
+
13
+ from sageattention import sageattn
14
+
15
+ @torch.library.custom_op(
16
+ "mylib::sageattn", mutates_args={"q", "k", "v"}, device_types="cuda"
17
+ )
18
+ def sageattn_func(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ attn_mask: Optional[torch.Tensor] = None,
23
+ dropout_p: float = 0,
24
+ is_causal: bool = False,
25
+ ) -> torch.Tensor:
26
+ return sageattn(
27
+ q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
 
 
 
 
28
  )
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ @sageattn_func.register_fake
31
+ def _sageattn_fake(q, k, v, attn_mask=None, dropout_p=0, is_causal=False):
32
+ return torch.empty(*q.shape, device=q.device, dtype=q.dtype)
33
 
34
+ SAGEATTN_AVAILABLE = True
 
 
35
 
36
+ except Exception as e:
37
+ sageattn_func = None
 
 
 
 
 
 
 
38
 
39
 
40
  def _is_hopper_gpu():
 
44
  device_name = torch.cuda.get_device_name(0).lower()
45
  return "h100" in device_name or "hopper" in device_name
46
 
47
+ FLASH_ATTN_3_AVAILABLE = False
48
+ try:
49
+ import flash_attn_interface
50
+ FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
51
+ except ModuleNotFoundError:
52
+ FLASH_ATTN_3_AVAILABLE = False
53
 
54
+ FLASH_ATTN_3_HUB_AVAILABLE = False
55
+ try:
56
+ use_hub_kernels = os.getenv("DIFFUSERS_ENABLE_HUB_KERNELS", "false").upper() in ["1", "TRUE"]
57
+ if use_hub_kernels and not is_kernels_available():
58
+ raise EnvironmentError((
59
+ "Attempting to use Hub Kernels for Flash Attention 3,"
60
+ "but the `kernels` library was not found in your environment. "
61
+ "Please install via `pip install kernels`"
62
+ ))
 
 
63
 
64
+ from kernels import get_kernel
65
+ flash_attn_3_hub = get_kernel("kernels-community/flash-attn3", revision="fake-ops-return-probs")
 
 
 
 
 
 
 
66
 
67
+ FLASH_ATTN_3_HUB_AVAILABLE = _is_hopper_gpu()
68
 
69
+ except:
70
+ FLASH_ATTN_3_HUB_AVAILABLE = False
71
 
72
+ FLASH_ATTN_2_AVAILABLE = False
73
+ try:
74
+ import flash_attn
75
 
76
+ FLASH_ATTN_2_AVAILABLE = True
77
+ except ModuleNotFoundError:
78
+ FLASH_ATTN_2_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
79
 
80
  __all__ = ["flash_attention", "attention"]
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def flash_attention(
83
  q,
84
  k,
 
107
  deterministic: bool. If True, slightly slower and uses more memory.
108
  dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
109
  """
110
+ if not FLASH_ATTN_3_AVAILABLE or not FLASH_ATTN_3_HUB_AVAILABLE:
111
+ return flash_attn.flash_attn_func(
112
+ q,
113
+ k,
114
+ v,
115
+ )
116
+
117
+ elif FLASH_ATTN_3_HUB_AVAILABLE:
118
+ return flash_attn_3_hub.flash_attn_func(
119
  q,
120
  k,
121
  v,
 
157
  if q_scale is not None:
158
  q = q * q_scale
159
 
160
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
161
  warnings.warn(
162
  "Flash attention 3 is not available, use flash attention 2 instead."
163
  )
164
 
165
  # apply attention
166
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
167
  # Note: dropout_p, window_size are not supported in FA3 now.
168
+ x = flash_attn_interface.flash_attn_varlen_func(
169
  q=q,
170
  k=k,
171
  v=v,
 
182
  deterministic=deterministic,
183
  ).unflatten(0, (b, lq))
184
  else:
185
+ assert FLASH_ATTN_3_AVAILABLE
186
+ x = flash_attn.flash_attn_varlen_func(
187
  q=q,
188
  k=k,
189
  v=v,
 
222
  fa_version=None,
223
  # og_dtype=torch.bfloat16,
224
  ):
 
 
 
 
225
 
226
+ if SAGEATTN_AVAILABLE:
227
  # print("Using sageattention")
228
  attn_mask = None
229
 
 
232
  k = k.transpose(1, 2).to(dtype)
233
  v = v.transpose(1, 2).to(dtype)
234
 
235
+ out = sageattn_func(
236
  q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
237
  )
238
 
239
  out = out.transpose(1, 2).contiguous().to(og_dtype)
240
  return out
241
 
242
+ elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
243
  return flash_attention(
244
  q=q,
245
  k=k,
transformer/causal_model.py CHANGED
@@ -18,6 +18,7 @@ from torch.nn.attention.flex_attention import BlockMask
18
 
19
  from diffusers.configuration_utils import ConfigMixin, register_to_config
20
  from diffusers.models.modeling_utils import ModelMixin
 
21
 
22
  flex_attention = torch.compile(
23
  flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
@@ -642,7 +643,7 @@ class CausalHead(nn.Module):
642
  return x
643
 
644
 
645
- class CausalWanModel(ModelMixin, ConfigMixin):
646
  r"""
647
  Wan diffusion backbone supporting both text-to-video and image-to-video.
648
  """
 
18
 
19
  from diffusers.configuration_utils import ConfigMixin, register_to_config
20
  from diffusers.models.modeling_utils import ModelMixin
21
+ from diffusers.loaders import PeftAdapterMixin
22
 
23
  flex_attention = torch.compile(
24
  flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
 
643
  return x
644
 
645
 
646
+ class CausalWanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
647
  r"""
648
  Wan diffusion backbone supporting both text-to-video and image-to-video.
649
  """
transformer/model.py CHANGED
@@ -10,13 +10,11 @@ from einops import repeat
10
  from .attention import (
11
  flash_attention,
12
  sageattn_func,
13
- _SAGEATTN_AVAILABLE,
14
- _FLASH_ATTN_2_AVAILABLE,
15
- _FLASH_ATTN_3_AVAILABLE,
16
  )
17
 
18
- print("SAGEATTN_AVAILABLE:", _SAGEATTN_AVAILABLE)
19
-
20
  __all__ = ["WanModel"]
21
 
22
 
@@ -153,7 +151,7 @@ class WanSelfAttention(nn.Module):
153
 
154
  q, k, v = qkv_fn(x)
155
 
156
- if _SAGEATTN_AVAILABLE:
157
  # print("Using sageattention in crossattn")
158
  og_dtype = q.dtype
159
  q = q.transpose(1, 2).to(dtype)
@@ -209,7 +207,7 @@ class WanT2VCrossAttention(WanSelfAttention):
209
  v = self.v(context).view(b, -1, n, d)
210
 
211
  # compute attention
212
- if _SAGEATTN_AVAILABLE:
213
  # print("Using sageattention in crossattn")
214
  dtype = torch.bfloat16
215
  og_dtype = q.dtype
@@ -222,7 +220,7 @@ class WanT2VCrossAttention(WanSelfAttention):
222
  v=v,
223
  )
224
  x = x.transpose(1, 2).contiguous().to(og_dtype)
225
- elif _FLASH_ATTN_2_AVAILABLE or _FLASH_ATTN_3_AVAILABLE:
226
  x = flash_attention(q, k, v, k_lens=context_lens)
227
  else:
228
  dtype = torch.bfloat16
 
10
  from .attention import (
11
  flash_attention,
12
  sageattn_func,
13
+ SAGEATTN_AVAILABLE,
14
+ FLASH_ATTN_2_AVAILABLE,
15
+ FLASH_ATTN_3_AVAILABLE,
16
  )
17
 
 
 
18
  __all__ = ["WanModel"]
19
 
20
 
 
151
 
152
  q, k, v = qkv_fn(x)
153
 
154
+ if SAGEATTN_AVAILABLE:
155
  # print("Using sageattention in crossattn")
156
  og_dtype = q.dtype
157
  q = q.transpose(1, 2).to(dtype)
 
207
  v = self.v(context).view(b, -1, n, d)
208
 
209
  # compute attention
210
+ if SAGEATTN_AVAILABLE:
211
  # print("Using sageattention in crossattn")
212
  dtype = torch.bfloat16
213
  og_dtype = q.dtype
 
220
  v=v,
221
  )
222
  x = x.transpose(1, 2).contiguous().to(og_dtype)
223
+ elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
224
  x = flash_attention(q, k, v, k_lens=context_lens)
225
  else:
226
  dtype = torch.bfloat16