Kernels
ca1207 commited on
Commit
ff6d675
·
1 Parent(s): d65066c

apply all2all scatter gather

Browse files
test/test_muon/muon.py DELETED
@@ -1 +0,0 @@
1
- ../../torch-ext/optimizer/muon.py
 
 
test/test_muon/optimizer ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../torch-ext/optimizer/
torch-ext/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def get_autotune_config():
7
+ return [
8
+ triton.Config(
9
+ {
10
+ 'BLOCK_SIZE_M': blk_m,
11
+ 'BLOCK_SIZE_K': blk_k,
12
+ 'GROUP_SIZE_M': grp_sz
13
+ },
14
+ num_stages=n_stages,
15
+ num_warps=n_warps) for blk_m in [32, 64, 128]
16
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
17
+ for n_warps in [4, 8]
18
+ ]
19
+
20
+
21
+ @triton.autotune(
22
+ configs=get_autotune_config(),
23
+ key=['M', 'K'],
24
+ )
25
+ @triton.jit
26
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
27
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
+ GROUP_SIZE_M: tl.constexpr):
29
+ """
30
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
31
+ The code is a simple adaptation from the triton `matmul` tutorial:
32
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
33
+ """
34
+ pid = tl.program_id(axis=0)
35
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
36
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
37
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
38
+ group_id = pid // num_pid_in_group
39
+ first_pid_m = group_id * GROUP_SIZE_M
40
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
41
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
42
+ pid_n = (pid % num_pid_in_group) // group_size_m
43
+ if pid_m > pid_n:
44
+ return
45
+
46
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
47
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
48
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
49
+ # we use a & b ptrs to denote different rows of x.
50
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
51
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
52
+
53
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
54
+
55
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
56
+ a = tl.load(a_ptrs,
57
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
58
+ other=0.0)
59
+ b = tl.load(b_ptrs,
60
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
61
+ other=0.0)
62
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
63
+ a_ptrs += BLOCK_SIZE_K * stride_xk
64
+ b_ptrs += BLOCK_SIZE_K * stride_xk
65
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
66
+ # https://github.com/triton-lang/triton/issues/2252
67
+ c = accumulator.to(x.dtype.element_ty)
68
+
69
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
70
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
72
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
73
+ tl.store(c_ptrs, c, mask=c_mask)
74
+
75
+ # transpose and copy
76
+ if pid_m < pid_n:
77
+ ct_ptrs = y + stride_ym * offs_cn[:,
78
+ None] + stride_yn * offs_cm[None, :]
79
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
80
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
81
+
82
+
83
+ def matmul_transpose_assign(d_in, d_out):
84
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
85
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
86
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
87
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
88
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
89
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
90
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
91
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
92
+
93
+ d_in = d_in.contiguous()
94
+ M, K = d_in.shape
95
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
96
+ M, META['BLOCK_SIZE_M']), )
97
+ with torch.cuda.device(d_in.device.index):
98
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
99
+ d_out.stride(0), d_out.stride(1))
100
+
101
+
102
+ def matmul_transpose(d_in):
103
+ M, _ = d_in.shape
104
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
105
+ matmul_transpose_assign(d_in, d_out)
106
+ return d_out
torch-ext/optimizer/muon.py CHANGED
@@ -8,6 +8,8 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
@@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -28,12 +31,15 @@ def _zeropower_via_newtonschulz5(G, steps):
28
  """
29
  assert len(G.shape) == 2
30
  assert G.dtype == torch.bfloat16
 
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,16 +48,14 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
 
55
  return X
56
 
57
 
@@ -69,51 +73,130 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
- """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
77
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +210,120 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
- """
141
- Scatter the computed_u from worker_rank to all ranks.
142
- """
 
 
143
 
 
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
157
- else:
158
- scatter_list = None
 
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -585,11 +743,15 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +759,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +781,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
 
 
18
  # Muon's Newton–Schulz iteration causes high variance in singular values
19
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
20
  @torch.no_grad()
21
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
22
  def _zeropower_via_newtonschulz5(G, steps):
23
  """
24
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
31
  """
32
  assert len(G.shape) == 2
33
  assert G.dtype == torch.bfloat16
34
+ G = G.to(thorch.float32)
35
  X = G # no manual typecast
36
 
37
  if G.size(0) > G.size(1):
38
  X = X.T
39
  # Ensure spectral norm is at most 1
40
  X = X / (X.norm() + 1e-7)
41
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
42
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
  # Perform the NS iterations
44
  for a, b, c in [
45
  (4.0848, -6.8946, 2.9270),
 
48
  (2.8769, -3.1427, 1.2046),
49
  (2.8366, -3.0525, 1.2012),
50
  ]:
51
+ matmul_transpose_assign(X, buf1)
52
+ matmul_transpose_assign(buf1, buf2)
53
+ buf1.mul_(b).add_(buf2, alpha=c)
54
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
55
 
56
  if G.size(0) > G.size(1):
57
  X = X.T
58
+ X = X.to(torch.bfloat16)
59
  return X
60
 
