dn6 HF Staff commited on
Commit
8c2cdba
·
verified ·
1 Parent(s): af08689

Upload transformer/attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transformer/attention.py +106 -25
transformer/attention.py CHANGED
@@ -36,6 +36,52 @@ try:
36
  except Exception as e:
37
  sageattn_func = None
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def _is_hopper_gpu():
41
  """Check if the current GPU is a Hopper architecture."""
@@ -44,41 +90,66 @@ def _is_hopper_gpu():
44
  device_name = torch.cuda.get_device_name(0).lower()
45
  return "h100" in device_name or "hopper" in device_name
46
 
 
47
  FLASH_ATTN_3_AVAILABLE = False
48
  try:
49
  import flash_attn_interface
 
50
  FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
51
- except ModuleNotFoundError:
52
- FLASH_ATTN_3_AVAILABLE = False
53
 
54
  FLASH_ATTN_3_HUB_AVAILABLE = False
55
  try:
56
- use_hub_kernels = os.getenv("DIFFUSERS_ENABLE_HUB_KERNELS", "false").upper() in ["1", "TRUE"]
57
  if use_hub_kernels and not is_kernels_available():
58
- raise EnvironmentError((
59
- "Attempting to use Hub Kernels for Flash Attention 3,"
60
- "but the `kernels` library was not found in your environment. "
61
- "Please install via `pip install kernels`"
62
- ))
 
 
63
 
64
  from kernels import get_kernel
 
65
  flash_attn_3_hub = get_kernel("kernels-community/flash-attn3", revision="fake-ops-return-probs")
66
 
67
  FLASH_ATTN_3_HUB_AVAILABLE = _is_hopper_gpu()
68
 
69
  except:
70
- FLASH_ATTN_3_HUB_AVAILABLE = False
71
 
72
  FLASH_ATTN_2_AVAILABLE = False
73
  try:
74
  import flash_attn
75
 
76
  FLASH_ATTN_2_AVAILABLE = True
