linoyts HF Staff commited on
Commit
c6f7d14
·
verified ·
1 Parent(s): cda033e

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +41 -0
optimization.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+ import spaces
8
+ import torch
9
+
10
+
11
+ P = ParamSpec('P')
12
+
13
+
14
+ INDUCTOR_CONFIGS = {
15
+ 'conv_1x1_as_mm': True,
16
+ 'epilogue_fusion': False,
17
+ 'coordinate_descent_tuning': True,
18
+ 'coordinate_descent_check_all_directions': True,
19
+ 'max_autotune': True,
20
+ 'triton.cudagraphs': True,
21
+ }
22
+
23
+
24
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
25
+
26
+ @spaces.GPU(duration=1500)
27
+ def compile_transformer():
28
+
29
+ with spaces.aoti_capture(pipeline.transformer) as call:
30
+ pipeline(*args, **kwargs)
31
+
32
+ exported = torch.export.export(
33
+ mod=pipeline.transformer,
34
+ args=call.args,
35
+ kwargs=call.kwargs,
36
+ )
37
+
38
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
39
+
40
+ pipeline.transformer.fuse_qkv_projections()
41
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)