61
 
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, state, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
+ with torch.cuda.stream(compute_stream):
87
+ for p in params:
88
+ state = param_to_state[id(p)]
89
+ if rank == state.worker_rank:
90
+ num_ranks = dist.get_world_size(group=state.process_group)
91
+ state.gathered_grad = torch.empty(p.grad.numel(),
92
+ dtype=torch.bfloat16,
93
+ device="cuda")
94
+ else:
95
+ state.gathered_grad = None
96
+
97
+ alloc_event = torch.cuda.Event()
98
+ alloc_event.record(compute_stream)
99
+ return alloc_event
100
+
101
+
102
+ @torch.no_grad()
103
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
104
+ alloc_event):
105
  with torch.cuda.stream(comm_stream):
106
+ process_group = param_to_state[id(params[0])].process_group
107
+ num_ranks = dist.get_world_size(group=process_group)
108
 
109
+ # Calculate sending tensors
110
+ per_dst = [[] for _ in range(num_ranks)]
111
+ send_counts = [0] * num_ranks
112
+ for p in params:
113
+ state = param_to_state[id(p)]
114
+ dst = state.worker_rank
115
+ shard_elems = split_elems_for_src(p, state, rank, num_ranks)
116
+ g = p.grad
117
+ g = g.to_local().to(torch.bfloat16).contiguous().view(-1)
118
+ assert g.numel() == shard_elems
119
+ per_dst[dst].append(g)
120
+ send_counts[dst] += shard_elems
121
+
122
+ assert all(
123
+ len(v) > 0
124
+ for v in per_dst), "all params should be sharded to all devices"
125
+
126
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
127
+ owner_params = [
128
+ p for p in params if param_to_state[id(p)].worker_rank == rank
129
+ ]
130
+
131
+ # Calculate receiving tensors
132
+ recv_counts = [0] * num_ranks
133
+ for src in range(num_ranks):
134
+ total = 0
135
+ for p in owner_params:
136
+ state = param_to_state[id(p)]
137
+ assert state.worker_rank == rank
138
+ total += split_elems_for_src(p, state, src, num_ranks)
139
+ recv_counts[src] = total
140
+
141
+ recv_total = sum(recv_counts)
142
+ recv_buf = torch.empty(recv_total, dtype=torch.bfloat16, device="cuda")
143
+ dist.all_to_all_single(
144
+ recv_buf,
145
+ send_buf,
146
+ output_split_sizes=recv_counts,
147
+ input_split_sizes=send_counts,
148
+ group=process_group,
149
  )
150
+
151
+ # Reconstructs gathered grad from the received buffer
152
+ #
153
+ # recv_buf (num ranks = 3)
154
+ #
155
+ # From rank 0 From rank 1 From rank 2
156
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
157
+ #
158
+ # Outer loop:
159
+ # rank 0 -> rank 1 -> rank2
160
+ #
161
+ # Inner loop:
162
+ # p1_n -> p2_n -> p3_n
163
+
164
+ comm_stream.wait_event(alloc_event)
165
+
166
+ off = 0
167
+ write_offsets = {id(p): 0 for p in owner_params}
168
+ for src in range(num_ranks):
169
+ if recv_counts[src] == 0:
170
+ continue
171
+
172
+ block = recv_counts[src]
173
+ inner_off = 0
174
+ for p in owner_params:
175
+ state = param_to_state[id(p)]
176
+ assert state.worker_rank == rank
177
+ n = split_elems_for_src(p, state, src, num_ranks)
178
+ assert n > 0
179
+
180
+ sg = recv_buf.narrow(0, off + inner_off, n)
181
+ woff = write_offsets[id(p)]
182
+ dst = state.gathered_grad.narrow(0, woff, n)
183
+ dst.copy_(sg)
184
+
185
+ write_offsets[id(p)] += n
186
+ inner_off += n
187
+ off += block
188
+
189
+ for p in params:
190
+ state = param_to_state[id(p)]
191
+ if state.worker_rank == rank:
192
+ state.gathered_grad = state.gathered_grad.view_as(p)
193
+ state.gather_event = torch.cuda.Event()
194
+ state.gather_event.record(comm_stream)
195
+ else:
196
+ state.gathered_grad = None
197
+ state.gather_event = None
198
+ if none_grad:
199
+ p.grad = None
200
 
201
 
202
  @torch.no_grad()
 
210
  raise RuntimeError("Gather event must be set before compute.")
211
  compute_stream.wait_event(state.gather_event)
212
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
213
+ state.gathered_grad = None
214
  state.computed_u = u
215
+ state.compute_event = torch.cuda.Event()
216
+ state.compute_event.record()
217
+ else:
218
+ state.computed_u = None
219
+ state.compute_event = None
220
 
221
 
222
  @torch.no_grad()
223
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
224
+ with torch.cuda.stream(compute_stream):
225
+ for p in params:
226
+ state = param_to_state[id(p)]
227
+ state.scattered_u = torch.empty_like(p.to_local(),
228
+ dtype=torch.bfloat16)
229
 
230
+ alloc_event = torch.cuda.Event()
231
+ alloc_event.record(compute_stream)
232
+ return alloc_event
233
+
234
+
235
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
236
  with torch.cuda.stream(comm_stream):