77
- except ModuleNotFoundError:
78
- FLASH_ATTN_2_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  __all__ = ["flash_attention", "attention"]
81
 
 
82
  def flash_attention(
83
  q,
84
  k,
@@ -107,12 +178,19 @@ def flash_attention(
107
  deterministic: bool. If True, slightly slower and uses more memory.
108
  dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
109
  """
110
- if not FLASH_ATTN_3_AVAILABLE or not FLASH_ATTN_3_HUB_AVAILABLE:
111
- return flash_attn.flash_attn_func(
112
- q,
113
- k,
114
- v,
115
- )
 
 
 
 
 
 
 
116
 
117
  elif FLASH_ATTN_3_HUB_AVAILABLE:
118
  return flash_attn_3_hub.flash_attn_func(
@@ -182,7 +260,7 @@ def flash_attention(
182
  deterministic=deterministic,
183
  ).unflatten(0, (b, lq))
184
  else:
185
- assert FLASH_ATTN_3_AVAILABLE
186
  x = flash_attn.flash_attn_varlen_func(
187
  q=q,
188
  k=k,
@@ -222,9 +300,7 @@ def attention(
222
  fa_version=None,
223
  # og_dtype=torch.bfloat16,
224
  ):
225
-
226
- if SAGEATTN_AVAILABLE:
227
- # print("Using sageattention")
228
  attn_mask = None
229
 
230
  og_dtype = q.dtype
@@ -232,14 +308,19 @@ def attention(
232
  k = k.transpose(1, 2).to(dtype)
233
  v = v.transpose(1, 2).to(dtype)
234
 
235
- out = sageattn_func(
236
- q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
237
- )
 
 
 
 
 
238
 
239
  out = out.transpose(1, 2).contiguous().to(og_dtype)
240
  return out
241
 
242
- elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
243
  return flash_attention(
244
  q=q,
245
  k=k,
 
36
  except Exception as e:
37
  sageattn_func = None
38
 
39
+ use_hub_kernels = os.getenv("DIFFUSERS_ENABLE_HUB_KERNELS", "false").upper() in [
40
+ "1",
41
+ "TRUE",
42
+ ]
43
+
44
+ SAGEATTN_HUB_AVAILABLE = False
45
+ try:
46
+ if use_hub_kernels and not is_kernels_available():
47
+ raise EnvironmentError(
48
+ (
49
+ "Attempting to use Hub Kernels for Flash Attention 3,"
50
+ "but the `kernels` library was not found in your environment. "
51
+ "Please install via `pip install kernels`"
52
+ )
53
+ )
54
+ if os.getenv("DISABLE_SAGEATTENTION", "0") != "0":
55
+ raise Exception("DISABLE_SAGEATTENTION is set")
56
+
57
+ from kernels import get_kernel
58
+
59
+ sageattn_hub = get_kernel("kernels-community/sage_attention")
60
+
61
+ @torch.library.custom_op(
62
+ "mylib::sageattn_hub", mutates_args={"q", "k", "v"}, device_types="cuda"
63
+ )
64
+ def sageattn_hub_func(
65
+ q: torch.Tensor,
66
+ k: torch.Tensor,
67
+ v: torch.Tensor,
68
+ attn_mask: Optional[torch.Tensor] = None,
69
+ dropout_p: float = 0,
70
+ is_causal: bool = False,
71
+ ) -> torch.Tensor:
72
+ return sageattn_hub(
73
+ q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
74
+ )
75
+
76
+ @sageattn_func.register_fake
77
+ def _sageattn_fake(q, k, v, attn_mask=None, dropout_p=0, is_causal=False):
78
+ return torch.empty(*q.shape, device=q.device, dtype=q.dtype)
79
+
80
+ SAGEATTN_HUB_AVAILABLE = True
81
+
82
+ except Exception as e:
83
+ sageattn_hub_func = None
84
+
85
 
86
  def _is_hopper_gpu():
87
  """Check if the current GPU is a Hopper architecture."""
 
90
  device_name = torch.cuda.get_device_name(0).lower()
91
  return "h100" in device_name or "hopper" in device_name
92
 
93
+
94
  FLASH_ATTN_3_AVAILABLE = False
95
  try:
96
  import flash_attn_interface
97
+
98
  FLASH_ATTN_3_AVAILABLE = _is_hopper_gpu()
99
+ except:
100
+ flash_attn_interface = None
101
 
102
  FLASH_ATTN_3_HUB_AVAILABLE = False
103
  try:
 
104
  if use_hub_kernels and not is_kernels_available():
105
+ raise EnvironmentError(
106
+ (
107
+ "Attempting to use Hub Kernels for Flash Attention 3,"
108
+ "but the `kernels` library was not found in your environment. "
109
+ "Please install via `pip install kernels`"
110
+ )
111
+ )
112
 
113
  from kernels import get_kernel
114
+
115
  flash_attn_3_hub = get_kernel("kernels-community/flash-attn3", revision="fake-ops-return-probs")
116
 
117
  FLASH_ATTN_3_HUB_AVAILABLE = _is_hopper_gpu()
118
 
119
  except:
120
+ flash_attn_3_hub = None
121
 
122
  FLASH_ATTN_2_AVAILABLE = False
123
  try:
124
  import flash_attn
125
 
126
  FLASH_ATTN_2_AVAILABLE = True
127
+ except:
128
+ flash_attn = None
129
+
130
+
131
+ FLASH_ATTN_2_HUB_AVAILABLE = False
132
+ try:
133
+ if use_hub_kernels and not is_kernels_available():
134
+ raise EnvironmentError(
135
+ (
136
+ "Attempting to use Hub Kernels for Flash Attention 3,"
137
+ "but the `kernels` library was not found in your environment. "
138
+ "Please install via `pip install kernels`"
139
+ )
140
+ )
141
+
142
+ from kernels import get_kernel
143
+
144
+ flash_attn_2_hub = get_kernel("kernels-community/flash-attn2")
145
+
146
+ FLASH_ATTN_2_HUB_AVAILABLE = True
147
+ except:
148
+ flash_attn_2_hub = None
149
 
150
  __all__ = ["flash_attention", "attention"]
151
 
152
+
153
  def flash_attention(
154
  q,
155
  k,
 
178
  deterministic: bool. If True, slightly slower and uses more memory.
179
  dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
180
  """
181
+ if FLASH_ATTN_3_AVAILABLE and not FLASH_ATTN_3_HUB_AVAILABLE:
182
+ if FLASH_ATTN_2_HUB_AVAILABLE:
183
+ return flash_attn_2_hub.flash_attn_func(
184
+ q,
185
+ k,
186
+ v,
187
+ )
188
+ else:
189
+ return flash_attn.flash_attn_func(
190
+ q,
191
+ k,
192
+ v,
193
+ )
194
 
195
  elif FLASH_ATTN_3_HUB_AVAILABLE:
196
  return flash_attn_3_hub.flash_attn_func(
 
260
  deterministic=deterministic,
261
  ).unflatten(0, (b, lq))
262
  else:
263
+ assert FLASH_ATTN_2_AVAILABLE
264
  x = flash_attn.flash_attn_varlen_func(
265
  q=q,
266
  k=k,
 
300
  fa_version=None,
301
  # og_dtype=torch.bfloat16,
302
  ):
303
+ if SAGEATTN_AVAILABLE or SAGEATTN_HUB_AVAILABLE:
 
 
304
  attn_mask = None
305
 
306
  og_dtype = q.dtype
 
308
  k = k.transpose(1, 2).to(dtype)
309
  v = v.transpose(1, 2).to(dtype)
310
 
311
+ if SAGEATTN_HUB_AVAILABLE:
312
+ out = sageattn_hub_func(
313
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
314
+ )
315
+ else:
316
+ out = sageattn_func(
317
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
318
+ )
319
 
320
  out = out.transpose(1, 2).contiguous().to(og_dtype)
321
  return out
322
 
323
+ elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE or FLASH_ATTN_3_HUB_AVAILABLE:
324
  return flash_attention(
325
  q=q,
326
  k=k,