lulavc commited on
Commit
d0bf03c
·
verified ·
1 Parent(s): a1830c8

Add custom AoTI helper for pre-compiled blocks

Browse files
Files changed (1) hide show
  1. aoti.py +33 -0
aoti.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AoTI (Ahead of Time Inductor) helper for ZeroGPU
3
+ Loads pre-compiled transformer blocks from Hugging Face Hub
4
+ """
5
+
6
+ from typing import cast
7
+
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
11
+ from spaces.zero.torch.aoti import ZeroGPUWeights
12
+
13
+
14
+ def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
15
+ """
16
+ Load pre-compiled AoTI blocks from Hub repository.
17
+
18
+ Args:
19
+ module: The transformer module containing layers to replace
20
+ repo_id: HuggingFace repo with pre-compiled blocks (e.g., 'zerogpu-aoti/Z-Image')
21
+ variant: Optional variant like 'fa3' for FlashAttention-3 compiled blocks
22
+ """
23
+ repeated_blocks = cast(list[str], module._repeated_blocks)
24
+ aoti_files = {name: hf_hub_download(
25
+ repo_id=repo_id,
26
+ filename='package.pt2',
27
+ subfolder=name if variant is None else f'{name}.{variant}',
28
+ ) for name in repeated_blocks}
29
+ for block_name, aoti_file in aoti_files.items():
30
+ for block in module.modules():
31
+ if block.__class__.__name__ == block_name:
32
+ weights = ZeroGPUWeights(block.state_dict())
33
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)