Kernels
mfuntowicz HF Staff commited on
Commit
784e399
·
verified ·
1 Parent(s): 1e6f4cf

Add Windows Kernel for PyTorch 2.9 + CUDA13 (#5)

Browse files

- Build uploaded using `kernels`. (5fafaf51c075affaba2853c085b6aeb73bb8da39)

.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.so filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.so filter=lfs diff=lfs merge=lfs -text
37
+ build/torch29-cu130-x86_64-windows/rotary/_rotary_a793e44.pyd filter=lfs diff=lfs merge=lfs -text
build/torch29-cu130-x86_64-windows/rotary/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ def apply_rotary(
8
+ x1: torch.Tensor,
9
+ x2: torch.Tensor,
10
+ cos: torch.Tensor,
11
+ sin: torch.Tensor,
12
+ out1: torch.Tensor,
13
+ out2: torch.Tensor,
14
+ conj: bool,
15
+ ) -> None:
16
+ ops.apply_rotary(x1, x2, cos, sin, out1, out2, conj)
17
+
18
+
19
+ def apply_rotary_transformers(
20
+ q: torch.Tensor,
21
+ k: torch.Tensor,
22
+ cos: torch.Tensor,
23
+ sin: torch.Tensor,
24
+ position_ids: Optional[torch.Tensor] = None,
25
+ unsqueeze_dim: int = 1,
26
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ """
28
+ Rotary kernel implementation wrapper
29
+ Adapts rotary kernels implementation to match transformers apply_rotary_pos_emb signature
30
+ """
31
+ cos = cos.unsqueeze(unsqueeze_dim)
32
+ sin = sin.unsqueeze(unsqueeze_dim)
33
+
34
+ q_rotated = q.clone()
35
+ k_rotated = k.clone()
36
+
37
+ # Get half dimension for rotation
38
+ half_dim = q.shape[-1] // 2
39
+ q1 = q_rotated[..., :half_dim]
40
+ q2 = q_rotated[..., half_dim:]
41
+ k1 = k_rotated[..., :half_dim]
42
+ k2 = k_rotated[..., half_dim:]
43
+ if cos.shape[-1] != half_dim:
44
+ # Trim cos/sin to match half_dim
45
+ cos = cos[..., :half_dim]
46
+ sin = sin[..., :half_dim]
47
+
48
+ apply_rotary(q1, q2, cos, sin, q1, q2, False)
49
+ apply_rotary(k1, k2, cos, sin, k1, k2, False)
50
+ return q_rotated, k_rotated
51
+
52
+
53
+ __all__ = ["apply_rotary", "apply_rotary_transformers"]
build/torch29-cu130-x86_64-windows/rotary/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _rotary_a793e44
3
+ ops = torch.ops._rotary_a793e44
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_rotary_a793e44::{op_name}"
build/torch29-cu130-x86_64-windows/rotary/_rotary_a793e44.pyd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:606c6eb81894dc8197f73e0e71a5356f56c61c612e5f77ab5c3d7c351eab8d3a
3
+ size 8007680