Diffusers updates
Browse files- .gitattributes +0 -6
- README.md +244 -16
- before_denoise.py +516 -111
- decoders.py +18 -10
- denoise.py +39 -70
- encoders.py +18 -10
- modular_blocks.py +11 -14
- modular_config.json +1 -1
- modular_model_index.json +3 -3
- transformer/attention.py +72 -124
- transformer/causal_model.py +2 -1
- transformer/model.py +6 -8
.gitattributes
CHANGED
|
@@ -33,9 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
hf_assets/3.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
-
hf_assets/masonry.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
-
hf_assets/trimferarricrop2_2x_speed.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
-
hf_assets/v2v_me_crop_1final.mov filter=lfs diff=lfs merge=lfs -text
|
| 40 |
-
hf_assets/vertical_grid_all_videos.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
-
hf_assets/vertical_grid_output_reordered.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -12,7 +12,7 @@ library_name: diffusers
|
|
| 12 |
---
|
| 13 |
Krea Realtime 14B is distilled from the [Wan 2.1 14B text-to-video model](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) using Self-Forcing, a technique for converting regular video diffusion models into autoregressive models. It achieves a text-to-video inference speed of **11fps** using 4 inference steps on a single NVIDIA B200 GPU. For more details on our training methodology and sampling innovations, refer to our [technical blog post](https://www.krea.ai/blog/krea-realtime-14b).
|
| 14 |
|
| 15 |
-
Inference code can be found [here](https://github.com/krea-ai/realtime-video).
|
| 16 |
|
| 17 |
|
| 18 |
<video width="100%" controls>
|
|
@@ -118,60 +118,288 @@ export CUDA_VISIBLE_DEVICES=0 # pick the GPU you want to serve on
|
|
| 118 |
export DO_COMPILE=true
|
| 119 |
|
| 120 |
uvicorn release_server:app --host 0.0.0.0 --port 8000
|
| 121 |
-
```
|
| 122 |
|
| 123 |
-
And use the web app at http://localhost:8000/ in your browser
|
| 124 |
(for more advanced use-cases and custom pipeline check out our GitHub repository: https://github.com/krea-ai/realtime-video)
|
| 125 |
|
| 126 |
# Use it with 🧨 diffusers
|
| 127 |
|
| 128 |
-
Krea Realtime 14B can be used with the `diffusers` library utilizing the new Modular Diffusers structure
|
| 129 |
|
| 130 |
```bash
|
| 131 |
# Install diffusers from main
|
| 132 |
pip install git+github.com/huggingface/diffusers.git
|
| 133 |
-
```
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
```py
|
| 136 |
import torch
|
| 137 |
-
from
|
| 138 |
from diffusers.utils import export_to_video
|
| 139 |
-
from diffusers import
|
| 140 |
-
from diffusers.modular_pipelines import PipelineState
|
| 141 |
|
| 142 |
repo_id = "krea/krea-realtime-video"
|
| 143 |
-
|
| 144 |
-
pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
pipe.load_components(
|
| 147 |
trust_remote_code=True,
|
| 148 |
device_map="cuda",
|
| 149 |
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
| 150 |
)
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
num_frames_per_block = 3
|
| 153 |
num_blocks = 9
|
|
|
|
| 154 |
|
| 155 |
frames = []
|
|
|
|
|
|
|
| 156 |
state = PipelineState()
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
for block in pipe.transformer.blocks:
|
| 162 |
block.self_attn.fuse_projections()
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
state = pipe(
|
| 166 |
state,
|
| 167 |
prompt=prompt,
|
| 168 |
num_inference_steps=6,
|
| 169 |
num_blocks=num_blocks,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
num_frames_per_block=num_frames_per_block,
|
| 171 |
block_idx=block_idx,
|
| 172 |
generator=torch.Generator("cuda").manual_seed(42),
|
| 173 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
frames.extend(state.values["videos"][0])
|
| 175 |
|
| 176 |
-
export_to_video(frames, "output.mp4", fps=
|
| 177 |
-
```
|
|
|
|
|
|
| 12 |
---
|
| 13 |
Krea Realtime 14B is distilled from the [Wan 2.1 14B text-to-video model](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) using Self-Forcing, a technique for converting regular video diffusion models into autoregressive models. It achieves a text-to-video inference speed of **11fps** using 4 inference steps on a single NVIDIA B200 GPU. For more details on our training methodology and sampling innovations, refer to our [technical blog post](https://www.krea.ai/blog/krea-realtime-14b).
|
| 14 |
|
| 15 |
+
Inference code can be found [here](https://github.com/krea-ai/realtime-video).
|
| 16 |
|
| 17 |
|
| 18 |
<video width="100%" controls>
|
|
|
|
| 118 |
export DO_COMPILE=true
|
| 119 |
|
| 120 |
uvicorn release_server:app --host 0.0.0.0 --port 8000
|
| 121 |
+
```
|
| 122 |
|
| 123 |
+
And use the web app at http://localhost:8000/ in your browser
|
| 124 |
(for more advanced use-cases and custom pipeline check out our GitHub repository: https://github.com/krea-ai/realtime-video)
|
| 125 |
|
| 126 |
# Use it with 🧨 diffusers
|
| 127 |
|
| 128 |
+
Krea Realtime 14B can be used with the `diffusers` library utilizing the new Modular Diffusers structure
|
| 129 |
|
| 130 |
```bash
|
| 131 |
# Install diffusers from main
|
| 132 |
pip install git+github.com/huggingface/diffusers.git
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
<details>
|
| 136 |
+
<summary>Text to Video</summary>
|
| 137 |
|
| 138 |
```py
|
| 139 |
import torch
|
| 140 |
+
from tqdm import tqdm
|
| 141 |
from diffusers.utils import export_to_video
|
| 142 |
+
from diffusers import ModularPipeline
|
| 143 |
+
from diffusers.modular_pipelines import PipelineState
|
| 144 |
|
| 145 |
repo_id = "krea/krea-realtime-video"
|
| 146 |
+
pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
|
| 147 |
+
pipe.load_components(
|
| 148 |
+
trust_remote_code=True,
|
| 149 |
+
device_map="cuda",
|
| 150 |
+
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
| 151 |
+
)
|
| 152 |
+
for block in pipe.transformer.blocks:
|
| 153 |
+
block.self_attn.fuse_projections()
|
| 154 |
+
|
| 155 |
+
num_blocks = 9
|
| 156 |
|
| 157 |
+
frames = []
|
| 158 |
+
state = PipelineState()
|
| 159 |
+
prompt = ["a cat sitting on a boat"]
|
| 160 |
+
|
| 161 |
+
generator = torch.Generator(device=pipe.device).manual_seed(42)
|
| 162 |
+
for block_idx in tqdm(range(num_blocks)):
|
| 163 |
+
state = pipe(
|
| 164 |
+
state,
|
| 165 |
+
prompt=prompt,
|
| 166 |
+
num_inference_steps=6,
|
| 167 |
+
num_blocks=num_blocks,
|
| 168 |
+
block_idx=block_idx,
|
| 169 |
+
generator=generator,
|
| 170 |
+
)
|
| 171 |
+
frames.extend(state.values["videos"][0])
|
| 172 |
+
|
| 173 |
+
export_to_video(frames, "output.mp4", fps=24)
|
| 174 |
+
```
|
| 175 |
+
</details>
|
| 176 |
+
|
| 177 |
+
<details>
|
| 178 |
+
<summary>Video to Video</summary>
|
| 179 |
+
|
| 180 |
+
```py
|
| 181 |
+
import torch
|
| 182 |
+
from tqdm import tqdm
|
| 183 |
+
from diffusers.utils import load_video, export_to_video
|
| 184 |
+
from diffusers import ModularPipeline
|
| 185 |
+
from diffusers.modular_pipelines import PipelineState
|
| 186 |
+
|
| 187 |
+
repo_id = "krea/krea-realtime-video"
|
| 188 |
+
pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
|
| 189 |
pipe.load_components(
|
| 190 |
trust_remote_code=True,
|
| 191 |
device_map="cuda",
|
| 192 |
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
| 193 |
)
|
| 194 |
+
for block in pipe.transformer.blocks:
|
| 195 |
+
block.self_attn.fuse_projections()
|
| 196 |
|
|
|
|
| 197 |
num_blocks = 9
|
| 198 |
+
video = load_video("https://app-uploads.krea.ai/public/a8218957-1a80-43dc-81b2-da970b5f2221-video.mp4")
|
| 199 |
|
| 200 |
frames = []
|
| 201 |
+
prompt = ["A car racing down a snowy mountain"]
|
| 202 |
+
|
| 203 |
state = PipelineState()
|
| 204 |
+
generator = torch.Generator("cuda").manual_seed(42)
|
| 205 |
+
for block_idx in tqdm(range(num_blocks)):
|
| 206 |
+
state = pipe(
|
| 207 |
+
state,
|
| 208 |
+
video=video,
|
| 209 |
+
prompt=prompt,
|
| 210 |
+
num_inference_steps=6,
|
| 211 |
+
strength=0.3,
|
| 212 |
+
block_idx=block_idx,
|
| 213 |
+
generator=generator,
|
| 214 |
+
)
|
| 215 |
+
frames.extend(state.values["videos"][0])
|
| 216 |
|
| 217 |
+
export_to_video(frames, "output-v2v.mp4", fps=24)
|
| 218 |
+
```
|
| 219 |
+
</details>
|
| 220 |
+
|
| 221 |
+
<details>
|
| 222 |
+
<summary>Streaming Video to Video</summary>
|
| 223 |
|
| 224 |
+
Using the `video_stream` input will process video frames in as they arrive, while maintaining temporal consistency across chunks.
|
| 225 |
+
|
| 226 |
+
```py
|
| 227 |
+
import torch
|
| 228 |
+
from collections import deque
|
| 229 |
+
from tqdm import tqdm
|
| 230 |
+
from diffusers.utils import load_video, export_to_video
|
| 231 |
+
from diffusers import ModularPipeline
|
| 232 |
+
from diffusers.modular_pipelines import PipelineState
|
| 233 |
+
|
| 234 |
+
repo_id = "krea/krea-realtime-video"
|
| 235 |
+
pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
|
| 236 |
+
pipe.load_components(
|
| 237 |
+
trust_remote_code=True,
|
| 238 |
+
device_map="cuda",
|
| 239 |
+
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
| 240 |
+
)
|
| 241 |
for block in pipe.transformer.blocks:
|
| 242 |
block.self_attn.fuse_projections()
|
| 243 |
|
| 244 |
+
n_samples = 9
|
| 245 |
+
frame_sample_len = 12
|
| 246 |
+
video = load_video(
|
| 247 |
+
"https://app-uploads.krea.ai/public/a8218957-1a80-43dc-81b2-da970b5f2221-video.mp4"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Simulate streaming video input
|
| 251 |
+
frame_samples = [
|
| 252 |
+
video[sample_start : sample_start + frame_sample_len]
|
| 253 |
+
for sample_start in range(0, n_samples * frame_sample_len, frame_sample_len)
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
frames = []
|
| 257 |
+
state = PipelineState()
|
| 258 |
+
prompt = ["A car racing down a snowny mountain road"]
|
| 259 |
+
|
| 260 |
+
block_idx = 0
|
| 261 |
+
generator = torch.Generator("cpu").manual_seed(42)
|
| 262 |
+
for frame_sample in tqdm(frame_samples):
|
| 263 |
+
state = pipe(
|
| 264 |
+
state,
|
| 265 |
+
video_stream=frame_sample,
|
| 266 |
+
prompt=prompt,
|
| 267 |
+
num_inference_steps=6,
|
| 268 |
+
strength=0.3,
|
| 269 |
+
block_idx=block_idx,
|
| 270 |
+
generator=generator,
|
| 271 |
+
)
|
| 272 |
+
frames.extend(state.values["videos"][0])
|
| 273 |
+
|
| 274 |
+
block_idx += 1
|
| 275 |
+
|
| 276 |
+
export_to_video(frames, "output-v2v-streaming.mp4", fps=24)
|
| 277 |
+
```
|
| 278 |
+
</details>
|
| 279 |
+
|
| 280 |
+
<details>
|
| 281 |
+
<summary>Using LoRAs</summary>
|
| 282 |
+
|
| 283 |
+
```py
|
| 284 |
+
import torch
|
| 285 |
+
from collections import deque
|
| 286 |
+
from tqdm import tqdm
|
| 287 |
+
from diffusers.utils import export_to_video
|
| 288 |
+
from diffusers import ModularPipeline
|
| 289 |
+
from diffusers.modular_pipelines import PipelineState
|
| 290 |
+
|
| 291 |
+
repo_id = "krea/krea-realtime-video"
|
| 292 |
+
pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
|
| 293 |
+
pipe.load_components(
|
| 294 |
+
trust_remote_code=True,
|
| 295 |
+
device_map="cuda",
|
| 296 |
+
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
| 297 |
+
)
|
| 298 |
+
pipe.transformer.load_lora_adapter(
|
| 299 |
+
"shauray/Origami_WanLora",
|
| 300 |
+
prefix="diffusion_model",
|
| 301 |
+
weight_name="origami_000000500.safetensors",
|
| 302 |
+
adapter_name="origami",
|
| 303 |
+
)
|
| 304 |
+
for block in pipe.transformer.blocks:
|
| 305 |
+
block.self_attn.fuse_projections()
|
| 306 |
+
|
| 307 |
+
num_blocks = 9
|
| 308 |
+
|
| 309 |
+
frames = []
|
| 310 |
+
state = PipelineState()
|
| 311 |
+
prompt = ["[origami] a cat sitting on a boat"]
|
| 312 |
+
|
| 313 |
+
generator = torch.Generator("cuda").manual_seed(42)
|
| 314 |
+
for block_idx in tqdm(range(num_blocks)):
|
| 315 |
state = pipe(
|
| 316 |
state,
|
| 317 |
prompt=prompt,
|
| 318 |
num_inference_steps=6,
|
| 319 |
num_blocks=num_blocks,
|
| 320 |
+
block_idx=block_idx,
|
| 321 |
+
generator=generator,
|
| 322 |
+
)
|
| 323 |
+
frames.extend(state.values["videos"][0])
|
| 324 |
+
|
| 325 |
+
export_to_video(frames, "output.mp4", fps=24)
|
| 326 |
+
```
|
| 327 |
+
</details>
|
| 328 |
+
|
| 329 |
+
<details>
|
| 330 |
+
<summary>Optimized Inference</summary>
|
| 331 |
+
|
| 332 |
+
To optimize inference speed and memory usage on Hopper level GPUs (H100s), we recommend using `torch.compile`, Flash Attention 3 and FP8 quantization with [torchao](https://github.com/pytorch/ao).
|
| 333 |
+
|
| 334 |
+
First let's set up our depedencies by enabling Flash Attention 3 via Hub [kernels](https://huggingface.co/docs/kernels/en/index) and installing the `torchao` and `kernels` packages.
|
| 335 |
+
|
| 336 |
+
```shell
|
| 337 |
+
export DIFFUSERS_ENABLE_HUB_KERNELS=true
|
| 338 |
+
pip install -U kernels torchao
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
Then we will iterate over the blocks of the transformer and apply quantization and `torch.compile`.
|
| 342 |
+
|
| 343 |
+
```py
|
| 344 |
+
import torch
|
| 345 |
+
from collections import deque
|
| 346 |
+
from tqdm import tqdm
|
| 347 |
+
from diffusers.utils import export_to_video
|
| 348 |
+
from diffusers import ModularPipeline
|
| 349 |
+
from diffusers.modular_pipelines import PipelineState
|
| 350 |
+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, quantize_
|
| 351 |
+
|
| 352 |
+
repo_id = "krea/krea-realtime-video"
|
| 353 |
+
pipe = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True)
|
| 354 |
+
pipe.load_components(
|
| 355 |
+
trust_remote_code=True,
|
| 356 |
+
device_map="cuda",
|
| 357 |
+
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
for block in pipe.transformer.blocks:
|
| 361 |
+
block.self_attn.fuse_projections()
|
| 362 |
+
|
| 363 |
+
# Quantize just the transformer blocks
|
| 364 |
+
for block in pipe.transformer.blocks:
|
| 365 |
+
quantize_(block, Float8DynamicActivationFloat8WeightConfig())
|
| 366 |
+
|
| 367 |
+
# Compile just the attention modules
|
| 368 |
+
for submod in pipe.transformer.modules():
|
| 369 |
+
if submod.__class__.__name__ in ["CausalWanAttentionBlock"]:
|
| 370 |
+
submod.compile(fullgraph=False)
|
| 371 |
+
|
| 372 |
+
num_blocks = 9
|
| 373 |
+
|
| 374 |
+
state = PipelineState()
|
| 375 |
+
prompt = ["a cat sitting on a boat"]
|
| 376 |
+
|
| 377 |
+
# Compile warmup
|
| 378 |
+
for block_idx in range(num_blocks):
|
| 379 |
+
state = pipe(
|
| 380 |
+
state,
|
| 381 |
+
prompt=prompt,
|
| 382 |
+
num_inference_steps=2,
|
| 383 |
+
num_blocks=num_blocks,
|
| 384 |
num_frames_per_block=num_frames_per_block,
|
| 385 |
block_idx=block_idx,
|
| 386 |
generator=torch.Generator("cuda").manual_seed(42),
|
| 387 |
)
|
| 388 |
+
|
| 389 |
+
# Reset state
|
| 390 |
+
state = PipelineState()
|
| 391 |
+
generator = torch.Generator("cuda").manual_seed(42)
|
| 392 |
+
for block_idx in tqdm(range(num_blocks)):
|
| 393 |
+
state = pipe(
|
| 394 |
+
state,
|
| 395 |
+
prompt=prompt,
|
| 396 |
+
num_inference_steps=6,
|
| 397 |
+
num_blocks=num_blocks,
|
| 398 |
+
block_idx=block_idx,
|
| 399 |
+
generator=generator,
|
| 400 |
+
)
|
| 401 |
frames.extend(state.values["videos"][0])
|
| 402 |
|
| 403 |
+
export_to_video(frames, "output.mp4", fps=24)
|
| 404 |
+
```
|
| 405 |
+
</details>
|
before_denoise.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
|
| 15 |
import inspect
|
| 16 |
from typing import List, Optional, Union, Dict
|
|
|
|
| 17 |
|
| 18 |
import torch
|
| 19 |
|
|
@@ -25,6 +26,7 @@ from diffusers.modular_pipelines import (
|
|
| 25 |
ModularPipeline,
|
| 26 |
ModularPipelineBlocks,
|
| 27 |
SequentialPipelineBlocks,
|
|
|
|
| 28 |
PipelineState,
|
| 29 |
)
|
| 30 |
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
|
@@ -221,7 +223,7 @@ def _initialize_crossattn_cache(
|
|
| 221 |
|
| 222 |
|
| 223 |
class WanInputStep(ModularPipelineBlocks):
|
| 224 |
-
model_name = "
|
| 225 |
|
| 226 |
@property
|
| 227 |
def description(self) -> str:
|
|
@@ -237,7 +239,11 @@ class WanInputStep(ModularPipelineBlocks):
|
|
| 237 |
@property
|
| 238 |
def inputs(self) -> List[InputParam]:
|
| 239 |
return [
|
| 240 |
-
InputParam(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
InputParam(
|
| 242 |
"prompt_embeds",
|
| 243 |
required=True,
|
|
@@ -331,8 +337,8 @@ class WanInputStep(ModularPipelineBlocks):
|
|
| 331 |
return components, state
|
| 332 |
|
| 333 |
|
| 334 |
-
class
|
| 335 |
-
model_name = "
|
| 336 |
|
| 337 |
@property
|
| 338 |
def expected_components(self) -> List[ComponentSpec]:
|
|
@@ -350,6 +356,7 @@ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
|
|
| 350 |
InputParam("num_inference_steps", default=4),
|
| 351 |
InputParam("timesteps"),
|
| 352 |
InputParam("sigmas"),
|
|
|
|
| 353 |
]
|
| 354 |
|
| 355 |
@property
|
|
@@ -391,7 +398,10 @@ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
|
|
| 391 |
]
|
| 392 |
)
|
| 393 |
denoising_steps = torch.linspace(
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
| 395 |
).to(torch.long)
|
| 396 |
|
| 397 |
block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps]
|
|
@@ -403,8 +413,8 @@ class WanRTStreamingSetTimestepsStep(ModularPipelineBlocks):
|
|
| 403 |
return components, state
|
| 404 |
|
| 405 |
|
| 406 |
-
class
|
| 407 |
-
model_name = "
|
| 408 |
|
| 409 |
@property
|
| 410 |
def expected_components(self) -> List[ComponentSpec]:
|
|
@@ -423,15 +433,36 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
|
|
| 423 |
@property
|
| 424 |
def inputs(self) -> List[InputParam]:
|
| 425 |
return [
|
| 426 |
-
InputParam(
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
InputParam(
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
InputParam(
|
| 436 |
"dtype",
|
| 437 |
type_hint=torch.dtype,
|
|
@@ -442,20 +473,11 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
|
|
| 442 |
@property
|
| 443 |
def intermediate_outputs(self) -> List[OutputParam]:
|
| 444 |
return [
|
| 445 |
-
OutputParam(
|
| 446 |
-
"latents",
|
| 447 |
-
type_hint=torch.Tensor,
|
| 448 |
-
description="The initial latents to use for the denoising process",
|
| 449 |
-
),
|
| 450 |
OutputParam(
|
| 451 |
"init_latents",
|
| 452 |
type_hint=torch.Tensor,
|
| 453 |
description="The initial latents to use for the denoising process",
|
| 454 |
),
|
| 455 |
-
OutputParam(
|
| 456 |
-
"final_latents",
|
| 457 |
-
type_hint=torch.Tensor,
|
| 458 |
-
),
|
| 459 |
]
|
| 460 |
|
| 461 |
@staticmethod
|
|
@@ -476,8 +498,8 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
|
|
| 476 |
components,
|
| 477 |
batch_size: int,
|
| 478 |
num_channels_latents: int = 16,
|
| 479 |
-
height: int =
|
| 480 |
-
width: int =
|
| 481 |
num_blocks: int = 9,
|
| 482 |
num_frames_per_block: int = 3,
|
| 483 |
dtype: Optional[torch.dtype] = None,
|
|
@@ -536,56 +558,398 @@ class WanRTStreamingPrepareLatentsStep(ModularPipelineBlocks):
|
|
| 536 |
block_state.generator,
|
| 537 |
block_state.init_latents,
|
| 538 |
)
|
| 539 |
-
|
| 540 |
-
block_state.final_latents = torch.zeros_like(
|
| 541 |
-
block_state.init_latents, device=components.transformer.device
|
| 542 |
-
)
|
| 543 |
self.set_block_state(state, block_state)
|
| 544 |
|
| 545 |
return components, state
|
| 546 |
|
| 547 |
|
| 548 |
-
class
|
| 549 |
"""
|
| 550 |
-
|
| 551 |
|
| 552 |
-
This block
|
| 553 |
-
|
|
|
|
|
|
|
| 554 |
"""
|
| 555 |
|
| 556 |
-
model_name = "
|
| 557 |
|
| 558 |
@property
|
| 559 |
def expected_components(self) -> List[ComponentSpec]:
|
| 560 |
-
return [
|
|
|
|
|
|
|
| 561 |
|
| 562 |
@property
|
| 563 |
def description(self) -> str:
|
| 564 |
return (
|
| 565 |
-
"
|
| 566 |
-
"
|
| 567 |
)
|
| 568 |
|
| 569 |
@property
|
| 570 |
def inputs(self) -> List[InputParam]:
|
| 571 |
return [
|
| 572 |
InputParam(
|
| 573 |
-
"
|
| 574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
type_hint=torch.Tensor,
|
| 576 |
-
description="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
),
|
| 578 |
InputParam(
|
| 579 |
"init_latents",
|
| 580 |
-
required=True,
|
| 581 |
type_hint=torch.Tensor,
|
| 582 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
),
|
| 584 |
InputParam(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
"latents",
|
| 586 |
type_hint=torch.Tensor,
|
| 587 |
-
description="
|
| 588 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
InputParam(
|
| 590 |
"block_idx",
|
| 591 |
required=True,
|
|
@@ -593,6 +957,12 @@ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
|
|
| 593 |
default=0,
|
| 594 |
description="Current block index to process",
|
| 595 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
InputParam(
|
| 597 |
"num_frames_per_block",
|
| 598 |
required=True,
|
|
@@ -623,7 +993,7 @@ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
|
|
| 623 |
) -> PipelineState:
|
| 624 |
block_state = self.get_block_state(state)
|
| 625 |
|
| 626 |
-
num_frames_per_block =
|
| 627 |
block_idx = block_state.block_idx
|
| 628 |
|
| 629 |
# Calculate frame range for current block
|
|
@@ -642,7 +1012,7 @@ class WanRTStreamingExtractBlockLatentsStep(ModularPipelineBlocks):
|
|
| 642 |
return components, state
|
| 643 |
|
| 644 |
|
| 645 |
-
class
|
| 646 |
"""
|
| 647 |
Initializes KV cache and cross-attention cache for streaming generation.
|
| 648 |
|
|
@@ -651,7 +1021,7 @@ class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
|
|
| 651 |
Should be called once at the start of streaming generation.
|
| 652 |
"""
|
| 653 |
|
| 654 |
-
model_name = "
|
| 655 |
|
| 656 |
@property
|
| 657 |
def expected_components(self) -> List[ComponentSpec]:
|
|
@@ -772,7 +1142,7 @@ class WanRTStreamingSetupKVCache(ModularPipelineBlocks):
|
|
| 772 |
return components, state
|
| 773 |
|
| 774 |
|
| 775 |
-
class
|
| 776 |
@property
|
| 777 |
def inputs(self) -> List[InputParam]:
|
| 778 |
return [
|
|
@@ -782,34 +1152,20 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 782 |
description="Current block latents [B, C, num_frames_per_block, H, W]",
|
| 783 |
),
|
| 784 |
InputParam(
|
| 785 |
-
"
|
| 786 |
type_hint=int,
|
| 787 |
-
description="Number of
|
| 788 |
),
|
| 789 |
InputParam(
|
| 790 |
"block_idx",
|
| 791 |
type_hint=int,
|
| 792 |
description="Current block index to process",
|
| 793 |
),
|
| 794 |
-
InputParam(
|
| 795 |
-
"block_mask",
|
| 796 |
-
description="Block-wise causal attention mask",
|
| 797 |
-
),
|
| 798 |
InputParam(
|
| 799 |
"current_start_frame",
|
| 800 |
type_hint=int,
|
| 801 |
description="Starting frame index for current block",
|
| 802 |
),
|
| 803 |
-
InputParam(
|
| 804 |
-
"videos",
|
| 805 |
-
type_hint=torch.Tensor,
|
| 806 |
-
description="Video frames for context encoding",
|
| 807 |
-
),
|
| 808 |
-
InputParam(
|
| 809 |
-
"final_latents",
|
| 810 |
-
type_hint=torch.Tensor,
|
| 811 |
-
description="Full latent buffer [B, C, total_frames, H, W]",
|
| 812 |
-
),
|
| 813 |
InputParam(
|
| 814 |
"prompt_embeds",
|
| 815 |
type_hint=torch.Tensor,
|
|
@@ -825,16 +1181,14 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 825 |
type_hint=torch.Tensor,
|
| 826 |
description="Cross-attention cache",
|
| 827 |
),
|
| 828 |
-
InputParam(
|
| 829 |
-
"encoder_cache",
|
| 830 |
-
description="Encoder feature cache",
|
| 831 |
-
),
|
| 832 |
InputParam(
|
| 833 |
"frame_cache_context",
|
| 834 |
description="Cached context frames for reencoding",
|
| 835 |
),
|
| 836 |
InputParam(
|
| 837 |
-
"
|
|
|
|
|
|
|
| 838 |
),
|
| 839 |
]
|
| 840 |
|
|
@@ -842,9 +1196,7 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 842 |
def expected_configs(self) -> List[ConfigSpec]:
|
| 843 |
return [ConfigSpec("seq_length", 32760)]
|
| 844 |
|
| 845 |
-
def prepare_latents(self, components,
|
| 846 |
-
frames = block_state.frame_cache_context[0].half()
|
| 847 |
-
|
| 848 |
components.vae._enc_feat_map = [None] * 55
|
| 849 |
latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax")
|
| 850 |
latents_mean = (
|
|
@@ -861,30 +1213,23 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 861 |
|
| 862 |
def get_context_frames(self, components, block_state):
|
| 863 |
current_kv_cache_num_frames = components.config.kv_cache_num_frames
|
| 864 |
-
|
| 865 |
-
:, :, : block_state.current_start_frame
|
| 866 |
-
]
|
| 867 |
-
|
| 868 |
-
if (
|
| 869 |
block_state.block_idx - 1
|
| 870 |
-
) *
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
context_frames[:, :, 1:][
|
| 878 |
-
:, :, -current_kv_cache_num_frames + 1 :
|
| 879 |
-
],
|
| 880 |
-
),
|
| 881 |
-
dim=2,
|
| 882 |
-
)
|
| 883 |
else:
|
|
|
|
| 884 |
context_frames = context_frames[:, :, 1:][
|
| 885 |
:, :, -current_kv_cache_num_frames + 1 :
|
| 886 |
]
|
| 887 |
-
first_frame_latent = self.prepare_latents(
|
|
|
|
|
|
|
| 888 |
first_frame_latent = first_frame_latent.to(block_state.latents)
|
| 889 |
context_frames = torch.cat((first_frame_latent, context_frames), dim=2)
|
| 890 |
|
|
@@ -895,20 +1240,15 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 895 |
if block_state.block_idx == 0:
|
| 896 |
return components, state
|
| 897 |
|
| 898 |
-
start_frame = min(
|
| 899 |
-
block_state.current_start_frame, components.config.kv_cache_num_frames
|
| 900 |
-
)
|
| 901 |
context_frames = self.get_context_frames(components, block_state)
|
| 902 |
-
|
| 903 |
-
components.transformer.
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
local_attn_size=-1,
|
| 909 |
-
)
|
| 910 |
)
|
| 911 |
-
components.transformer.block_mask =
|
| 912 |
context_timestep = torch.zeros(
|
| 913 |
(context_frames.shape[0], context_frames.shape[2]),
|
| 914 |
device=components.transformer.device,
|
|
@@ -921,7 +1261,7 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 921 |
kv_cache=block_state.kv_cache,
|
| 922 |
seq_len=components.config.seq_length,
|
| 923 |
crossattn_cache=block_state.crossattn_cache,
|
| 924 |
-
current_start=
|
| 925 |
cache_start=None,
|
| 926 |
)
|
| 927 |
components.transformer.block_mask = None
|
|
@@ -929,13 +1269,13 @@ class WanRTStreamingRecomputeKVCache(ModularPipelineBlocks):
|
|
| 929 |
return components, state
|
| 930 |
|
| 931 |
|
| 932 |
-
class
|
| 933 |
block_classes = [
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
]
|
| 940 |
block_names = [
|
| 941 |
"set_timesteps",
|
|
@@ -953,4 +1293,69 @@ class WanRTStreamingBeforeDenoiseStep(SequentialPipelineBlocks):
|
|
| 953 |
+ " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
|
| 954 |
+ " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
|
| 955 |
+ " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
import inspect
|
| 16 |
from typing import List, Optional, Union, Dict
|
| 17 |
+
from collections import deque
|
| 18 |
|
| 19 |
import torch
|
| 20 |
|
|
|
|
| 26 |
ModularPipeline,
|
| 27 |
ModularPipelineBlocks,
|
| 28 |
SequentialPipelineBlocks,
|
| 29 |
+
AutoPipelineBlocks,
|
| 30 |
PipelineState,
|
| 31 |
)
|
| 32 |
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
class WanInputStep(ModularPipelineBlocks):
|
| 226 |
+
model_name = "wan"
|
| 227 |
|
| 228 |
@property
|
| 229 |
def description(self) -> str:
|
|
|
|
| 239 |
@property
|
| 240 |
def inputs(self) -> List[InputParam]:
|
| 241 |
return [
|
| 242 |
+
InputParam(
|
| 243 |
+
"num_videos_per_prompt",
|
| 244 |
+
default=1,
|
| 245 |
+
description="Number of videos to generate per prompt",
|
| 246 |
+
),
|
| 247 |
InputParam(
|
| 248 |
"prompt_embeds",
|
| 249 |
required=True,
|
|
|
|
| 337 |
return components, state
|
| 338 |
|
| 339 |
|
| 340 |
+
class WanRTSetTimestepsStep(ModularPipelineBlocks):
|
| 341 |
+
model_name = "wan"
|
| 342 |
|
| 343 |
@property
|
| 344 |
def expected_components(self) -> List[ComponentSpec]:
|
|
|
|
| 356 |
InputParam("num_inference_steps", default=4),
|
| 357 |
InputParam("timesteps"),
|
| 358 |
InputParam("sigmas"),
|
| 359 |
+
InputParam("strength", default=1.0),
|
| 360 |
]
|
| 361 |
|
| 362 |
@property
|
|
|
|
| 398 |
]
|
| 399 |
)
|
| 400 |
denoising_steps = torch.linspace(
|
| 401 |
+
block_state.strength * 1000,
|
| 402 |
+
0,
|
| 403 |
+
block_state.num_inference_steps,
|
| 404 |
+
dtype=torch.float32,
|
| 405 |
).to(torch.long)
|
| 406 |
|
| 407 |
block_state.timesteps = zero_padded_timesteps[1000 - denoising_steps]
|
|
|
|
| 413 |
return components, state
|
| 414 |
|
| 415 |
|
| 416 |
+
class WanRTPrepareLatentsStep(ModularPipelineBlocks):
|
| 417 |
+
model_name = "wan"
|
| 418 |
|
| 419 |
@property
|
| 420 |
def expected_components(self) -> List[ComponentSpec]:
|
|
|
|
| 433 |
@property
|
| 434 |
def inputs(self) -> List[InputParam]:
|
| 435 |
return [
|
| 436 |
+
InputParam(
|
| 437 |
+
"height",
|
| 438 |
+
type_hint=int,
|
| 439 |
+
description="Height of the video to generate in pixels",
|
| 440 |
+
),
|
| 441 |
+
InputParam(
|
| 442 |
+
"width",
|
| 443 |
+
type_hint=int,
|
| 444 |
+
description="Width of the video to generate in pixels",
|
| 445 |
+
),
|
| 446 |
+
InputParam(
|
| 447 |
+
"num_blocks",
|
| 448 |
+
type_hint=int,
|
| 449 |
+
description="Number of temporal blocks to generate",
|
| 450 |
+
),
|
| 451 |
+
InputParam(
|
| 452 |
+
"init_latents",
|
| 453 |
+
type_hint=Optional[torch.Tensor],
|
| 454 |
+
description="Pre-initialized latents to use instead of random noise",
|
| 455 |
+
),
|
| 456 |
+
InputParam(
|
| 457 |
+
"num_videos_per_prompt",
|
| 458 |
+
type_hint=int,
|
| 459 |
+
default=1,
|
| 460 |
+
description="Number of videos to generate per prompt",
|
| 461 |
+
),
|
| 462 |
+
InputParam(
|
| 463 |
+
"generator",
|
| 464 |
+
description="Random number generator for reproducible generation",
|
| 465 |
+
),
|
| 466 |
InputParam(
|
| 467 |
"dtype",
|
| 468 |
type_hint=torch.dtype,
|
|
|
|
| 473 |
@property
|
| 474 |
def intermediate_outputs(self) -> List[OutputParam]:
|
| 475 |
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
OutputParam(
|
| 477 |
"init_latents",
|
| 478 |
type_hint=torch.Tensor,
|
| 479 |
description="The initial latents to use for the denoising process",
|
| 480 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
]
|
| 482 |
|
| 483 |
@staticmethod
|
|
|
|
| 498 |
components,
|
| 499 |
batch_size: int,
|
| 500 |
num_channels_latents: int = 16,
|
| 501 |
+
height: int = 480,
|
| 502 |
+
width: int = 832,
|
| 503 |
num_blocks: int = 9,
|
| 504 |
num_frames_per_block: int = 3,
|
| 505 |
dtype: Optional[torch.dtype] = None,
|
|
|
|
| 558 |
block_state.generator,
|
| 559 |
block_state.init_latents,
|
| 560 |
)
|
| 561 |
+
block_state.init_latents = block_state.init_latents.contiguous()
|
|
|
|
|
|
|
|
|
|
| 562 |
self.set_block_state(state, block_state)
|
| 563 |
|
| 564 |
return components, state
|
| 565 |
|
| 566 |
|
| 567 |
+
class WanRTPrepareVideoLatentStep(ModularPipelineBlocks):
|
| 568 |
"""
|
| 569 |
+
Prepares video latents from input PIL images for video-to-video generation.
|
| 570 |
|
| 571 |
+
This block:
|
| 572 |
+
1. Processes input PIL images
|
| 573 |
+
2. Encodes them to latent space using the VAE encoder
|
| 574 |
+
3. Adds noise based on denoising strength for partial denoising
|
| 575 |
"""
|
| 576 |
|
| 577 |
+
model_name = "wan"
|
| 578 |
|
| 579 |
@property
|
| 580 |
def expected_components(self) -> List[ComponentSpec]:
|
| 581 |
+
return [
|
| 582 |
+
ComponentSpec("vae", AutoencoderKLWan),
|
| 583 |
+
]
|
| 584 |
|
| 585 |
@property
|
| 586 |
def description(self) -> str:
|
| 587 |
return (
|
| 588 |
+
"Prepares video latents from input PIL images by encoding to latent space "
|
| 589 |
+
"and optionally adding noise for video-to-video generation."
|
| 590 |
)
|
| 591 |
|
| 592 |
@property
|
| 593 |
def inputs(self) -> List[InputParam]:
|
| 594 |
return [
|
| 595 |
InputParam(
|
| 596 |
+
"video",
|
| 597 |
+
type_hint=list,
|
| 598 |
+
description="List of PIL Images for input video",
|
| 599 |
+
),
|
| 600 |
+
InputParam(
|
| 601 |
+
"height",
|
| 602 |
+
type_hint=int,
|
| 603 |
+
default=480,
|
| 604 |
+
description="Target height for video processing",
|
| 605 |
+
),
|
| 606 |
+
InputParam(
|
| 607 |
+
"width",
|
| 608 |
+
type_hint=int,
|
| 609 |
+
default=832,
|
| 610 |
+
description="Target width for video processing",
|
| 611 |
+
),
|
| 612 |
+
InputParam(
|
| 613 |
+
"strength",
|
| 614 |
+
type_hint=float,
|
| 615 |
+
default=1.0,
|
| 616 |
+
description="Denoising strength (0-1). Lower values preserve more of original video.",
|
| 617 |
+
),
|
| 618 |
+
InputParam(
|
| 619 |
+
"generator",
|
| 620 |
+
description="Random generator for noise",
|
| 621 |
+
),
|
| 622 |
+
InputParam(
|
| 623 |
+
"timesteps",
|
| 624 |
type_hint=torch.Tensor,
|
| 625 |
+
description="All timesteps for noise scheduling",
|
| 626 |
+
),
|
| 627 |
+
InputParam(
|
| 628 |
+
"num_blocks",
|
| 629 |
+
type_hint=int,
|
| 630 |
+
description="Number of blocks for generation",
|
| 631 |
),
|
| 632 |
InputParam(
|
| 633 |
"init_latents",
|
|
|
|
| 634 |
type_hint=torch.Tensor,
|
| 635 |
+
),
|
| 636 |
+
]
|
| 637 |
+
|
| 638 |
+
@property
|
| 639 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 640 |
+
return [
|
| 641 |
+
OutputParam(
|
| 642 |
+
"init_latents",
|
| 643 |
+
type_hint=torch.Tensor,
|
| 644 |
+
description="Noised latents from input video ready for denoising",
|
| 645 |
+
),
|
| 646 |
+
OutputParam(
|
| 647 |
+
"num_blocks",
|
| 648 |
+
type_hint=int,
|
| 649 |
+
description="Updated number of blocks based on video length",
|
| 650 |
+
),
|
| 651 |
+
]
|
| 652 |
+
|
| 653 |
+
def encode_frames(
|
| 654 |
+
self,
|
| 655 |
+
components,
|
| 656 |
+
video: Optional[torch.Tensor] = None,
|
| 657 |
+
timesteps: Optional[torch.Tensor] = None,
|
| 658 |
+
generator: Optional[torch.Generator] = None,
|
| 659 |
+
dtype: Optional[torch.dtype] = None,
|
| 660 |
+
device: Optional[torch.device] = None,
|
| 661 |
+
latents: Optional[torch.Tensor] = None,
|
| 662 |
+
):
|
| 663 |
+
if latents is not None:
|
| 664 |
+
return latents.to(device, dtype)
|
| 665 |
+
|
| 666 |
+
if not hasattr(components.vae, "_enc_feat_map"):
|
| 667 |
+
components.vae.clear_cache()
|
| 668 |
+
else:
|
| 669 |
+
components.vae._enc_feat_map = [None] * 55
|
| 670 |
+
|
| 671 |
+
init_latents = [
|
| 672 |
+
retrieve_latents(
|
| 673 |
+
components.vae.encode(vid.unsqueeze(0).transpose(2, 1)),
|
| 674 |
+
sample_mode="argmax",
|
| 675 |
+
)
|
| 676 |
+
for vid in video
|
| 677 |
+
]
|
| 678 |
+
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
| 679 |
+
|
| 680 |
+
latents_mean = (
|
| 681 |
+
torch.tensor(components.vae.config.latents_mean)
|
| 682 |
+
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
| 683 |
+
.to(device, dtype)
|
| 684 |
+
)
|
| 685 |
+
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
| 686 |
+
1, components.vae.config.z_dim, 1, 1, 1
|
| 687 |
+
).to(device, dtype)
|
| 688 |
+
init_latents = (init_latents - latents_mean) * latents_std
|
| 689 |
+
init_denoising_strength = timesteps[0] / 1000.0
|
| 690 |
+
|
| 691 |
+
# Add noise to latents
|
| 692 |
+
noise = randn_tensor(
|
| 693 |
+
init_latents.shape,
|
| 694 |
+
device=init_latents.device,
|
| 695 |
+
dtype=init_latents.dtype,
|
| 696 |
+
generator=generator,
|
| 697 |
+
)
|
| 698 |
+
init_latents = (
|
| 699 |
+
init_latents * (1.0 - init_denoising_strength)
|
| 700 |
+
+ noise * init_denoising_strength
|
| 701 |
+
)
|
| 702 |
+
init_latents = init_latents.to(components.transformer.dtype).contiguous()
|
| 703 |
+
|
| 704 |
+
return init_latents
|
| 705 |
+
|
| 706 |
+
@torch.no_grad()
|
| 707 |
+
def __call__(
|
| 708 |
+
self, components: ModularPipeline, state: PipelineState
|
| 709 |
+
) -> PipelineState:
|
| 710 |
+
block_state = self.get_block_state(state)
|
| 711 |
+
|
| 712 |
+
if block_state.init_latents is not None:
|
| 713 |
+
block_state.init_latents = block_state.init_latents.to(
|
| 714 |
+
components.transformer.dtype
|
| 715 |
+
)
|
| 716 |
+
self.set_block_state(state, block_state)
|
| 717 |
+
return components, state
|
| 718 |
+
|
| 719 |
+
video = (
|
| 720 |
+
components.video_processor.preprocess(
|
| 721 |
+
block_state.video, block_state.height, block_state.width
|
| 722 |
+
)
|
| 723 |
+
.unsqueeze(0)
|
| 724 |
+
.to(components.vae.device, components.vae.dtype)
|
| 725 |
+
)
|
| 726 |
+
block_state.init_latents = self.encode_frames(
|
| 727 |
+
components,
|
| 728 |
+
video,
|
| 729 |
+
block_state.timesteps,
|
| 730 |
+
block_state.generator,
|
| 731 |
+
components.vae.dtype,
|
| 732 |
+
components.vae.device,
|
| 733 |
+
block_state.init_latents,
|
| 734 |
+
)
|
| 735 |
+
block_state.init_latents = block_state.init_latents.to(
|
| 736 |
+
components.transformer.dtype
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
self.set_block_state(state, block_state)
|
| 740 |
+
return components, state
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class WanRTStreamPrepareVideoLatentStep(ModularPipelineBlocks):
|
| 744 |
+
"""
|
| 745 |
+
Prepares video latents from input PIL images for video-to-video generation.
|
| 746 |
+
|
| 747 |
+
This block:
|
| 748 |
+
1. Processes input PIL images
|
| 749 |
+
2. Encodes them to latent space using the VAE encoder
|
| 750 |
+
3. Adds noise based on denoising strength for partial denoising
|
| 751 |
+
"""
|
| 752 |
+
|
| 753 |
+
model_name = "wan"
|
| 754 |
+
|
| 755 |
+
@property
|
| 756 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 757 |
+
return [
|
| 758 |
+
ComponentSpec("vae", AutoencoderKLWan),
|
| 759 |
+
]
|
| 760 |
+
|
| 761 |
+
@property
|
| 762 |
+
def description(self) -> str:
|
| 763 |
+
return (
|
| 764 |
+
"Prepares video latents from input PIL images by encoding to latent space "
|
| 765 |
+
"and optionally adding noise for video-to-video generation."
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
@property
|
| 769 |
+
def inputs(self) -> List[InputParam]:
|
| 770 |
+
return [
|
| 771 |
+
InputParam(
|
| 772 |
+
"video_stream",
|
| 773 |
+
type_hint=list,
|
| 774 |
+
description="List of PIL Images for input video",
|
| 775 |
+
),
|
| 776 |
+
InputParam(
|
| 777 |
+
"height",
|
| 778 |
+
type_hint=int,
|
| 779 |
+
default=480,
|
| 780 |
+
description="Target height for video processing",
|
| 781 |
),
|
| 782 |
InputParam(
|
| 783 |
+
"width",
|
| 784 |
+
type_hint=int,
|
| 785 |
+
default=832,
|
| 786 |
+
description="Target width for video processing",
|
| 787 |
+
),
|
| 788 |
+
InputParam(
|
| 789 |
+
"generator",
|
| 790 |
+
type_hint=torch.Generator,
|
| 791 |
+
description="Random generator for noise",
|
| 792 |
+
),
|
| 793 |
+
InputParam(
|
| 794 |
+
"timesteps",
|
| 795 |
+
type_hint=torch.Tensor,
|
| 796 |
+
description="All timesteps for noise scheduling",
|
| 797 |
+
),
|
| 798 |
+
InputParam(
|
| 799 |
+
"block_idx",
|
| 800 |
+
type_hint=int,
|
| 801 |
+
description="Index of current block to denoise",
|
| 802 |
+
),
|
| 803 |
+
InputParam(
|
| 804 |
+
"num_blocks",
|
| 805 |
+
type_hint=int,
|
| 806 |
+
description="Total number of blocks to denoise",
|
| 807 |
+
),
|
| 808 |
+
InputParam(
|
| 809 |
+
"input_frames_cache",
|
| 810 |
+
default=deque(maxlen=24),
|
| 811 |
+
description="Cached input video frames for context encoding",
|
| 812 |
+
),
|
| 813 |
+
]
|
| 814 |
+
|
| 815 |
+
@property
|
| 816 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 817 |
+
return [
|
| 818 |
+
OutputParam(
|
| 819 |
"latents",
|
| 820 |
type_hint=torch.Tensor,
|
| 821 |
+
description="Noised latents from input video ready for denoising",
|
| 822 |
),
|
| 823 |
+
OutputParam(
|
| 824 |
+
"current_start_frame",
|
| 825 |
+
type_hint=int,
|
| 826 |
+
),
|
| 827 |
+
]
|
| 828 |
+
|
| 829 |
+
def encode_frames(
|
| 830 |
+
self,
|
| 831 |
+
components,
|
| 832 |
+
video: Optional[torch.Tensor] = None,
|
| 833 |
+
dtype: Optional[torch.dtype] = None,
|
| 834 |
+
device: Optional[torch.device] = None,
|
| 835 |
+
latents: Optional[torch.Tensor] = None,
|
| 836 |
+
):
|
| 837 |
+
if latents is not None:
|
| 838 |
+
return latents.to(device, dtype)
|
| 839 |
+
|
| 840 |
+
if not hasattr(components.vae, "_enc_feat_map"):
|
| 841 |
+
components.vae.clear_cache()
|
| 842 |
+
else:
|
| 843 |
+
components.vae._enc_feat_map = [None] * 55
|
| 844 |
+
|
| 845 |
+
init_latents = [
|
| 846 |
+
retrieve_latents(
|
| 847 |
+
components.vae.encode(vid.unsqueeze(0).transpose(2, 1)),
|
| 848 |
+
sample_mode="argmax",
|
| 849 |
+
)
|
| 850 |
+
for vid in video
|
| 851 |
+
]
|
| 852 |
+
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
| 853 |
+
|
| 854 |
+
latents_mean = (
|
| 855 |
+
torch.tensor(components.vae.config.latents_mean)
|
| 856 |
+
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
| 857 |
+
.to(device, dtype)
|
| 858 |
+
)
|
| 859 |
+
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
| 860 |
+
1, components.vae.config.z_dim, 1, 1, 1
|
| 861 |
+
).to(device, dtype)
|
| 862 |
+
init_latents = (init_latents - latents_mean) * latents_std
|
| 863 |
+
|
| 864 |
+
return init_latents
|
| 865 |
+
|
| 866 |
+
def resample_frames(self, frames, target_length):
|
| 867 |
+
"""Resample a list to the target length using linear interpolation of indices"""
|
| 868 |
+
if len(frames) == target_length:
|
| 869 |
+
return frames
|
| 870 |
+
|
| 871 |
+
indices = (
|
| 872 |
+
torch.linspace(0, len(frames) - 1, target_length, device="cpu")
|
| 873 |
+
.round()
|
| 874 |
+
.long()
|
| 875 |
+
)
|
| 876 |
+
return [frames[i] for i in indices]
|
| 877 |
+
|
| 878 |
+
@torch.no_grad()
|
| 879 |
+
def __call__(
|
| 880 |
+
self, components: ModularPipeline, state: PipelineState
|
| 881 |
+
) -> PipelineState:
|
| 882 |
+
block_state = self.get_block_state(state)
|
| 883 |
+
|
| 884 |
+
if block_state.video_stream is None:
|
| 885 |
+
raise ValueError(
|
| 886 |
+
"Stream Video to Video requires an input video. Please provide a`video` input to the Pipeline"
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
block_state.input_frames_cache.extend(block_state.video_stream)
|
| 890 |
+
video = (
|
| 891 |
+
components.video_processor.preprocess(
|
| 892 |
+
list(block_state.input_frames_cache),
|
| 893 |
+
block_state.height,
|
| 894 |
+
block_state.width,
|
| 895 |
+
)
|
| 896 |
+
.unsqueeze(0)
|
| 897 |
+
.to(components.vae.device, components.vae.dtype)
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
block_state.current_start_frame = (
|
| 901 |
+
block_state.block_idx * components.config.num_frames_per_block
|
| 902 |
+
)
|
| 903 |
+
init_latents = self.encode_frames(
|
| 904 |
+
components,
|
| 905 |
+
video,
|
| 906 |
+
components.vae.dtype,
|
| 907 |
+
components.vae.device,
|
| 908 |
+
None,
|
| 909 |
+
)
|
| 910 |
+
init_latents = init_latents[:, :, -components.config.num_frames_per_block :]
|
| 911 |
+
|
| 912 |
+
strength = block_state.timesteps[0] / 1000.0
|
| 913 |
+
noise = randn_tensor(
|
| 914 |
+
init_latents.shape,
|
| 915 |
+
device=components.transformer.device,
|
| 916 |
+
dtype=components.transformer.dtype,
|
| 917 |
+
generator=block_state.generator,
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
init_latents = init_latents * (1.0 - strength) + noise * strength
|
| 921 |
+
init_latents = init_latents.to(components.transformer.dtype).contiguous()
|
| 922 |
+
|
| 923 |
+
block_state.latents = init_latents
|
| 924 |
+
|
| 925 |
+
self.set_block_state(state, block_state)
|
| 926 |
+
return components, state
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
class WanRTExtractBlockLatentsStep(ModularPipelineBlocks):
|
| 930 |
+
"""
|
| 931 |
+
Extracts a single block of latents from the full video buffer for streaming generation.
|
| 932 |
+
|
| 933 |
+
This block simply slices the final_latents buffer to get the current block's latents.
|
| 934 |
+
The final_latents buffer should be created beforehand using WanRTPrepareAllLatents.
|
| 935 |
+
"""
|
| 936 |
+
|
| 937 |
+
model_name = "wan"
|
| 938 |
+
|
| 939 |
+
@property
|
| 940 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 941 |
+
return []
|
| 942 |
+
|
| 943 |
+
@property
|
| 944 |
+
def description(self) -> str:
|
| 945 |
+
return (
|
| 946 |
+
"Extracts a single block from the full latent buffer for streaming generation. "
|
| 947 |
+
"Slices final_latents based on block_idx to get current block's latents."
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
@property
|
| 951 |
+
def inputs(self) -> List[InputParam]:
|
| 952 |
+
return [
|
| 953 |
InputParam(
|
| 954 |
"block_idx",
|
| 955 |
required=True,
|
|
|
|
| 957 |
default=0,
|
| 958 |
description="Current block index to process",
|
| 959 |
),
|
| 960 |
+
InputParam(
|
| 961 |
+
"init_latents",
|
| 962 |
+
required=True,
|
| 963 |
+
type_hint=torch.Tensor,
|
| 964 |
+
description="Full latent buffer [B, C, total_frames, H, W]",
|
| 965 |
+
),
|
| 966 |
InputParam(
|
| 967 |
"num_frames_per_block",
|
| 968 |
required=True,
|
|
|
|
| 993 |
) -> PipelineState:
|
| 994 |
block_state = self.get_block_state(state)
|
| 995 |
|
| 996 |
+
num_frames_per_block = components.config.num_frames_per_block
|
| 997 |
block_idx = block_state.block_idx
|
| 998 |
|
| 999 |
# Calculate frame range for current block
|
|
|
|
| 1012 |
return components, state
|
| 1013 |
|
| 1014 |
|
| 1015 |
+
class WanRTSetupKVCache(ModularPipelineBlocks):
|
| 1016 |
"""
|
| 1017 |
Initializes KV cache and cross-attention cache for streaming generation.
|
| 1018 |
|
|
|
|
| 1021 |
Should be called once at the start of streaming generation.
|
| 1022 |
"""
|
| 1023 |
|
| 1024 |
+
model_name = "wan"
|
| 1025 |
|
| 1026 |
@property
|
| 1027 |
def expected_components(self) -> List[ComponentSpec]:
|
|
|
|
| 1142 |
return components, state
|
| 1143 |
|
| 1144 |
|
| 1145 |
+
class WanRTRecomputeKVCache(ModularPipelineBlocks):
|
| 1146 |
@property
|
| 1147 |
def inputs(self) -> List[InputParam]:
|
| 1148 |
return [
|
|
|
|
| 1152 |
description="Current block latents [B, C, num_frames_per_block, H, W]",
|
| 1153 |
),
|
| 1154 |
InputParam(
|
| 1155 |
+
"num_blocks",
|
| 1156 |
type_hint=int,
|
| 1157 |
+
description="Number of blocks to denoise",
|
| 1158 |
),
|
| 1159 |
InputParam(
|
| 1160 |
"block_idx",
|
| 1161 |
type_hint=int,
|
| 1162 |
description="Current block index to process",
|
| 1163 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1164 |
InputParam(
|
| 1165 |
"current_start_frame",
|
| 1166 |
type_hint=int,
|
| 1167 |
description="Starting frame index for current block",
|
| 1168 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1169 |
InputParam(
|
| 1170 |
"prompt_embeds",
|
| 1171 |
type_hint=torch.Tensor,
|
|
|
|
| 1181 |
type_hint=torch.Tensor,
|
| 1182 |
description="Cross-attention cache",
|
| 1183 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1184 |
InputParam(
|
| 1185 |
"frame_cache_context",
|
| 1186 |
description="Cached context frames for reencoding",
|
| 1187 |
),
|
| 1188 |
InputParam(
|
| 1189 |
+
"current_denoised_latents",
|
| 1190 |
+
type_hint=torch.Tensor,
|
| 1191 |
+
description="Current denoised latents",
|
| 1192 |
),
|
| 1193 |
]
|
| 1194 |
|
|
|
|
| 1196 |
def expected_configs(self) -> List[ConfigSpec]:
|
| 1197 |
return [ConfigSpec("seq_length", 32760)]
|
| 1198 |
|
| 1199 |
+
def prepare_latents(self, components, frames):
|
|
|
|
|
|
|
| 1200 |
components.vae._enc_feat_map = [None] * 55
|
| 1201 |
latents = retrieve_latents(components.vae.encode(frames), sample_mode="argmax")
|
| 1202 |
latents_mean = (
|
|
|
|
| 1213 |
|
| 1214 |
def get_context_frames(self, components, block_state):
|
| 1215 |
current_kv_cache_num_frames = components.config.kv_cache_num_frames
|
| 1216 |
+
total_frames_generated = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
block_state.block_idx - 1
|
| 1218 |
+
) * components.config.num_frames_per_block
|
| 1219 |
+
|
| 1220 |
+
if total_frames_generated < current_kv_cache_num_frames:
|
| 1221 |
+
context_frames = block_state.current_denoised_latents[
|
| 1222 |
+
:, :, :current_kv_cache_num_frames
|
| 1223 |
+
]
|
| 1224 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1225 |
else:
|
| 1226 |
+
context_frames = block_state.current_denoised_latents
|
| 1227 |
context_frames = context_frames[:, :, 1:][
|
| 1228 |
:, :, -current_kv_cache_num_frames + 1 :
|
| 1229 |
]
|
| 1230 |
+
first_frame_latent = self.prepare_latents(
|
| 1231 |
+
components, frames=block_state.frame_cache_context[0].half()
|
| 1232 |
+
)
|
| 1233 |
first_frame_latent = first_frame_latent.to(block_state.latents)
|
| 1234 |
context_frames = torch.cat((first_frame_latent, context_frames), dim=2)
|
| 1235 |
|
|
|
|
| 1240 |
if block_state.block_idx == 0:
|
| 1241 |
return components, state
|
| 1242 |
|
|
|
|
|
|
|
|
|
|
| 1243 |
context_frames = self.get_context_frames(components, block_state)
|
| 1244 |
+
block_mask = components.transformer._prepare_blockwise_causal_attn_mask(
|
| 1245 |
+
components.transformer.device,
|
| 1246 |
+
num_frames=context_frames.shape[2],
|
| 1247 |
+
frame_seqlen=components.config.frame_seq_length,
|
| 1248 |
+
num_frame_per_block=components.config.num_frames_per_block,
|
| 1249 |
+
local_attn_size=-1,
|
|
|
|
|
|
|
| 1250 |
)
|
| 1251 |
+
components.transformer.block_mask = block_mask
|
| 1252 |
context_timestep = torch.zeros(
|
| 1253 |
(context_frames.shape[0], context_frames.shape[2]),
|
| 1254 |
device=components.transformer.device,
|
|
|
|
| 1261 |
kv_cache=block_state.kv_cache,
|
| 1262 |
seq_len=components.config.seq_length,
|
| 1263 |
crossattn_cache=block_state.crossattn_cache,
|
| 1264 |
+
current_start=0, # when updating the kv cache with block_mask the current_start is unused
|
| 1265 |
cache_start=None,
|
| 1266 |
)
|
| 1267 |
components.transformer.block_mask = None
|
|
|
|
| 1269 |
return components, state
|
| 1270 |
|
| 1271 |
|
| 1272 |
+
class WanRTBeforeDenoiseStep(SequentialPipelineBlocks):
|
| 1273 |
block_classes = [
|
| 1274 |
+
WanRTSetTimestepsStep,
|
| 1275 |
+
WanRTPrepareLatentsStep,
|
| 1276 |
+
WanRTExtractBlockLatentsStep,
|
| 1277 |
+
WanRTSetupKVCache,
|
| 1278 |
+
WanRTRecomputeKVCache,
|
| 1279 |
]
|
| 1280 |
block_names = [
|
| 1281 |
"set_timesteps",
|
|
|
|
| 1293 |
+ " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
|
| 1294 |
+ " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
|
| 1295 |
+ " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
|
| 1296 |
+
+ " - `WanRTPrepareVideoLatentStep` is used to prepare video latents from input video\n"
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
class WanRTVideoToVideoBeforeDenoiseStep(SequentialPipelineBlocks):
|
| 1301 |
+
block_classes = [
|
| 1302 |
+
WanRTSetTimestepsStep,
|
| 1303 |
+
WanRTPrepareVideoLatentStep,
|
| 1304 |
+
WanRTExtractBlockLatentsStep,
|
| 1305 |
+
WanRTSetupKVCache,
|
| 1306 |
+
WanRTRecomputeKVCache,
|
| 1307 |
+
]
|
| 1308 |
+
block_names = [
|
| 1309 |
+
"set_timesteps",
|
| 1310 |
+
"prepare_video_latents",
|
| 1311 |
+
"extract_block_init_latents",
|
| 1312 |
+
"setup_kv_cache",
|
| 1313 |
+
"recompute_kv_cache",
|
| 1314 |
+
]
|
| 1315 |
+
|
| 1316 |
+
@property
|
| 1317 |
+
def description(self):
|
| 1318 |
+
return (
|
| 1319 |
+
"Before denoise step that prepare the inputs for the denoise step.\n"
|
| 1320 |
+
+ "This is a sequential pipeline blocks:\n"
|
| 1321 |
+
+ " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
|
| 1322 |
+
+ " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
|
| 1323 |
+
+ " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
|
| 1324 |
+
+ " - `WanRTPrepareVideoLatentStep` is used to prepare video latents from input video\n"
|
| 1325 |
)
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
class WanRTStreamVideoToVideoBeforeDenoiseStep(SequentialPipelineBlocks):
|
| 1329 |
+
block_classes = [
|
| 1330 |
+
WanRTSetTimestepsStep,
|
| 1331 |
+
WanRTStreamPrepareVideoLatentStep,
|
| 1332 |
+
WanRTSetupKVCache,
|
| 1333 |
+
WanRTRecomputeKVCache,
|
| 1334 |
+
]
|
| 1335 |
+
block_names = [
|
| 1336 |
+
"set_timesteps",
|
| 1337 |
+
"prepare_video_latents",
|
| 1338 |
+
"setup_kv_cache",
|
| 1339 |
+
"recompute_kv_cache",
|
| 1340 |
+
]
|
| 1341 |
+
|
| 1342 |
+
@property
|
| 1343 |
+
def description(self):
|
| 1344 |
+
return (
|
| 1345 |
+
"Before denoise step that prepare the inputs for the denoise step.\n"
|
| 1346 |
+
+ "This is a sequential pipeline blocks:\n"
|
| 1347 |
+
+ " - `WanRTInputStep` is used to adjust the batch size of the model inputs\n"
|
| 1348 |
+
+ " - `WanRTSetTimestepsStep` is used to set the timesteps\n"
|
| 1349 |
+
+ " - `WanRTPrepareLatentsStep` is used to prepare the latents\n"
|
| 1350 |
+
+ " - `WanRTPrepareVideoLatentStep` is used to prepare video latents from input video\n"
|
| 1351 |
+
)
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
class WanRTAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
| 1355 |
+
block_classes = [
|
| 1356 |
+
WanRTVideoToVideoBeforeDenoiseStep,
|
| 1357 |
+
WanRTStreamVideoToVideoBeforeDenoiseStep,
|
| 1358 |
+
WanRTBeforeDenoiseStep,
|
| 1359 |
+
]
|
| 1360 |
+
block_names = ["video-to-video", "stream-to-video", "text-to-video"]
|
| 1361 |
+
block_trigger_inputs = ["video", "video_stream", None]
|
decoders.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
from typing import Any, List, Tuple, Union
|
|
|
|
| 16 |
|
| 17 |
import numpy as np
|
| 18 |
import PIL
|
|
@@ -35,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
| 35 |
|
| 36 |
|
| 37 |
class WanRTDecodeStep(ModularPipelineBlocks):
|
| 38 |
-
model_name = "
|
| 39 |
decoder_cache = []
|
| 40 |
|
| 41 |
@property
|
|
@@ -62,7 +63,15 @@ class WanRTDecodeStep(ModularPipelineBlocks):
|
|
| 62 |
@property
|
| 63 |
def inputs(self) -> List[Tuple[str, Any]]:
|
| 64 |
return [
|
| 65 |
-
InputParam(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
InputParam(
|
| 67 |
"latents",
|
| 68 |
required=True,
|
|
@@ -71,15 +80,12 @@ class WanRTDecodeStep(ModularPipelineBlocks):
|
|
| 71 |
),
|
| 72 |
InputParam(
|
| 73 |
"frame_cache_context",
|
| 74 |
-
description="
|
| 75 |
-
|
| 76 |
-
InputParam(
|
| 77 |
-
"block_idx",
|
| 78 |
-
description="The denoised latents from the denoising step",
|
| 79 |
),
|
| 80 |
InputParam(
|
| 81 |
"decoder_cache",
|
| 82 |
-
description="
|
| 83 |
),
|
| 84 |
]
|
| 85 |
|
|
@@ -100,6 +106,10 @@ class WanRTDecodeStep(ModularPipelineBlocks):
|
|
| 100 |
block_state = self.get_block_state(state)
|
| 101 |
vae_dtype = components.vae.dtype
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# Disable clearing cache
|
| 104 |
if block_state.block_idx == 0:
|
| 105 |
components.vae.clear_cache()
|
|
@@ -134,12 +144,10 @@ class WanRTDecodeStep(ModularPipelineBlocks):
|
|
| 134 |
|
| 135 |
block_state.decoder_cache = components.vae._feat_map
|
| 136 |
block_state.frame_cache_context.extend(videos.split(1, dim=2))
|
| 137 |
-
|
| 138 |
videos = components.video_processor.postprocess_video(
|
| 139 |
videos, output_type=block_state.output_type
|
| 140 |
)
|
| 141 |
block_state.videos = videos
|
| 142 |
-
|
| 143 |
self.set_block_state(state, block_state)
|
| 144 |
|
| 145 |
return components, state
|
|
|
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
from typing import Any, List, Tuple, Union
|
| 16 |
+
from collections import deque
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import PIL
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
class WanRTDecodeStep(ModularPipelineBlocks):
|
| 39 |
+
model_name = "wan"
|
| 40 |
decoder_cache = []
|
| 41 |
|
| 42 |
@property
|
|
|
|
| 63 |
@property
|
| 64 |
def inputs(self) -> List[Tuple[str, Any]]:
|
| 65 |
return [
|
| 66 |
+
InputParam(
|
| 67 |
+
"output_type",
|
| 68 |
+
default="pil",
|
| 69 |
+
description="The output format for the generated videos (pil, latent, pt, or np)",
|
| 70 |
+
),
|
| 71 |
+
InputParam(
|
| 72 |
+
"block_idx",
|
| 73 |
+
description="Index of the current block being decoded",
|
| 74 |
+
),
|
| 75 |
InputParam(
|
| 76 |
"latents",
|
| 77 |
required=True,
|
|
|
|
| 80 |
),
|
| 81 |
InputParam(
|
| 82 |
"frame_cache_context",
|
| 83 |
+
description="Deque object to store most recently decoded frames",
|
| 84 |
+
type_hint=deque
|
|
|
|
|
|
|
|
|
|
| 85 |
),
|
| 86 |
InputParam(
|
| 87 |
"decoder_cache",
|
| 88 |
+
description="Decoder feature cache",
|
| 89 |
),
|
| 90 |
]
|
| 91 |
|
|
|
|
| 106 |
block_state = self.get_block_state(state)
|
| 107 |
vae_dtype = components.vae.dtype
|
| 108 |
|
| 109 |
+
if block_state.frame_cache_context is None:
|
| 110 |
+
frame_cache_len = 1 + (components.config.kv_cache_num_frames - 1) * 4
|
| 111 |
+
block_state.frame_cache_context = deque(maxlen=frame_cache_len)
|
| 112 |
+
|
| 113 |
# Disable clearing cache
|
| 114 |
if block_state.block_idx == 0:
|
| 115 |
components.vae.clear_cache()
|
|
|
|
| 144 |
|
| 145 |
block_state.decoder_cache = components.vae._feat_map
|
| 146 |
block_state.frame_cache_context.extend(videos.split(1, dim=2))
|
|
|
|
| 147 |
videos = components.video_processor.postprocess_video(
|
| 148 |
videos, output_type=block_state.output_type
|
| 149 |
)
|
| 150 |
block_state.videos = videos
|
|
|
|
| 151 |
self.set_block_state(state, block_state)
|
| 152 |
|
| 153 |
return components, state
|
denoise.py
CHANGED
|
@@ -16,8 +16,6 @@ from typing import Any, List, Tuple
|
|
| 16 |
|
| 17 |
import torch
|
| 18 |
|
| 19 |
-
from diffusers.configuration_utils import FrozenDict
|
| 20 |
-
from diffusers.guiders import ClassifierFreeGuidance
|
| 21 |
from diffusers.models import AutoModel
|
| 22 |
from diffusers.schedulers import UniPCMultistepScheduler
|
| 23 |
from diffusers.utils import logging
|
|
@@ -39,8 +37,8 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
|
| 39 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
|
| 41 |
|
| 42 |
-
class
|
| 43 |
-
model_name = "
|
| 44 |
|
| 45 |
@property
|
| 46 |
def expected_components(self) -> List[ComponentSpec]:
|
|
@@ -51,14 +49,12 @@ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
|
|
| 51 |
return (
|
| 52 |
"Step within the denoising loop that denoise the latents with guidance. "
|
| 53 |
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
| 54 |
-
"object (e.g. `
|
| 55 |
)
|
| 56 |
|
| 57 |
@property
|
| 58 |
def inputs(self) -> List[Tuple[str, Any]]:
|
| 59 |
return [
|
| 60 |
-
InputParam("attention_kwargs"),
|
| 61 |
-
InputParam("block_idx"),
|
| 62 |
InputParam(
|
| 63 |
"latents",
|
| 64 |
required=True,
|
|
@@ -69,36 +65,25 @@ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
|
|
| 69 |
"prompt_embeds",
|
| 70 |
required=True,
|
| 71 |
type_hint=torch.Tensor,
|
|
|
|
| 72 |
),
|
| 73 |
InputParam(
|
| 74 |
"kv_cache",
|
| 75 |
required=True,
|
| 76 |
type_hint=torch.Tensor,
|
|
|
|
| 77 |
),
|
| 78 |
InputParam(
|
| 79 |
"crossattn_cache",
|
| 80 |
required=True,
|
| 81 |
type_hint=torch.Tensor,
|
|
|
|
| 82 |
),
|
| 83 |
InputParam(
|
| 84 |
"current_start_frame",
|
| 85 |
required=True,
|
| 86 |
type_hint=torch.Tensor,
|
| 87 |
-
|
| 88 |
-
InputParam(
|
| 89 |
-
"num_inference_steps",
|
| 90 |
-
required=True,
|
| 91 |
-
type_hint=int,
|
| 92 |
-
default=4,
|
| 93 |
-
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
| 94 |
-
),
|
| 95 |
-
InputParam(
|
| 96 |
-
kwargs_type="guider_input_fields",
|
| 97 |
-
description=(
|
| 98 |
-
"All conditional model inputs that need to be prepared with guider. "
|
| 99 |
-
"It should contain prompt_embeds/negative_prompt_embeds. "
|
| 100 |
-
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
| 101 |
-
),
|
| 102 |
),
|
| 103 |
]
|
| 104 |
|
|
@@ -116,20 +101,21 @@ class WanRTStreamingLoopDenoiser(ModularPipelineBlocks):
|
|
| 116 |
|
| 117 |
block_state.noise_pred = components.transformer(
|
| 118 |
x=block_state.latents,
|
| 119 |
-
t=t.expand(
|
|
|
|
|
|
|
| 120 |
context=block_state.prompt_embeds,
|
| 121 |
kv_cache=block_state.kv_cache,
|
| 122 |
seq_len=components.config.seq_length,
|
| 123 |
crossattn_cache=block_state.crossattn_cache,
|
| 124 |
current_start=start_frame * components.config.frame_seq_length,
|
| 125 |
-
cache_start=
|
| 126 |
)
|
| 127 |
-
|
| 128 |
return components, block_state
|
| 129 |
|
| 130 |
|
| 131 |
-
class
|
| 132 |
-
model_name = "
|
| 133 |
|
| 134 |
@property
|
| 135 |
def expected_components(self) -> List[ComponentSpec]:
|
|
@@ -142,18 +128,24 @@ class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
|
|
| 142 |
return (
|
| 143 |
"step within the denoising loop that update the latents. "
|
| 144 |
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
| 145 |
-
"object (e.g. `
|
| 146 |
)
|
| 147 |
|
| 148 |
@property
|
| 149 |
def inputs(self) -> List[Tuple[str, Any]]:
|
| 150 |
-
return []
|
| 151 |
-
|
| 152 |
-
@property
|
| 153 |
-
def intermediate_inputs(self) -> List[str]:
|
| 154 |
return [
|
| 155 |
-
InputParam(
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
]
|
| 158 |
|
| 159 |
@property
|
|
@@ -185,14 +177,13 @@ class WanRTStreamingLoopAfterDenoiser(ModularPipelineBlocks):
|
|
| 185 |
block_state.latents.double()
|
| 186 |
- sigma_t.double() * block_state.noise_pred.double()
|
| 187 |
).to(latents_dtype)
|
| 188 |
-
|
| 189 |
block_state.latents = latents
|
| 190 |
|
| 191 |
return components, block_state
|
| 192 |
|
| 193 |
|
| 194 |
-
class
|
| 195 |
-
model_name = "
|
| 196 |
|
| 197 |
@property
|
| 198 |
def description(self) -> str:
|
|
@@ -201,7 +192,7 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
|
| 201 |
"Recomputes cache from context frames, denoises current block, and updates cache."
|
| 202 |
)
|
| 203 |
|
| 204 |
-
def add_noise(self,
|
| 205 |
timesteps = block_state.all_timesteps
|
| 206 |
sigmas = block_state.sigmas.to(timesteps.device)
|
| 207 |
|
|
@@ -232,38 +223,25 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
|
| 232 |
"all_timesteps",
|
| 233 |
required=True,
|
| 234 |
type_hint=torch.Tensor,
|
| 235 |
-
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
| 236 |
),
|
| 237 |
InputParam(
|
| 238 |
"sigmas",
|
| 239 |
required=True,
|
| 240 |
type_hint=torch.Tensor,
|
| 241 |
-
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
| 242 |
),
|
| 243 |
-
InputParam("
|
| 244 |
InputParam(
|
| 245 |
"num_inference_steps",
|
| 246 |
required=True,
|
| 247 |
type_hint=int,
|
| 248 |
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
| 249 |
),
|
| 250 |
-
InputParam(
|
| 251 |
-
"num_frames_per_block",
|
| 252 |
-
required=True,
|
| 253 |
-
type_hint=int,
|
| 254 |
-
default=3,
|
| 255 |
-
),
|
| 256 |
InputParam(
|
| 257 |
"current_start_frame",
|
| 258 |
required=True,
|
| 259 |
type_hint=int,
|
| 260 |
),
|
| 261 |
-
InputParam(
|
| 262 |
-
"block_idx",
|
| 263 |
-
),
|
| 264 |
-
InputParam(
|
| 265 |
-
"generator",
|
| 266 |
-
),
|
| 267 |
]
|
| 268 |
|
| 269 |
@torch.no_grad()
|
|
@@ -279,7 +257,6 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
|
| 279 |
|
| 280 |
block_state.latents = (
|
| 281 |
self.add_noise(
|
| 282 |
-
components,
|
| 283 |
block_state,
|
| 284 |
block_state.latents.transpose(1, 2).squeeze(0),
|
| 285 |
randn_tensor(
|
|
@@ -290,31 +267,23 @@ class WanRTStreamingDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
|
| 290 |
),
|
| 291 |
t1.expand(
|
| 292 |
block_state.latents.shape[0],
|
| 293 |
-
|
| 294 |
),
|
| 295 |
-
i,
|
| 296 |
)
|
| 297 |
.unsqueeze(0)
|
| 298 |
.transpose(1, 2)
|
| 299 |
)
|
| 300 |
|
| 301 |
-
|
| 302 |
-
block_state.final_latents[
|
| 303 |
-
:,
|
| 304 |
-
:,
|
| 305 |
-
block_state.current_start_frame : block_state.current_start_frame
|
| 306 |
-
+ block_state.num_frames_per_block,
|
| 307 |
-
] = block_state.latents
|
| 308 |
-
|
| 309 |
self.set_block_state(state, block_state)
|
| 310 |
|
| 311 |
return components, state
|
| 312 |
|
| 313 |
|
| 314 |
-
class
|
| 315 |
block_classes = [
|
| 316 |
-
|
| 317 |
-
|
| 318 |
]
|
| 319 |
block_names = ["denoiser", "after_denoiser"]
|
| 320 |
|
|
@@ -322,9 +291,9 @@ class WanRTStreamingDenoiseStep(WanRTStreamingDenoiseLoopWrapper):
|
|
| 322 |
def description(self) -> str:
|
| 323 |
return (
|
| 324 |
"Denoise step that iteratively denoise the latents. \n"
|
| 325 |
-
"Its loop logic is defined in `
|
| 326 |
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
| 327 |
-
" - `
|
| 328 |
-
" - `
|
| 329 |
"This block supports both text2vid tasks."
|
| 330 |
)
|
|
|
|
| 16 |
|
| 17 |
import torch
|
| 18 |
|
|
|
|
|
|
|
| 19 |
from diffusers.models import AutoModel
|
| 20 |
from diffusers.schedulers import UniPCMultistepScheduler
|
| 21 |
from diffusers.utils import logging
|
|
|
|
| 37 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
|
| 39 |
|
| 40 |
+
class WanRTLoopDenoiser(ModularPipelineBlocks):
|
| 41 |
+
model_name = "wan"
|
| 42 |
|
| 43 |
@property
|
| 44 |
def expected_components(self) -> List[ComponentSpec]:
|
|
|
|
| 49 |
return (
|
| 50 |
"Step within the denoising loop that denoise the latents with guidance. "
|
| 51 |
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
| 52 |
+
"object (e.g. `WanRTDenoiseLoopWrapper`)"
|
| 53 |
)
|
| 54 |
|
| 55 |
@property
|
| 56 |
def inputs(self) -> List[Tuple[str, Any]]:
|
| 57 |
return [
|
|
|
|
|
|
|
| 58 |
InputParam(
|
| 59 |
"latents",
|
| 60 |
required=True,
|
|
|
|
| 65 |
"prompt_embeds",
|
| 66 |
required=True,
|
| 67 |
type_hint=torch.Tensor,
|
| 68 |
+
description="Text embeddings to condition the denoising process",
|
| 69 |
),
|
| 70 |
InputParam(
|
| 71 |
"kv_cache",
|
| 72 |
required=True,
|
| 73 |
type_hint=torch.Tensor,
|
| 74 |
+
description="KV Cache of the transformer model",
|
| 75 |
),
|
| 76 |
InputParam(
|
| 77 |
"crossattn_cache",
|
| 78 |
required=True,
|
| 79 |
type_hint=torch.Tensor,
|
| 80 |
+
description="Cross Attention Cache of the transformer model",
|
| 81 |
),
|
| 82 |
InputParam(
|
| 83 |
"current_start_frame",
|
| 84 |
required=True,
|
| 85 |
type_hint=torch.Tensor,
|
| 86 |
+
description="Starting frame index for the current block in the streaming generation",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
),
|
| 88 |
]
|
| 89 |
|
|
|
|
| 101 |
|
| 102 |
block_state.noise_pred = components.transformer(
|
| 103 |
x=block_state.latents,
|
| 104 |
+
t=t.expand(
|
| 105 |
+
block_state.latents.shape[0], components.config.num_frames_per_block
|
| 106 |
+
),
|
| 107 |
context=block_state.prompt_embeds,
|
| 108 |
kv_cache=block_state.kv_cache,
|
| 109 |
seq_len=components.config.seq_length,
|
| 110 |
crossattn_cache=block_state.crossattn_cache,
|
| 111 |
current_start=start_frame * components.config.frame_seq_length,
|
| 112 |
+
cache_start=None,
|
| 113 |
)
|
|
|
|
| 114 |
return components, block_state
|
| 115 |
|
| 116 |
|
| 117 |
+
class WanRTLoopAfterDenoiser(ModularPipelineBlocks):
|
| 118 |
+
model_name = "wan"
|
| 119 |
|
| 120 |
@property
|
| 121 |
def expected_components(self) -> List[ComponentSpec]:
|
|
|
|
| 128 |
return (
|
| 129 |
"step within the denoising loop that update the latents. "
|
| 130 |
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
| 131 |
+
"object (e.g. `WanRTDenoiseLoopWrapper`)"
|
| 132 |
)
|
| 133 |
|
| 134 |
@property
|
| 135 |
def inputs(self) -> List[Tuple[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return [
|
| 137 |
+
InputParam(
|
| 138 |
+
"latents",
|
| 139 |
+
description="Current latents being denoised",
|
| 140 |
+
),
|
| 141 |
+
InputParam(
|
| 142 |
+
"all_timesteps",
|
| 143 |
+
description="All timesteps for the denoising process",
|
| 144 |
+
),
|
| 145 |
+
InputParam(
|
| 146 |
+
"sigmas",
|
| 147 |
+
description="Noise schedule sigmas for each timestep",
|
| 148 |
+
),
|
| 149 |
]
|
| 150 |
|
| 151 |
@property
|
|
|
|
| 177 |
block_state.latents.double()
|
| 178 |
- sigma_t.double() * block_state.noise_pred.double()
|
| 179 |
).to(latents_dtype)
|
|
|
|
| 180 |
block_state.latents = latents
|
| 181 |
|
| 182 |
return components, block_state
|
| 183 |
|
| 184 |
|
| 185 |
+
class WanRTDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
| 186 |
+
model_name = "wan"
|
| 187 |
|
| 188 |
@property
|
| 189 |
def description(self) -> str:
|
|
|
|
| 192 |
"Recomputes cache from context frames, denoises current block, and updates cache."
|
| 193 |
)
|
| 194 |
|
| 195 |
+
def add_noise(self, block_state, sample, noise, timestep):
|
| 196 |
timesteps = block_state.all_timesteps
|
| 197 |
sigmas = block_state.sigmas.to(timesteps.device)
|
| 198 |
|
|
|
|
| 223 |
"all_timesteps",
|
| 224 |
required=True,
|
| 225 |
type_hint=torch.Tensor,
|
|
|
|
| 226 |
),
|
| 227 |
InputParam(
|
| 228 |
"sigmas",
|
| 229 |
required=True,
|
| 230 |
type_hint=torch.Tensor,
|
|
|
|
| 231 |
),
|
| 232 |
+
InputParam("current_denoised_latents", type_hint=torch.Tensor),
|
| 233 |
InputParam(
|
| 234 |
"num_inference_steps",
|
| 235 |
required=True,
|
| 236 |
type_hint=int,
|
| 237 |
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
| 238 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
InputParam(
|
| 240 |
"current_start_frame",
|
| 241 |
required=True,
|
| 242 |
type_hint=int,
|
| 243 |
),
|
| 244 |
+
InputParam("generator", type_hint=torch.Generator),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
]
|
| 246 |
|
| 247 |
@torch.no_grad()
|
|
|
|
| 257 |
|
| 258 |
block_state.latents = (
|
| 259 |
self.add_noise(
|
|
|
|
| 260 |
block_state,
|
| 261 |
block_state.latents.transpose(1, 2).squeeze(0),
|
| 262 |
randn_tensor(
|
|
|
|
| 267 |
),
|
| 268 |
t1.expand(
|
| 269 |
block_state.latents.shape[0],
|
| 270 |
+
components.config.num_frames_per_block,
|
| 271 |
),
|
|
|
|
| 272 |
)
|
| 273 |
.unsqueeze(0)
|
| 274 |
.transpose(1, 2)
|
| 275 |
)
|
| 276 |
|
| 277 |
+
block_state.current_denoised_latents = block_state.latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
self.set_block_state(state, block_state)
|
| 279 |
|
| 280 |
return components, state
|
| 281 |
|
| 282 |
|
| 283 |
+
class WanRTDenoiseStep(WanRTDenoiseLoopWrapper):
|
| 284 |
block_classes = [
|
| 285 |
+
WanRTLoopDenoiser,
|
| 286 |
+
WanRTLoopAfterDenoiser,
|
| 287 |
]
|
| 288 |
block_names = ["denoiser", "after_denoiser"]
|
| 289 |
|
|
|
|
| 291 |
def description(self) -> str:
|
| 292 |
return (
|
| 293 |
"Denoise step that iteratively denoise the latents. \n"
|
| 294 |
+
"Its loop logic is defined in `WanRTDenoiseLoopWrapper.__call__` method \n"
|
| 295 |
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
| 296 |
+
" - `WanRTLoopDenoiser`\n"
|
| 297 |
+
" - `WanRTLoopAfterDenoiser`\n"
|
| 298 |
"This block supports both text2vid tasks."
|
| 299 |
)
|
encoders.py
CHANGED
|
@@ -56,7 +56,7 @@ def prompt_clean(text):
|
|
| 56 |
return text
|
| 57 |
|
| 58 |
|
| 59 |
-
class
|
| 60 |
model_name = "WanRTStreaming"
|
| 61 |
|
| 62 |
@property
|
|
@@ -83,8 +83,14 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
|
|
| 83 |
@property
|
| 84 |
def inputs(self) -> List[InputParam]:
|
| 85 |
return [
|
| 86 |
-
InputParam(
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
InputParam(
|
| 89 |
"prompt_embeds",
|
| 90 |
type_hint=torch.Tensor,
|
|
@@ -95,7 +101,10 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
|
|
| 95 |
type_hint=torch.Tensor,
|
| 96 |
description="negative text embeddings used to guide the image generation",
|
| 97 |
),
|
| 98 |
-
InputParam(
|
|
|
|
|
|
|
|
|
|
| 99 |
]
|
| 100 |
|
| 101 |
@property
|
|
@@ -205,7 +214,7 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
|
|
| 205 |
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
|
| 206 |
|
| 207 |
if prompt_embeds is None:
|
| 208 |
-
prompt_embeds =
|
| 209 |
components, prompt, max_sequence_length, device
|
| 210 |
)
|
| 211 |
|
|
@@ -229,10 +238,8 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
|
|
| 229 |
" the batch size of `prompt`."
|
| 230 |
)
|
| 231 |
|
| 232 |
-
negative_prompt_embeds = (
|
| 233 |
-
|
| 234 |
-
components, negative_prompt, max_sequence_length, device
|
| 235 |
-
)
|
| 236 |
)
|
| 237 |
|
| 238 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
@@ -266,7 +273,7 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
|
|
| 266 |
(
|
| 267 |
block_state.prompt_embeds,
|
| 268 |
block_state.negative_prompt_embeds,
|
| 269 |
-
) =
|
| 270 |
components,
|
| 271 |
block_state.prompt,
|
| 272 |
block_state.device,
|
|
@@ -276,6 +283,7 @@ class WanRTStreamingTextEncoderStep(ModularPipelineBlocks):
|
|
| 276 |
prompt_embeds=block_state.prompt_embeds,
|
| 277 |
negative_prompt_embeds=block_state.negative_prompt_embeds,
|
| 278 |
)
|
|
|
|
| 279 |
|
| 280 |
# Add outputs
|
| 281 |
self.set_block_state(state, block_state)
|
|
|
|
| 56 |
return text
|
| 57 |
|
| 58 |
|
| 59 |
+
class WanRTTextEncoderStep(ModularPipelineBlocks):
|
| 60 |
model_name = "WanRTStreaming"
|
| 61 |
|
| 62 |
@property
|
|
|
|
| 83 |
@property
|
| 84 |
def inputs(self) -> List[InputParam]:
|
| 85 |
return [
|
| 86 |
+
InputParam(
|
| 87 |
+
"prompt",
|
| 88 |
+
description="The prompt or prompts to guide the video generation",
|
| 89 |
+
),
|
| 90 |
+
InputParam(
|
| 91 |
+
"negative_prompt",
|
| 92 |
+
description="The prompt or prompts not to guide the video generation",
|
| 93 |
+
),
|
| 94 |
InputParam(
|
| 95 |
"prompt_embeds",
|
| 96 |
type_hint=torch.Tensor,
|
|
|
|
| 101 |
type_hint=torch.Tensor,
|
| 102 |
description="negative text embeddings used to guide the image generation",
|
| 103 |
),
|
| 104 |
+
InputParam(
|
| 105 |
+
"attention_kwargs",
|
| 106 |
+
description="Additional keyword arguments to pass to the attention mechanism",
|
| 107 |
+
),
|
| 108 |
]
|
| 109 |
|
| 110 |
@property
|
|
|
|
| 214 |
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
|
| 215 |
|
| 216 |
if prompt_embeds is None:
|
| 217 |
+
prompt_embeds = WanRTTextEncoderStep._get_t5_prompt_embeds(
|
| 218 |
components, prompt, max_sequence_length, device
|
| 219 |
)
|
| 220 |
|
|
|
|
| 238 |
" the batch size of `prompt`."
|
| 239 |
)
|
| 240 |
|
| 241 |
+
negative_prompt_embeds = WanRTTextEncoderStep._get_t5_prompt_embeds(
|
| 242 |
+
components, negative_prompt, max_sequence_length, device
|
|
|
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
|
|
| 273 |
(
|
| 274 |
block_state.prompt_embeds,
|
| 275 |
block_state.negative_prompt_embeds,
|
| 276 |
+
) = WanRTTextEncoderStep.encode_prompt(
|
| 277 |
components,
|
| 278 |
block_state.prompt,
|
| 279 |
block_state.device,
|
|
|
|
| 283 |
prompt_embeds=block_state.prompt_embeds,
|
| 284 |
negative_prompt_embeds=block_state.negative_prompt_embeds,
|
| 285 |
)
|
| 286 |
+
block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
|
| 287 |
|
| 288 |
# Add outputs
|
| 289 |
self.set_block_state(state, block_state)
|
modular_blocks.py
CHANGED
|
@@ -16,27 +16,24 @@ from diffusers.utils import logging
|
|
| 16 |
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
| 17 |
from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
|
| 18 |
|
| 19 |
-
from .before_denoise import
|
| 20 |
from .decoders import WanRTDecodeStep
|
| 21 |
-
from .encoders import
|
| 22 |
-
from .denoise import
|
| 23 |
|
| 24 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 25 |
|
| 26 |
-
|
|
|
|
| 27 |
[
|
| 28 |
-
("text_encoder",
|
| 29 |
-
("before_denoise",
|
| 30 |
-
("denoise",
|
| 31 |
("decode", WanRTDecodeStep),
|
| 32 |
]
|
| 33 |
)
|
| 34 |
|
| 35 |
-
ALL_BLOCKS = {
|
| 36 |
-
"text2video": TEXT2VIDEO_BLOCKS,
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
|
| 40 |
-
class
|
| 41 |
-
block_classes = list(
|
| 42 |
-
block_names = list(
|
|
|
|
| 16 |
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
| 17 |
from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
|
| 18 |
|
| 19 |
+
from .before_denoise import WanRTAutoBeforeDenoiseStep
|
| 20 |
from .decoders import WanRTDecodeStep
|
| 21 |
+
from .encoders import WanRTTextEncoderStep
|
| 22 |
+
from .denoise import WanRTDenoiseStep
|
| 23 |
|
| 24 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 25 |
|
| 26 |
+
|
| 27 |
+
AUTO_BLOCKS = InsertableDict(
|
| 28 |
[
|
| 29 |
+
("text_encoder", WanRTTextEncoderStep),
|
| 30 |
+
("before_denoise", WanRTAutoBeforeDenoiseStep),
|
| 31 |
+
("denoise", WanRTDenoiseStep),
|
| 32 |
("decode", WanRTDecodeStep),
|
| 33 |
]
|
| 34 |
)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
class WanRTBlocks(SequentialPipelineBlocks):
|
| 38 |
+
block_classes = list(AUTO_BLOCKS.copy().values())
|
| 39 |
+
block_names = list(AUTO_BLOCKS.copy().keys())
|
modular_config.json
CHANGED
|
@@ -2,6 +2,6 @@
|
|
| 2 |
"_class_name": "WanRTBlocks",
|
| 3 |
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
"auto_map": {
|
| 5 |
-
"ModularPipelineBlocks": "modular_blocks.
|
| 6 |
}
|
| 7 |
}
|
|
|
|
| 2 |
"_class_name": "WanRTBlocks",
|
| 3 |
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
"auto_map": {
|
| 5 |
+
"ModularPipelineBlocks": "modular_blocks.WanRTBlocks"
|
| 6 |
}
|
| 7 |
}
|
modular_model_index.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
-
"_blocks_class_name": "
|
| 3 |
-
"_class_name": "
|
| 4 |
"_diffusers_version": "0.36.0.dev0",
|
| 5 |
"frame_seq_length": 1560,
|
| 6 |
"kv_cache_num_frames": 3,
|
|
@@ -52,7 +52,7 @@
|
|
| 52 |
null,
|
| 53 |
null,
|
| 54 |
{
|
| 55 |
-
"repo": "
|
| 56 |
"revision": null,
|
| 57 |
"subfolder": "transformer",
|
| 58 |
"type_hint": [
|
|
|
|
| 1 |
{
|
| 2 |
+
"_blocks_class_name": "WanRTBlocks",
|
| 3 |
+
"_class_name": "WanModularPipeline",
|
| 4 |
"_diffusers_version": "0.36.0.dev0",
|
| 5 |
"frame_seq_length": 1560,
|
| 6 |
"kv_cache_num_frames": 3,
|
|
|
|
| 52 |
null,
|
| 53 |
null,
|
| 54 |
{
|
| 55 |
+
"repo": "krea/krea-realtime-video",
|
| 56 |
"revision": null,
|
| 57 |
"subfolder": "transformer",
|
| 58 |
"type_hint": [
|
transformer/attention.py
CHANGED
|
@@ -1,65 +1,40 @@
|
|
| 1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
import torch
|
| 3 |
-
from typing import Optional
|
| 4 |
import os
|
| 5 |
import warnings
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
from sageattention import sageattn
|
| 30 |
-
|
| 31 |
-
@torch.library.custom_op(
|
| 32 |
-
"mylib::sageattn", mutates_args={"q", "k", "v"}, device_types="cuda"
|
| 33 |
)
|
| 34 |
-
def sageattn_func(
|
| 35 |
-
q: torch.Tensor,
|
| 36 |
-
k: torch.Tensor,
|
| 37 |
-
v: torch.Tensor,
|
| 38 |
-
attn_mask: Optional[torch.Tensor] = None,
|
| 39 |
-
dropout_p: float = 0,
|
| 40 |
-
is_causal: bool = False,
|
| 41 |
-
) -> torch.Tensor:
|
| 42 |
-
return sageattn(
|
| 43 |
-
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
| 44 |
-
)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
| 51 |
-
_sageattn_func = sageattn_func
|
| 52 |
-
_SAGEATTN_AVAILABLE = True
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
if isinstance(e, ModuleNotFoundError):
|
| 57 |
-
print("sageattention package is not installed")
|
| 58 |
-
elif isinstance(e, ImportError) and "DLL" in str(e):
|
| 59 |
-
print("sageattention DLL loading error")
|
| 60 |
-
_sageattn_func = None
|
| 61 |
-
|
| 62 |
-
return _SAGEATTN_AVAILABLE
|
| 63 |
|
| 64 |
|
| 65 |
def _is_hopper_gpu():
|
|
@@ -69,65 +44,41 @@ def _is_hopper_gpu():
|
|
| 69 |
device_name = torch.cuda.get_device_name(0).lower()
|
| 70 |
return "h100" in device_name or "hopper" in device_name
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
from flash_attn import flash_attn_func
|
| 83 |
-
import flash_attn_interface
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
_flash_attn_interface = flash_attn_interface
|
| 88 |
-
# FA3 optimizations only available on Hopper GPUs
|
| 89 |
-
_FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
|
| 90 |
-
except ModuleNotFoundError:
|
| 91 |
-
_FLASH_ATTN_3_AVAILABLE = False
|
| 92 |
-
_flash_attn_func = None
|
| 93 |
-
_flash_attn_interface = None
|
| 94 |
|
| 95 |
-
|
| 96 |
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
_FLASH_ATTN_2_AVAILABLE = False
|
| 106 |
-
try:
|
| 107 |
-
import flash_attn
|
| 108 |
-
|
| 109 |
-
_flash_attn = flash_attn
|
| 110 |
-
_FLASH_ATTN_2_AVAILABLE = True
|
| 111 |
-
except ModuleNotFoundError:
|
| 112 |
-
_FLASH_ATTN_2_AVAILABLE = False
|
| 113 |
-
|
| 114 |
-
return _FLASH_ATTN_2_AVAILABLE
|
| 115 |
|
| 116 |
__all__ = ["flash_attention", "attention"]
|
| 117 |
|
| 118 |
-
|
| 119 |
-
# Compatibility getters for external code
|
| 120 |
-
def sageattn_func():
|
| 121 |
-
"""Getter for sageattn_func - initializes if needed."""
|
| 122 |
-
_init_sageattention()
|
| 123 |
-
return _sageattn_func
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def SAGEATTN_AVAILABLE():
|
| 127 |
-
"""Getter for SAGEATTN_AVAILABLE - initializes if needed."""
|
| 128 |
-
return _init_sageattention()
|
| 129 |
-
|
| 130 |
-
|
| 131 |
def flash_attention(
|
| 132 |
q,
|
| 133 |
k,
|
|
@@ -156,14 +107,15 @@ def flash_attention(
|
|
| 156 |
deterministic: bool. If True, slightly slower and uses more memory.
|
| 157 |
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 158 |
"""
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
| 167 |
q,
|
| 168 |
k,
|
| 169 |
v,
|
|
@@ -205,15 +157,15 @@ def flash_attention(
|
|
| 205 |
if q_scale is not None:
|
| 206 |
q = q * q_scale
|
| 207 |
|
| 208 |
-
if version is not None and version == 3 and not
|
| 209 |
warnings.warn(
|
| 210 |
"Flash attention 3 is not available, use flash attention 2 instead."
|
| 211 |
)
|
| 212 |
|
| 213 |
# apply attention
|
| 214 |
-
if (version is None or version == 3) and
|
| 215 |
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 216 |
-
x =
|
| 217 |
q=q,
|
| 218 |
k=k,
|
| 219 |
v=v,
|
|
@@ -230,8 +182,8 @@ def flash_attention(
|
|
| 230 |
deterministic=deterministic,
|
| 231 |
).unflatten(0, (b, lq))
|
| 232 |
else:
|
| 233 |
-
assert
|
| 234 |
-
x =
|
| 235 |
q=q,
|
| 236 |
k=k,
|
| 237 |
v=v,
|
|
@@ -270,12 +222,8 @@ def attention(
|
|
| 270 |
fa_version=None,
|
| 271 |
# og_dtype=torch.bfloat16,
|
| 272 |
):
|
| 273 |
-
# Initialize attention modules
|
| 274 |
-
sageattn_available = _init_sageattention()
|
| 275 |
-
flash_attn_2_available = _init_flash_attention_2()
|
| 276 |
-
flash_attn_3_available = _init_flash_attention_3()
|
| 277 |
|
| 278 |
-
if
|
| 279 |
# print("Using sageattention")
|
| 280 |
attn_mask = None
|
| 281 |
|
|
@@ -284,14 +232,14 @@ def attention(
|
|
| 284 |
k = k.transpose(1, 2).to(dtype)
|
| 285 |
v = v.transpose(1, 2).to(dtype)
|
| 286 |
|
| 287 |
-
out =
|
| 288 |
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
| 289 |
)
|
| 290 |
|
| 291 |
out = out.transpose(1, 2).contiguous().to(og_dtype)
|
| 292 |
return out
|
| 293 |
|
| 294 |
-
elif
|
| 295 |
return flash_attention(
|
| 296 |
q=q,
|
| 297 |
k=k,
|
|
|
|
| 1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
import torch
|
|
|
|
| 3 |
import os
|
| 4 |
import warnings
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from diffusers.utils import is_kernels_available
|
| 7 |
+
|
| 8 |
+
SAGEATTN_AVAILABLE = False
|
| 9 |
+
try:
|
| 10 |
+
if os.getenv("DISABLE_SAGEATTENTION", "0") != "0":
|
| 11 |
+
raise Exception("DISABLE_SAGEATTENTION is set")
|
| 12 |
+
|
| 13 |
+
from sageattention import sageattn
|
| 14 |
+
|
| 15 |
+
@torch.library.custom_op(
|
| 16 |
+
"mylib::sageattn", mutates_args={"q", "k", "v"}, device_types="cuda"
|
| 17 |
+
)
|
| 18 |
+
def sageattn_func(
|
| 19 |
+
q: torch.Tensor,
|
| 20 |
+
k: torch.Tensor,
|
| 21 |
+
v: torch.Tensor,
|
| 22 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 23 |
+
dropout_p: float = 0,
|
| 24 |
+
is_causal: bool = False,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
return sageattn(
|
| 27 |
+
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
@sageattn_func.register_fake
|
| 31 |
+
def _sageattn_fake(q, k, v, attn_mask=None, dropout_p=0, is_causal=False):
|
| 32 |
+
return torch.empty(*q.shape, device=q.device, dtype=q.dtype)
|
| 33 |
|
| 34 |
+
SAGEATTN_AVAILABLE = True
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
except Exception as e:
|
| 37 |
+
sageattn_func = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def _is_hopper_gpu():
|
|
|
|
| 44 |
device_name = torch.cuda.get_device_name(0).lower()
|
| 45 |
return "h100" in device_name or "hopper" in device_name
|
| 46 |
|
| 47 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 48 |
+
try:
|
| 49 |
+
import flash_attn_interface
|
| 50 |
+
FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
|
| 51 |
+
except ModuleNotFoundError:
|
| 52 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 53 |
|
| 54 |
+
FLASH_ATTN_3_HUB_AVAILABLE = False
|
| 55 |
+
try:
|
| 56 |
+
use_hub_kernels = os.getenv("DIFFUSERS_ENABLE_HUB_KERNELS", "false").upper() in ["1", "TRUE"]
|
| 57 |
+
if use_hub_kernels and not is_kernels_available():
|
| 58 |
+
raise EnvironmentError((
|
| 59 |
+
"Attempting to use Hub Kernels for Flash Attention 3,"
|
| 60 |
+
"but the `kernels` library was not found in your environment. "
|
| 61 |
+
"Please install via `pip install kernels`"
|
| 62 |
+
))
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
from kernels import get_kernel
|
| 65 |
+
flash_attn_3_hub = get_kernel("kernels-community/flash-attn3", revision="fake-ops-return-probs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
FLASH_ATTN_3_HUB_AVAILABLE = _is_hopper_gpu()
|
| 68 |
|
| 69 |
+
except:
|
| 70 |
+
FLASH_ATTN_3_HUB_AVAILABLE = False
|
| 71 |
|
| 72 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 73 |
+
try:
|
| 74 |
+
import flash_attn
|
| 75 |
|
| 76 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 77 |
+
except ModuleNotFoundError:
|
| 78 |
+
FLASH_ATTN_2_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
__all__ = ["flash_attention", "attention"]
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def flash_attention(
|
| 83 |
q,
|
| 84 |
k,
|
|
|
|
| 107 |
deterministic: bool. If True, slightly slower and uses more memory.
|
| 108 |
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 109 |
"""
|
| 110 |
+
if not FLASH_ATTN_3_AVAILABLE or not FLASH_ATTN_3_HUB_AVAILABLE:
|
| 111 |
+
return flash_attn.flash_attn_func(
|
| 112 |
+
q,
|
| 113 |
+
k,
|
| 114 |
+
v,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
elif FLASH_ATTN_3_HUB_AVAILABLE:
|
| 118 |
+
return flash_attn_3_hub.flash_attn_func(
|
| 119 |
q,
|
| 120 |
k,
|
| 121 |
v,
|
|
|
|
| 157 |
if q_scale is not None:
|
| 158 |
q = q * q_scale
|
| 159 |
|
| 160 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 161 |
warnings.warn(
|
| 162 |
"Flash attention 3 is not available, use flash attention 2 instead."
|
| 163 |
)
|
| 164 |
|
| 165 |
# apply attention
|
| 166 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 167 |
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 168 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 169 |
q=q,
|
| 170 |
k=k,
|
| 171 |
v=v,
|
|
|
|
| 182 |
deterministic=deterministic,
|
| 183 |
).unflatten(0, (b, lq))
|
| 184 |
else:
|
| 185 |
+
assert FLASH_ATTN_3_AVAILABLE
|
| 186 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 187 |
q=q,
|
| 188 |
k=k,
|
| 189 |
v=v,
|
|
|
|
| 222 |
fa_version=None,
|
| 223 |
# og_dtype=torch.bfloat16,
|
| 224 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
if SAGEATTN_AVAILABLE:
|
| 227 |
# print("Using sageattention")
|
| 228 |
attn_mask = None
|
| 229 |
|
|
|
|
| 232 |
k = k.transpose(1, 2).to(dtype)
|
| 233 |
v = v.transpose(1, 2).to(dtype)
|
| 234 |
|
| 235 |
+
out = sageattn_func(
|
| 236 |
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
| 237 |
)
|
| 238 |
|
| 239 |
out = out.transpose(1, 2).contiguous().to(og_dtype)
|
| 240 |
return out
|
| 241 |
|
| 242 |
+
elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 243 |
return flash_attention(
|
| 244 |
q=q,
|
| 245 |
k=k,
|
transformer/causal_model.py
CHANGED
|
@@ -18,6 +18,7 @@ from torch.nn.attention.flex_attention import BlockMask
|
|
| 18 |
|
| 19 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 20 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
| 21 |
|
| 22 |
flex_attention = torch.compile(
|
| 23 |
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
|
@@ -642,7 +643,7 @@ class CausalHead(nn.Module):
|
|
| 642 |
return x
|
| 643 |
|
| 644 |
|
| 645 |
-
class CausalWanModel(ModelMixin, ConfigMixin):
|
| 646 |
r"""
|
| 647 |
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 648 |
"""
|
|
|
|
| 18 |
|
| 19 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 20 |
from diffusers.models.modeling_utils import ModelMixin
|
| 21 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 22 |
|
| 23 |
flex_attention = torch.compile(
|
| 24 |
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
|
|
|
| 643 |
return x
|
| 644 |
|
| 645 |
|
| 646 |
+
class CausalWanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
| 647 |
r"""
|
| 648 |
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 649 |
"""
|
transformer/model.py
CHANGED
|
@@ -10,13 +10,11 @@ from einops import repeat
|
|
| 10 |
from .attention import (
|
| 11 |
flash_attention,
|
| 12 |
sageattn_func,
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
)
|
| 17 |
|
| 18 |
-
print("SAGEATTN_AVAILABLE:", _SAGEATTN_AVAILABLE)
|
| 19 |
-
|
| 20 |
__all__ = ["WanModel"]
|
| 21 |
|
| 22 |
|
|
@@ -153,7 +151,7 @@ class WanSelfAttention(nn.Module):
|
|
| 153 |
|
| 154 |
q, k, v = qkv_fn(x)
|
| 155 |
|
| 156 |
-
if
|
| 157 |
# print("Using sageattention in crossattn")
|
| 158 |
og_dtype = q.dtype
|
| 159 |
q = q.transpose(1, 2).to(dtype)
|
|
@@ -209,7 +207,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|
| 209 |
v = self.v(context).view(b, -1, n, d)
|
| 210 |
|
| 211 |
# compute attention
|
| 212 |
-
if
|
| 213 |
# print("Using sageattention in crossattn")
|
| 214 |
dtype = torch.bfloat16
|
| 215 |
og_dtype = q.dtype
|
|
@@ -222,7 +220,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|
| 222 |
v=v,
|
| 223 |
)
|
| 224 |
x = x.transpose(1, 2).contiguous().to(og_dtype)
|
| 225 |
-
elif
|
| 226 |
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 227 |
else:
|
| 228 |
dtype = torch.bfloat16
|
|
|
|
| 10 |
from .attention import (
|
| 11 |
flash_attention,
|
| 12 |
sageattn_func,
|
| 13 |
+
SAGEATTN_AVAILABLE,
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE,
|
| 15 |
+
FLASH_ATTN_3_AVAILABLE,
|
| 16 |
)
|
| 17 |
|
|
|
|
|
|
|
| 18 |
__all__ = ["WanModel"]
|
| 19 |
|
| 20 |
|
|
|
|
| 151 |
|
| 152 |
q, k, v = qkv_fn(x)
|
| 153 |
|
| 154 |
+
if SAGEATTN_AVAILABLE:
|
| 155 |
# print("Using sageattention in crossattn")
|
| 156 |
og_dtype = q.dtype
|
| 157 |
q = q.transpose(1, 2).to(dtype)
|
|
|
|
| 207 |
v = self.v(context).view(b, -1, n, d)
|
| 208 |
|
| 209 |
# compute attention
|
| 210 |
+
if SAGEATTN_AVAILABLE:
|
| 211 |
# print("Using sageattention in crossattn")
|
| 212 |
dtype = torch.bfloat16
|
| 213 |
og_dtype = q.dtype
|
|
|
|
| 220 |
v=v,
|
| 221 |
)
|
| 222 |
x = x.transpose(1, 2).contiguous().to(og_dtype)
|
| 223 |
+
elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 224 |
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 225 |
else:
|
| 226 |
dtype = torch.bfloat16
|