237
+ process_group = param_to_state[id(params[0])].process_group
238
+ num_ranks = dist.get_world_size(group=process_group)
239
+ owner_params = [
240
+ p for p in params if param_to_state[id(p)].worker_rank == rank
241
+ ]
242
 
243
+ per_dst = [[] for _ in range(num_ranks)]
244
+ send_counts = [0] * num_ranks
 
 
245
 
246
+ if owner_params:
247
+ for p in owner_params:
248
+ state = param_to_state[id(p)]
249
+ if state.compute_event is None:
250
+ raise RuntimeError(
251
+ "Compute event must be set before scatter.")
252
+ comm_stream.wait_event(state.compute_event)
253
+ state.gathered_grad = None
254
 
255
+ assert state.computed_u is not None
256
+
257
+ u_full = state.computed_u.to(
258
+ torch.bfloat16).contiguous().view(-1)
259
+
260
+ offset = 0
261
+ for dst in range(num_ranks):
262
+ n = split_elems_for_src(p, state, dst, num_ranks)
263
+ assert n > 0
264
+
265
+ su = u_full.narrow(0, offset, n)
266
+ per_dst[dst].append(su)
267
+ send_counts[dst] += n
268
+ offset += n
269
+
270
+ assert offset == u_full.numel()
271
+
272
+ if any(len(v) > 0 for v in per_dst):
273
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
274
+ else:
275
+ # all_to_all requires participation from all ranks
276
+ # Even non-owner ranks must join the collective call
277
+ send_buf = torch.empty(0, dtype=torch.bfloat16, device="cuda")
278
+
279
+ recv_counts = [0] * num_ranks
280
+ for src in range(num_ranks):
281
+ total = 0
282
+ for p in params:
283
+ state = param_to_state[id(p)]
284
+ if state.worker_rank != src:
285
+ continue
286
+ total += split_elems_for_src(p, state, rank, num_ranks)
287
+ recv_counts[src] = total
288
+
289
+ recv_total = sum(recv_counts)
290
+ assert recv_total > 0
291
+ recv_buf = torch.empty(recv_total, dtype=torch.bfloat16, device="cuda")
292
+
293
+ dist.all_to_all_single(
294
+ recv_buf,
295
+ send_buf,
296
+ output_split_sizes=recv_counts,
297
+ input_split_sizes=send_counts,
298
+ group=process_group,
299
  )
300
+
301
+ comm_stream.wait_event(alloc_event)
302
+
303
+ off = 0
304
+ for src in range(num_ranks):
305
+ block = recv_counts[src]
306
+ if block == 0:
307
+ continue
308
+
309
+ inner_off = 0
310
+ for p in params:
311
+ state = param_to_state[id(p)]
312
+ if state.worker_rank != src:
313
+ continue
314
+ n = split_elems_for_src(p, state, rank, num_ranks)
315
+ assert n > 0
316
+
317
+ flat_local = recv_buf.narrow(0, off + inner_off,
318
+ n).view_as(p.to_local())
319
+ state.scattered_u.copy_(flat_local)
320
+
321
+ state.scatter_event = torch.cuda.Event()
322
+ state.scatter_event.record(comm_stream)
323
+ inner_off += n
324
+
325
+ assert inner_off == block
326
+ off += block
327
 
328
 
329
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
743
  param_to_state, ordered_params = self.init_state_and_assign_params(
744
  names, params, group, qk_logits)
745
 
746
+ def enqueue_all2all_gather(start_idx, chunk_size):
747
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
748
+ if target_params:
749
+ alloc_event = _alloc_gathered_grad(target_params,
750
+ param_to_state, self.rank,
751
+ self.compute_stream)
752
+ _all2all_gather(target_params, param_to_state, self.rank,
753
+ self.comm_stream, group["none_grad"],
754
+ alloc_event)
755
 
756
  def enqueue_computes(start_idx, chunk_size):
757
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
759
  _compute_u(p, state, group["ns_steps"], self.rank,
760
  self.compute_stream)
761
 
762
+ def enqueue_all2all_scatter(start_idx, chunk_size):
763
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
764
+ if target_params:
765
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
766
+ self.rank,
767
+ self.compute_stream)
768
+ _all2all_scatter(target_params, param_to_state, self.rank,
769
+ self.comm_stream, alloc_event)
770
 
771
  def enqueue_update_param(start_idx, chunk_size):
772
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
781
  # Wait grad update
782
  self.comm_stream.wait_stream(torch.cuda.current_stream())
783
 
784
+ PRE_STEP = 5
785
+ for i in range(0, PRE_STEP):
786
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
787
+ enqueue_computes(i * chunk_size, chunk_size)
788
+
789
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
790
+ enqueue_all2all_scatter(i, chunk_size)
791
+ enqueue_all2all_gather(i + PRE_STEP * chunk_size, chunk_size)
792
+ enqueue_update_param(i, chunk_size)
793
+ enqueue_computes(i + PRE_STEP * chunk_size, chunk_size)
 
 
794
 
795
  # Wait the last update_param to finish
796
  torch.cuda.current_stream().wait_stream(self.compute_stream)