Flash attention optimization for significant speedup. - old title: Optimization tips to maximize generation speed?
EDIT: For a big boost in generation speed:
- update KJ nodes
- install flash-attn if you don't already have it ( https://mjunya.com/flash-attention-prebuild-wheels/ )
- use "Patch Flash Attention KJ"
I am seeing about a 50% speedup on a 5090.
Hiya,
The model is quite demanding, even on a 5090, generating a ~2 Mpix image takes over a minute.
I've tried KJ Sage attention patcher, torch.compile and neither seem to work or do anything. Could be a user error.
Easycache works but decreases quality by smoothing out details which is what this model really shines at.
Any good tips to make it run faster?
Thanks for supporting this model, it's awesome!
Oh, I just noticed the new Patch Flash Attention KJ node and gave it a try.
Down to ~1.7s/it from ~2.6s/it on a 2048x1088 image, using fp8. Holy moly!
Yeah I figured flash attention actually gives considerable speed boost over sdpa:
torch 2.12.0+cu132 | NVIDIA GeForce RTX 4090 | flash_attn: True
----------------------------------------------------------------------------------------------------
B=1 H=18 L= 6385 D=256 | mem_eff= 9.007ms math= 46.183ms flash= 5.221ms | mem_eff/flash= 1.73x
B=1 H=18 L= 6385 D=128 | mem_eff= 3.852ms math= 34.916ms flash= 2.626ms | mem_eff/flash= 1.47x
B=1 H=18 L= 6385 D= 64 | mem_eff= 1.564ms math= 33.437ms flash= 1.356ms | mem_eff/flash= 1.15x
----------------------------------------------------------------------------------------------------
B=1 H=24 L= 4096 D=128 | mem_eff= 2.015ms math= 18.854ms flash= 1.280ms | mem_eff/flash= 1.57x
B=1 H=24 L= 4096 D= 64 | mem_eff= 0.829ms math= 17.414ms flash= 0.662ms | mem_eff/flash= 1.25x
Lots of easy install wheels available here: https://mjunya.com/flash-attention-prebuild-wheels/
Sageattention can't work due to the model using head dim 256, sage only supports up to 128.
Another small boost and peak VRAM reduction is chunking the FFN and running RoPE in bf16 (will have to see if this can be default), available to test via this node:
You can also pretty safely run the uncond model as nvfp4.
EasyCache has pretty bad quality hit, but if you update ComfyUI to latest nightly, you can try using the CFG Override to drop cfg to 1.0, so those steps are twice as fast, that didn't seem to hit the quality as much.
CFG Override to drop cfg to 1.0
so i dont have to load unconditional model?
CFG Override to drop cfg to 1.0
so i dont have to load unconditional model?
You should use it in most cases, the CFG Override comment means that if you drop it to 1.0 instead of 3.0, those steps it's active at (last few) are done at cfg 1.0, which halves the inference time for those steps. It trades some sharpness to faster inference.
10-15% speed boost with fa at 4060ti with 16gb torch 2.11.0+cu130
Got torch.compile to work with TorchCompileModelAdvanced (KJNodes):
- Updated to the latest KJNodes (as I noticed a torch.compile specific update https://github.com/kijai/ComfyUI-KJNodes/commit/fadde42973faa83b50cb73c8cd7d584d3744febb)
- Set "dynamic" to false.
2048x1088, fp8:
torch.compile disabled, ~1.7s/t
torch.compile enabled, ~1.6s/t
So that's a pretty slim ~6% increase, but not nothing. The compile is very quick so at least for now, I can't see any downsides to using this.
Odd, I see 0% speed increase on blackwell. only compile node helps slightly. cfg override set to 1 doesn't seem to be doing barely anything either.
Nvfp4 unconditional does, but seems to affect lighting pretty drastically.
Odd, I see 0% speed increase on blackwell. only compile node helps slightly. cfg override set to 1 doesn't seem to be doing barely anything either.
Nvfp4 unconditional does, but seems to affect lighting pretty drastically.
Flash attention is only expected to give a boost in Windows, if you're on Linux the pytorch attention already has a flash attention kernel that's automatically selected and just as fast.
Odd, I see 0% speed increase on blackwell. only compile node helps slightly. cfg override set to 1 doesn't seem to be doing barely anything either.
Nvfp4 unconditional does, but seems to affect lighting pretty drastically.Flash attention is only expected to give a boost in Windows, if you're on Linux the pytorch attention already has a flash attention kernel that's automatically selected and just as fast.
maybe unrelated, do you know why comfyui does not support fa3? it's faster than fa2 for me in vllm with hopper cards
Odd, I see 0% speed increase on blackwell. only compile node helps slightly. cfg override set to 1 doesn't seem to be doing barely anything either.
Nvfp4 unconditional does, but seems to affect lighting pretty drastically.Flash attention is only expected to give a boost in Windows, if you're on Linux the pytorch attention already has a flash attention kernel that's automatically selected and just as fast.
Ok thanks for the response, that's exactly what i was thinking. Pytorch flex attention built in should have had the same effect.
Yeah I figured flash attention actually gives considerable speed boost over sdpa:
torch 2.12.0+cu132 | NVIDIA GeForce RTX 4090 | flash_attn: True ---------------------------------------------------------------------------------------------------- B=1 H=18 L= 6385 D=256 | mem_eff= 9.007ms math= 46.183ms flash= 5.221ms | mem_eff/flash= 1.73x B=1 H=18 L= 6385 D=128 | mem_eff= 3.852ms math= 34.916ms flash= 2.626ms | mem_eff/flash= 1.47x B=1 H=18 L= 6385 D= 64 | mem_eff= 1.564ms math= 33.437ms flash= 1.356ms | mem_eff/flash= 1.15x ---------------------------------------------------------------------------------------------------- B=1 H=24 L= 4096 D=128 | mem_eff= 2.015ms math= 18.854ms flash= 1.280ms | mem_eff/flash= 1.57x B=1 H=24 L= 4096 D= 64 | mem_eff= 0.829ms math= 17.414ms flash= 0.662ms | mem_eff/flash= 1.25xLots of easy install wheels available here: https://mjunya.com/flash-attention-prebuild-wheels/
Sageattention can't work due to the model using head dim 256, sage only supports up to 128.
Another small boost and peak VRAM reduction is chunking the FFN and running RoPE in bf16 (will have to see if this can be default), available to test via this node:
You can also pretty safely run the uncond model as nvfp4.
EasyCache has pretty bad quality hit, but if you update ComfyUI to latest nightly, you can try using the CFG Override to drop cfg to 1.0, so those steps are twice as fast, that didn't seem to hit the quality as much.
Is there the possibility that you will ever be able to create a "patch" to make sage attention work with it?
Is there the possibility that you will ever be able to create a "patch" to make sage attention work with it?
Yes, I've already done tests with that for our upcoming comfy-kitchen sageattention implementation, and it is a lot faster at higher resolution.
woct0rdho is also doing work on that in his sageattn branch, didn't test that myself yet but should be functional if compile it yourself, no pre-built wheels yet from what I can see:
https://github.com/woct0rdho/SageAttention/commit/364d0534a4fbfe089347883648ee8dc501d0d7c8
This would need no node changes, just sageattention itself needs to be able to handle the head dim 256.
Is there the possibility that you will ever be able to create a "patch" to make sage attention work with it?
Yes, I've already done tests with that for our upcoming comfy-kitchen sageattention implementation, and it is a lot faster at higher resolution.
woct0rdho is also doing work on that in his sageattn branch, didn't test that myself yet but should be functional if compile it yourself, no pre-built wheels yet from what I can see:
https://github.com/woct0rdho/SageAttention/commit/364d0534a4fbfe089347883648ee8dc501d0d7c8
This would need no node changes, just sageattention itself needs to be able to handle the head dim 256.
Legend! Can't wait.
I compiled woct0rdho's SageAttention repo, some results on RTX 5090, Win 11, python 3.13, torch 2.11, Cuda 13.0:
1024x1024, 28-steps, res_2m
bl: 1.07s/it (or 0.93it/s)
fa: 1.23it/s
sa: 1.39it/s
2048x2048, 20-steps, res_2m
bl: 20/20 [02:35<00:00, 7.76s/it]
fa: 20/20 [01:29<00:00, 4.45s/it]
sa: 20/20 [01:04<00:00, 3.22s/it]
bl=baseline, fa=flash-attention, sa=sageattention (using KJ patcher at "auto" setting)
SageAttention provides a 1.38x speedup over flash-attention at 4Mpix resolution and 2.41x over the attention that us Windows users have by default.
The gains are substantially smaller at lower resolutions as seen in the 1Mpix results, "only" 1.49x over baseline.

