apply all2all scatter gather
Browse files- test/test_muon/muon.py +0 -1
- test/test_muon/optimizer +1 -0
- torch-ext/optimizer/matmul_transpose_triton.py +106 -0
- torch-ext/optimizer/muon.py +261 -93
test/test_muon/muon.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
../../torch-ext/optimizer/muon.py
|
|
|
|
|
|
test/test_muon/optimizer
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../../torch-ext/optimizer/
|
torch-ext/optimizer/matmul_transpose_triton.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_autotune_config():
|
| 7 |
+
return [
|
| 8 |
+
triton.Config(
|
| 9 |
+
{
|
| 10 |
+
'BLOCK_SIZE_M': blk_m,
|
| 11 |
+
'BLOCK_SIZE_K': blk_k,
|
| 12 |
+
'GROUP_SIZE_M': grp_sz
|
| 13 |
+
},
|
| 14 |
+
num_stages=n_stages,
|
| 15 |
+
num_warps=n_warps) for blk_m in [32, 64, 128]
|
| 16 |
+
for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
|
| 17 |
+
for n_warps in [4, 8]
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@triton.autotune(
|
| 22 |
+
configs=get_autotune_config(),
|
| 23 |
+
key=['M', 'K'],
|
| 24 |
+
)
|
| 25 |
+
@triton.jit
|
| 26 |
+
def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
|
| 27 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
| 28 |
+
GROUP_SIZE_M: tl.constexpr):
|
| 29 |
+
"""
|
| 30 |
+
Core kernel jit function of matmul_transpose that computes y = x @ x.T
|
| 31 |
+
The code is a simple adaptation from the triton `matmul` tutorial:
|
| 32 |
+
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
| 33 |
+
"""
|
| 34 |
+
pid = tl.program_id(axis=0)
|
| 35 |
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
| 36 |
+
num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
|
| 37 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
| 38 |
+
group_id = pid // num_pid_in_group
|
| 39 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
| 40 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
| 41 |
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
| 42 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
| 43 |
+
if pid_m > pid_n:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 47 |
+
offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
| 48 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
| 49 |
+
# we use a & b ptrs to denote different rows of x.
|
| 50 |
+
a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
|
| 51 |
+
b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
|
| 52 |
+
|
| 53 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
|
| 54 |
+
|
| 55 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
| 56 |
+
a = tl.load(a_ptrs,
|
| 57 |
+
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
| 58 |
+
other=0.0)
|
| 59 |
+
b = tl.load(b_ptrs,
|
| 60 |
+
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
| 61 |
+
other=0.0)
|
| 62 |
+
accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
|
| 63 |
+
a_ptrs += BLOCK_SIZE_K * stride_xk
|
| 64 |
+
b_ptrs += BLOCK_SIZE_K * stride_xk
|
| 65 |
+
# use dtype.element_ty to accommodate different input datatypes as in cpp templates
|
| 66 |
+
# https://github.com/triton-lang/triton/issues/2252
|
| 67 |
+
c = accumulator.to(x.dtype.element_ty)
|
| 68 |
+
|
| 69 |
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 70 |
+
offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
| 71 |
+
c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
|
| 72 |
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
|
| 73 |
+
tl.store(c_ptrs, c, mask=c_mask)
|
| 74 |
+
|
| 75 |
+
# transpose and copy
|
| 76 |
+
if pid_m < pid_n:
|
| 77 |
+
ct_ptrs = y + stride_ym * offs_cn[:,
|
| 78 |
+
None] + stride_yn * offs_cm[None, :]
|
| 79 |
+
ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
|
| 80 |
+
tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def matmul_transpose_assign(d_in, d_out):
|
| 84 |
+
assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
|
| 85 |
+
assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
|
| 86 |
+
assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
|
| 87 |
+
assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
|
| 88 |
+
assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
|
| 89 |
+
assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
|
| 90 |
+
assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
|
| 91 |
+
"First dimension of `d_in` must match first and second dimension of `d_out`"
|
| 92 |
+
|
| 93 |
+
d_in = d_in.contiguous()
|
| 94 |
+
M, K = d_in.shape
|
| 95 |
+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
| 96 |
+
M, META['BLOCK_SIZE_M']), )
|
| 97 |
+
with torch.cuda.device(d_in.device.index):
|
| 98 |
+
mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
|
| 99 |
+
d_out.stride(0), d_out.stride(1))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def matmul_transpose(d_in):
|
| 103 |
+
M, _ = d_in.shape
|
| 104 |
+
d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
|
| 105 |
+
matmul_transpose_assign(d_in, d_out)
|
| 106 |
+
return d_out
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -8,6 +8,8 @@ import torch
|
|
| 8 |
import torch.distributed as dist
|
| 9 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 10 |
|
|
|
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
|
@@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|
| 16 |
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 17 |
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 18 |
@torch.no_grad()
|
|
|
|
| 19 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 20 |
"""
|
| 21 |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
@@ -28,12 +31,15 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 28 |
"""
|
| 29 |
assert len(G.shape) == 2
|
| 30 |
assert G.dtype == torch.bfloat16
|
|
|
|
| 31 |
X = G # no manual typecast
|
| 32 |
|
| 33 |
if G.size(0) > G.size(1):
|
| 34 |
X = X.T
|
| 35 |
# Ensure spectral norm is at most 1
|
| 36 |
X = X / (X.norm() + 1e-7)
|
|
|
|
|
|
|
| 37 |
# Perform the NS iterations
|
| 38 |
for a, b, c in [
|
| 39 |
(4.0848, -6.8946, 2.9270),
|
|
@@ -42,16 +48,14 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 42 |
(2.8769, -3.1427, 1.2046),
|
| 43 |
(2.8366, -3.0525, 1.2012),
|
| 44 |
]:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
B = torch.addmm(A, A, A, alpha=c, beta=b)
|
| 50 |
-
# X = a * X + B @ X
|
| 51 |
-
X = torch.addmm(X, B, X, alpha=1.0, beta=a)
|
| 52 |
|
| 53 |
if G.size(0) > G.size(1):
|
| 54 |
X = X.T
|
|
|
|
| 55 |
return X
|
| 56 |
|
| 57 |
|
|
@@ -69,51 +73,130 @@ class _muon_state:
|
|
| 69 |
qk_clip_state = None
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
@torch.no_grad()
|
| 73 |
-
def
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
with torch.cuda.stream(comm_stream):
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
@torch.no_grad()
|
|
@@ -127,45 +210,120 @@ def _compute_u(p, state, steps, rank, compute_stream):
|
|
| 127 |
raise RuntimeError("Gather event must be set before compute.")
|
| 128 |
compute_stream.wait_event(state.gather_event)
|
| 129 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
|
|
|
| 130 |
state.computed_u = u
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
|
| 137 |
|
| 138 |
@torch.no_grad()
|
| 139 |
-
def
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
with torch.cuda.stream(comm_stream):
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# Clear the gathered gradient to free memory
|
| 152 |
-
state.gathered_grad = None
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
@@ -585,11 +743,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 585 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 586 |
names, params, group, qk_logits)
|
| 587 |
|
| 588 |
-
def
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
|
| 594 |
def enqueue_computes(start_idx, chunk_size):
|
| 595 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
|
@@ -597,10 +759,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 598 |
self.compute_stream)
|
| 599 |
|
| 600 |
-
def
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
def enqueue_update_param(start_idx, chunk_size):
|
| 606 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
|
@@ -615,14 +781,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 615 |
# Wait grad update
|
| 616 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
| 617 |
|
| 618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
for i in range(0, len(params) + chunk_size - 1, chunk_size):
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
enqueue_scatters(i, chunk_size)
|
| 625 |
-
enqueue_update_param(i, chunk_size)
|
| 626 |
|
| 627 |
# Wait the last update_param to finish
|
| 628 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
|
|
|
| 8 |
import torch.distributed as dist
|
| 9 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 10 |
|
| 11 |
+
from .matmul_transpose_triton import matmul_transpose_assign
|
| 12 |
+
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
|
|
|
|
| 18 |
# Muon's Newton–Schulz iteration causes high variance in singular values
|
| 19 |
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
| 20 |
@torch.no_grad()
|
| 21 |
+
# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
|
| 22 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 23 |
"""
|
| 24 |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
|
|
| 31 |
"""
|
| 32 |
assert len(G.shape) == 2
|
| 33 |
assert G.dtype == torch.bfloat16
|
| 34 |
+
G = G.to(thorch.float32)
|
| 35 |
X = G # no manual typecast
|
| 36 |
|
| 37 |
if G.size(0) > G.size(1):
|
| 38 |
X = X.T
|
| 39 |
# Ensure spectral norm is at most 1
|
| 40 |
X = X / (X.norm() + 1e-7)
|
| 41 |
+
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 42 |
+
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 43 |
# Perform the NS iterations
|
| 44 |
for a, b, c in [
|
| 45 |
(4.0848, -6.8946, 2.9270),
|
|
|
|
| 48 |
(2.8769, -3.1427, 1.2046),
|
| 49 |
(2.8366, -3.0525, 1.2012),
|
| 50 |
]:
|
| 51 |
+
matmul_transpose_assign(X, buf1)
|
| 52 |
+
matmul_transpose_assign(buf1, buf2)
|
| 53 |
+
buf1.mul_(b).add_(buf2, alpha=c)
|
| 54 |
+
X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
if G.size(0) > G.size(1):
|
| 57 |
X = X.T
|
| 58 |
+
X = X.to(torch.bfloat16)
|
| 59 |
return X
|
| 60 |
|
| 61 |
|
|
|
|
| 73 |
qk_clip_state = None
|
| 74 |
|
| 75 |
|
| 76 |
+
def split_elems_for_src(param, state, src_rank, num_ranks) -> int:
|
| 77 |
+
rows = param.shape[0]
|
| 78 |
+
cols = int(param.numel() // rows)
|
| 79 |
+
base, rem = divmod(rows, num_ranks)
|
| 80 |
+
my_rows = base + (1 if src_rank < rem else 0)
|
| 81 |
+
return my_rows * cols
|
| 82 |
+
|
| 83 |
+
|
| 84 |
@torch.no_grad()
|
| 85 |
+
def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
|
| 86 |
+
with torch.cuda.stream(compute_stream):
|
| 87 |
+
for p in params:
|
| 88 |
+
state = param_to_state[id(p)]
|
| 89 |
+
if rank == state.worker_rank:
|
| 90 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 91 |
+
state.gathered_grad = torch.empty(p.grad.numel(),
|
| 92 |
+
dtype=torch.bfloat16,
|
| 93 |
+
device="cuda")
|
| 94 |
+
else:
|
| 95 |
+
state.gathered_grad = None
|
| 96 |
+
|
| 97 |
+
alloc_event = torch.cuda.Event()
|
| 98 |
+
alloc_event.record(compute_stream)
|
| 99 |
+
return alloc_event
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
| 104 |
+
alloc_event):
|
| 105 |
with torch.cuda.stream(comm_stream):
|
| 106 |
+
process_group = param_to_state[id(params[0])].process_group
|
| 107 |
+
num_ranks = dist.get_world_size(group=process_group)
|
| 108 |
|
| 109 |
+
# Calculate sending tensors
|
| 110 |
+
per_dst = [[] for _ in range(num_ranks)]
|
| 111 |
+
send_counts = [0] * num_ranks
|
| 112 |
+
for p in params:
|
| 113 |
+
state = param_to_state[id(p)]
|
| 114 |
+
dst = state.worker_rank
|
| 115 |
+
shard_elems = split_elems_for_src(p, state, rank, num_ranks)
|
| 116 |
+
g = p.grad
|
| 117 |
+
g = g.to_local().to(torch.bfloat16).contiguous().view(-1)
|
| 118 |
+
assert g.numel() == shard_elems
|
| 119 |
+
per_dst[dst].append(g)
|
| 120 |
+
send_counts[dst] += shard_elems
|
| 121 |
+
|
| 122 |
+
assert all(
|
| 123 |
+
len(v) > 0
|
| 124 |
+
for v in per_dst), "all params should be sharded to all devices"
|
| 125 |
+
|
| 126 |
+
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 127 |
+
owner_params = [
|
| 128 |
+
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
# Calculate receiving tensors
|
| 132 |
+
recv_counts = [0] * num_ranks
|
| 133 |
+
for src in range(num_ranks):
|
| 134 |
+
total = 0
|
| 135 |
+
for p in owner_params:
|
| 136 |
+
state = param_to_state[id(p)]
|
| 137 |
+
assert state.worker_rank == rank
|
| 138 |
+
total += split_elems_for_src(p, state, src, num_ranks)
|
| 139 |
+
recv_counts[src] = total
|
| 140 |
+
|
| 141 |
+
recv_total = sum(recv_counts)
|
| 142 |
+
recv_buf = torch.empty(recv_total, dtype=torch.bfloat16, device="cuda")
|
| 143 |
+
dist.all_to_all_single(
|
| 144 |
+
recv_buf,
|
| 145 |
+
send_buf,
|
| 146 |
+
output_split_sizes=recv_counts,
|
| 147 |
+
input_split_sizes=send_counts,
|
| 148 |
+
group=process_group,
|
| 149 |
)
|
| 150 |
+
|
| 151 |
+
# Reconstructs gathered grad from the received buffer
|
| 152 |
+
#
|
| 153 |
+
# recv_buf (num ranks = 3)
|
| 154 |
+
#
|
| 155 |
+
# From rank 0 From rank 1 From rank 2
|
| 156 |
+
# | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
|
| 157 |
+
#
|
| 158 |
+
# Outer loop:
|
| 159 |
+
# rank 0 -> rank 1 -> rank2
|
| 160 |
+
#
|
| 161 |
+
# Inner loop:
|
| 162 |
+
# p1_n -> p2_n -> p3_n
|
| 163 |
+
|
| 164 |
+
comm_stream.wait_event(alloc_event)
|
| 165 |
+
|
| 166 |
+
off = 0
|
| 167 |
+
write_offsets = {id(p): 0 for p in owner_params}
|
| 168 |
+
for src in range(num_ranks):
|
| 169 |
+
if recv_counts[src] == 0:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
block = recv_counts[src]
|
| 173 |
+
inner_off = 0
|
| 174 |
+
for p in owner_params:
|
| 175 |
+
state = param_to_state[id(p)]
|
| 176 |
+
assert state.worker_rank == rank
|
| 177 |
+
n = split_elems_for_src(p, state, src, num_ranks)
|
| 178 |
+
assert n > 0
|
| 179 |
+
|
| 180 |
+
sg = recv_buf.narrow(0, off + inner_off, n)
|
| 181 |
+
woff = write_offsets[id(p)]
|
| 182 |
+
dst = state.gathered_grad.narrow(0, woff, n)
|
| 183 |
+
dst.copy_(sg)
|
| 184 |
+
|
| 185 |
+
write_offsets[id(p)] += n
|
| 186 |
+
inner_off += n
|
| 187 |
+
off += block
|
| 188 |
+
|
| 189 |
+
for p in params:
|
| 190 |
+
state = param_to_state[id(p)]
|
| 191 |
+
if state.worker_rank == rank:
|
| 192 |
+
state.gathered_grad = state.gathered_grad.view_as(p)
|
| 193 |
+
state.gather_event = torch.cuda.Event()
|
| 194 |
+
state.gather_event.record(comm_stream)
|
| 195 |
+
else:
|
| 196 |
+
state.gathered_grad = None
|
| 197 |
+
state.gather_event = None
|
| 198 |
+
if none_grad:
|
| 199 |
+
p.grad = None
|
| 200 |
|
| 201 |
|
| 202 |
@torch.no_grad()
|
|
|
|
| 210 |
raise RuntimeError("Gather event must be set before compute.")
|
| 211 |
compute_stream.wait_event(state.gather_event)
|
| 212 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
| 213 |
+
state.gathered_grad = None
|
| 214 |
state.computed_u = u
|
| 215 |
+
state.compute_event = torch.cuda.Event()
|
| 216 |
+
state.compute_event.record()
|
| 217 |
+
else:
|
| 218 |
+
state.computed_u = None
|
| 219 |
+
state.compute_event = None
|
| 220 |
|
| 221 |
|
| 222 |
@torch.no_grad()
|
| 223 |
+
def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
|
| 224 |
+
with torch.cuda.stream(compute_stream):
|
| 225 |
+
for p in params:
|
| 226 |
+
state = param_to_state[id(p)]
|
| 227 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
| 228 |
+
dtype=torch.bfloat16)
|
| 229 |
|
| 230 |
+
alloc_event = torch.cuda.Event()
|
| 231 |
+
alloc_event.record(compute_stream)
|
| 232 |
+
return alloc_event
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
| 236 |
with torch.cuda.stream(comm_stream):
|
| 237 |
+
process_group = param_to_state[id(params[0])].process_group
|
| 238 |
+
num_ranks = dist.get_world_size(group=process_group)
|
| 239 |
+
owner_params = [
|
| 240 |
+
p for p in params if param_to_state[id(p)].worker_rank == rank
|
| 241 |
+
]
|
| 242 |
|
| 243 |
+
per_dst = [[] for _ in range(num_ranks)]
|
| 244 |
+
send_counts = [0] * num_ranks
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
if owner_params:
|
| 247 |
+
for p in owner_params:
|
| 248 |
+
state = param_to_state[id(p)]
|
| 249 |
+
if state.compute_event is None:
|
| 250 |
+
raise RuntimeError(
|
| 251 |
+
"Compute event must be set before scatter.")
|
| 252 |
+
comm_stream.wait_event(state.compute_event)
|
| 253 |
+
state.gathered_grad = None
|
| 254 |
|
| 255 |
+
assert state.computed_u is not None
|
| 256 |
+
|
| 257 |
+
u_full = state.computed_u.to(
|
| 258 |
+
torch.bfloat16).contiguous().view(-1)
|
| 259 |
+
|
| 260 |
+
offset = 0
|
| 261 |
+
for dst in range(num_ranks):
|
| 262 |
+
n = split_elems_for_src(p, state, dst, num_ranks)
|
| 263 |
+
assert n > 0
|
| 264 |
+
|
| 265 |
+
su = u_full.narrow(0, offset, n)
|
| 266 |
+
per_dst[dst].append(su)
|
| 267 |
+
send_counts[dst] += n
|
| 268 |
+
offset += n
|
| 269 |
+
|
| 270 |
+
assert offset == u_full.numel()
|
| 271 |
+
|
| 272 |
+
if any(len(v) > 0 for v in per_dst):
|
| 273 |
+
send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
|
| 274 |
+
else:
|
| 275 |
+
# all_to_all requires participation from all ranks
|
| 276 |
+
# Even non-owner ranks must join the collective call
|
| 277 |
+
send_buf = torch.empty(0, dtype=torch.bfloat16, device="cuda")
|
| 278 |
+
|
| 279 |
+
recv_counts = [0] * num_ranks
|
| 280 |
+
for src in range(num_ranks):
|
| 281 |
+
total = 0
|
| 282 |
+
for p in params:
|
| 283 |
+
state = param_to_state[id(p)]
|
| 284 |
+
if state.worker_rank != src:
|
| 285 |
+
continue
|
| 286 |
+
total += split_elems_for_src(p, state, rank, num_ranks)
|
| 287 |
+
recv_counts[src] = total
|
| 288 |
+
|
| 289 |
+
recv_total = sum(recv_counts)
|
| 290 |
+
assert recv_total > 0
|
| 291 |
+
recv_buf = torch.empty(recv_total, dtype=torch.bfloat16, device="cuda")
|
| 292 |
+
|
| 293 |
+
dist.all_to_all_single(
|
| 294 |
+
recv_buf,
|
| 295 |
+
send_buf,
|
| 296 |
+
output_split_sizes=recv_counts,
|
| 297 |
+
input_split_sizes=send_counts,
|
| 298 |
+
group=process_group,
|
| 299 |
)
|
| 300 |
+
|
| 301 |
+
comm_stream.wait_event(alloc_event)
|
| 302 |
+
|
| 303 |
+
off = 0
|
| 304 |
+
for src in range(num_ranks):
|
| 305 |
+
block = recv_counts[src]
|
| 306 |
+
if block == 0:
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
inner_off = 0
|
| 310 |
+
for p in params:
|
| 311 |
+
state = param_to_state[id(p)]
|
| 312 |
+
if state.worker_rank != src:
|
| 313 |
+
continue
|
| 314 |
+
n = split_elems_for_src(p, state, rank, num_ranks)
|
| 315 |
+
assert n > 0
|
| 316 |
+
|
| 317 |
+
flat_local = recv_buf.narrow(0, off + inner_off,
|
| 318 |
+
n).view_as(p.to_local())
|
| 319 |
+
state.scattered_u.copy_(flat_local)
|
| 320 |
+
|
| 321 |
+
state.scatter_event = torch.cuda.Event()
|
| 322 |
+
state.scatter_event.record(comm_stream)
|
| 323 |
+
inner_off += n
|
| 324 |
+
|
| 325 |
+
assert inner_off == block
|
| 326 |
+
off += block
|
| 327 |
|
| 328 |
|
| 329 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
|
| 743 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 744 |
names, params, group, qk_logits)
|
| 745 |
|
| 746 |
+
def enqueue_all2all_gather(start_idx, chunk_size):
|
| 747 |
+
target_params = ordered_params[start_idx:start_idx + chunk_size]
|
| 748 |
+
if target_params:
|
| 749 |
+
alloc_event = _alloc_gathered_grad(target_params,
|
| 750 |
+
param_to_state, self.rank,
|
| 751 |
+
self.compute_stream)
|
| 752 |
+
_all2all_gather(target_params, param_to_state, self.rank,
|
| 753 |
+
self.comm_stream, group["none_grad"],
|
| 754 |
+
alloc_event)
|
| 755 |
|
| 756 |
def enqueue_computes(start_idx, chunk_size):
|
| 757 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
|
|
|
| 759 |
_compute_u(p, state, group["ns_steps"], self.rank,
|
| 760 |
self.compute_stream)
|
| 761 |
|
| 762 |
+
def enqueue_all2all_scatter(start_idx, chunk_size):
|
| 763 |
+
target_params = ordered_params[start_idx:start_idx + chunk_size]
|
| 764 |
+
if target_params:
|
| 765 |
+
alloc_event = _alloc_scattered_u(target_params, param_to_state,
|
| 766 |
+
self.rank,
|
| 767 |
+
self.compute_stream)
|
| 768 |
+
_all2all_scatter(target_params, param_to_state, self.rank,
|
| 769 |
+
self.comm_stream, alloc_event)
|
| 770 |
|
| 771 |
def enqueue_update_param(start_idx, chunk_size):
|
| 772 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
|
|
|
| 781 |
# Wait grad update
|
| 782 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
| 783 |
|
| 784 |
+
PRE_STEP = 5
|
| 785 |
+
for i in range(0, PRE_STEP):
|
| 786 |
+
enqueue_all2all_gather(i * chunk_size, chunk_size)
|
| 787 |
+
enqueue_computes(i * chunk_size, chunk_size)
|
| 788 |
+
|
| 789 |
for i in range(0, len(params) + chunk_size - 1, chunk_size):
|
| 790 |
+
enqueue_all2all_scatter(i, chunk_size)
|
| 791 |
+
enqueue_all2all_gather(i + PRE_STEP * chunk_size, chunk_size)
|
| 792 |
+
enqueue_update_param(i, chunk_size)
|
| 793 |
+
enqueue_computes(i + PRE_STEP * chunk_size, chunk_size)
|
|
|
|
|
|
|
| 794 |
|
| 795 |
# Wait the last update_param to finish
|
| 796 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|