Kernels
wyldecat github-actions[bot] commited on
Commit
e2b41e5
·
unverified ·
1 Parent(s): e907c7d

Support param group with various placements (#13)

Browse files

* feat(muon): group parameters by placements for parallel Muon execution

* refactor(muon): refactor step func and group params with it's placement

* feat(muon): support general mesh

* refactor(muon): refactor state init

* refactor(muon): refactor test

* fix(muon): fix general mesh, add chunk_size argument

* refactor(muon): change overlap_step to warmup_step

* refactor(muon-test): rewrite README, add conftest.py and use explicit flags

* chore(muon): clarify N-D sharding support and add test reference

* fix: use device_mesh as key to group params

* Add built binary [skip-build]

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .pre-commit-config.yaml +0 -4
  2. README.md +10 -1
  3. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  4. build/torch28-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} +2 -2
  5. build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py +174 -0
  6. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +377 -206
  7. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/torch28-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} +2 -2
  9. build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py +174 -0
  10. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +377 -206
  11. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  12. build/torch28-cxx11-cu129-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} +2 -2
  13. build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py +174 -0
  14. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +377 -206
  15. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  16. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} +2 -2
  17. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py +174 -0
  18. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +377 -206
  19. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +3 -0
  21. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +0 -3
  22. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py +174 -0
  23. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +377 -206
  24. build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  25. build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +3 -0
  26. build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +0 -3
  27. build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py +174 -0
  28. build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py +377 -206
  29. build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  30. build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +3 -0
  31. build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +0 -3
  32. build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py +174 -0
  33. build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py +377 -206
  34. build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py +3 -3
  35. build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +3 -0
  36. build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +0 -3
  37. build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py +174 -0
  38. build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py +377 -206
  39. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  40. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +3 -0
  41. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +0 -3
  42. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py +174 -0
  43. build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py +377 -206
  44. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  45. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +3 -0
  46. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +0 -3
  47. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py +174 -0
  48. build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py +377 -206
  49. docs/muon/balanced.png +0 -3
  50. docs/muon/distributed_muon.png +0 -3
.pre-commit-config.yaml CHANGED
@@ -31,7 +31,3 @@ repos:
31
  hooks:
32
  - id: pymarkdown
33
  args: [fix]
34
- - repo: https://github.com/rhysd/actionlint
35
- rev: v1.7.7
36
- hooks:
37
- - id: actionlint
 
31
  hooks:
32
  - id: pymarkdown
33
  args: [fix]
 
 
 
 
README.md CHANGED
@@ -11,7 +11,13 @@ Optimizer is a python package that provides:
11
  - with support for parallelism techniques for efficient large-scale training.
12
 
13
  ## Currently implemented
14
- - [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
 
 
 
 
 
 
15
 
16
  ## Usage
17
 
@@ -39,6 +45,9 @@ optim = optimizer.Muon(
39
  )
40
  ```
41
 
 
 
 
42
  ## Pre-commit Hooks
43
 
44
  This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
 
11
  - with support for parallelism techniques for efficient large-scale training.
12
 
13
  ## Currently implemented
14
+ - Parallel Muon with N-D sharding
15
+ - arxiv URL: (TBW)
16
+ - Supports **general N-D sharding configurations**
17
+ - The implementation is not tied to any specific parallel strategy.
18
+ - Verified from basic FSDP2 setups up to hybrid configurations such as
19
+ **(2 TP + 2 DP-Replicate + 2 DP-Shard)**.
20
+ - Verified configurations can be found in [test_muon.py](./test/test_muon.py)
21
 
22
  ## Usage
23
 
 
45
  )
46
  ```
47
 
48
+ ## Test
49
+ - Check [test/README.md](./test/README.md) for how to run the tests.
50
+
51
  ## Pre-commit Hooks
52
 
53
  This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:511199ac2ae46febc8aeeb96e843a748da7d6fdea4922572ccf27ee5eabe312d
3
- size 1816064
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35708a107d9ac807fa3e63bbacfc6234fd7622a689a79eae3e43fce11f85d3da
3
+ size 1924376
build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b3cdb515b6c56204224cc307b66d34fcee1cd5e27b4117197a71b784d34fadc5
3
- size 1871056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03c3bbbbc5c4ceb5cebfe3a2e411f155bebb390f1921c14d59fcf791dd556da1
3
+ size 1983488
build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b957f60eab442d3ff5a5525d16a1b4b71e8c6be32edb874d9a5681953c61f0c2
3
- size 1871056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cbcd3df518412314d547a86b947998802e488e8aec0f22bf8b59fbc2d1c91e8
3
+ size 1983488
build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_811726c_dirty.abi3.so → _optimizer_23d68bb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:898ff08457f77c2f6ef504c73570cc87c5c5fd9a144528dbf8af4c03ffc21049
3
- size 1749232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a2999010ee158e13e3ef247e877dfab073b5bde7babefe2b2b5273b760c7ddf
3
+ size 1852152
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55f869cf4220f2033d4e499da522da46794a682495c2b688dbcac0ec89135cf4
3
+ size 1852240
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:72d100180fd73094f7b1c6e765eb4a77f103ad392fdee571687cb0c66d304177
3
- size 1749320
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca847c77875fc19f211a4c8ac217e9664b46c6862aa3234c270aacfea519d0f5
3
+ size 1924376
build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:87c8e75ead1c831dabfce1abbd7c100aa72c9b2988dfc0e1554216ca8005267c
3
- size 1816064
 
 
 
 
build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc97ff00a3255d5eb363958b1e619eadbc4315f1930d0fb59cfc9560c3951721
3
+ size 1983488
build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab1875be65811d88c407f36077aced58056a4feeb9946d7cd40ec55c7e1025c8
3
- size 1871056
 
 
 
 
build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa394498c52692c29094cbd2cc3da6c4c37aefaa4454c97487f8e91827fbd814
3
+ size 1988672
build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:52a744cf30c60fe1e8fc35ebb0d3421d679bb2047fbb4602846bd6902cfa9e52
3
- size 1872152
 
 
 
 
build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d297c32252c7f030f3ec60ab1cc908cf145c8ecc710a25690a528d06115ab998
3
+ size 1852184
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0661740cd0f97ca56ef83979c5a5fa059bcba411148f89d836e9305065578e73
3
- size 1749264
 
 
 
 
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_811726c_dirty
3
- ops = torch.ops._optimizer_811726c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_811726c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_23d68bb_dirty
3
+ ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_23d68bb_dirty::{op_name}"
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8de22742ad0d387021a7b812ee3b7d0c8c54191914c8c0469886f6d2c082e9e3
3
+ size 1852272
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:08b55491319446b12d0d890926506639640414edcba945e0f71afef0fac369d5
3
- size 1749352
 
 
 
 
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}.")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
68
+
69
+
70
+ def construct_shard_mesh(
71
+ placements: tuple[Placement],
72
+ mesh: DeviceMesh,
73
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
74
+ """
75
+ Construct Shard Mesh and Placements for unsharding.
76
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
77
+ """
78
+ my_rank = dist.get_rank()
79
+
80
+ assert mesh.mesh.device.type == 'cpu'
81
+
82
+ # Copy mesh to avoid modifying the original mesh
83
+ mesh = mesh.mesh.clone()
84
+
85
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
86
+
87
+ # For Shard, strided shard comes after regular shard on the same dim
88
+ # to preserve left-to-right order of replicate-to-shard.
89
+ # This is because that strided shard is using stride to represent
90
+ # more fine-grained sharding on the same dim.
91
+ # Please check the URL below for _StridedShard.
92
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
93
+
94
+ def placement_sort_key(
95
+ placement_with_index: tuple[float, Placement]
96
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
97
+ index, placement = placement_with_index
98
+ is_replicate = placement.is_replicate()
99
+ is_shard = placement.is_shard()
100
+ is_partial = placement.is_partial()
101
+
102
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
103
+ assert not is_partial, "Partial placement is not supported."
104
+
105
+ if is_replicate:
106
+ return (-1.0, 0, index)
107
+ elif is_shard:
108
+ if isinstance(placement, _StridedShard):
109
+ return (placement.dim, 1 / placement.split_factor, index)
110
+ return (placement.dim, 0, index)
111
+ else:
112
+ raise TypeError(f"Unknown placement type: {type(placement)}")
113
+
114
+ placements_with_index: list[tuple[int,
115
+ Placement]] = list(enumerate(placements))
116
+ placements_with_index = sorted(placements_with_index,
117
+ key=placement_sort_key)
118
+
119
+ sorted_indices, sorted_placements = zip(*placements_with_index)
120
+
121
+ # 2. Permute mesh according to sorted placements.
122
+ sorted_mesh = mesh.permute(sorted_indices)
123
+
124
+ # 3. Collect list of shard meshes by removing replicate dims
125
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
126
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
127
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
128
+
129
+ # merge replicate dims
130
+ # shard_meshes became a list of shard meshes with a length of replicate degree
131
+ if num_replicates > 0:
132
+ sorted_mesh = sorted_mesh.flatten(
133
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
134
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
135
+ else:
136
+ shard_meshes = [sorted_mesh]
137
+ shard_placements = sorted_placements[num_replicates:]
138
+
139
+ # assume all shard placements are different
140
+ assert len(shard_placements) == len(set(shard_placements))
141
+
142
+ # 4. Construct ProcessGroups
143
+ # Caution: all groups should be created in the same order in all processes,
144
+ # even though each process only needs its own group.
145
+
146
+ # To use tensor as dict key, convert it to tuple
147
+ def tensor_to_tuple(t):
148
+ if isinstance(t, torch.Tensor):
149
+ t = t.tolist()
150
+ if isinstance(t, list):
151
+ return tuple(tensor_to_tuple(x) for x in t)
152
+ return t
153
+
154
+ my_shard_mesh_as_tuple = None
155
+ for shard_mesh in shard_meshes:
156
+ assert isinstance(shard_mesh, torch.Tensor)
157
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
158
+
159
+ if (my_rank == shard_mesh).any().item():
160
+ assert my_shard_mesh_as_tuple is None
161
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
162
+
163
+ # update global cache
164
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
165
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
166
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
167
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
168
+ shard_process_group,
169
+ )
170
+
171
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
172
+ my_shard_mesh_as_tuple]
173
+
174
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,18 +1,24 @@
1
  import logging
2
  import math
3
  import types
 
4
  from dataclasses import dataclass
5
- from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
9
- from torch.distributed._tensor import DTensor, Replicate, Shard
 
 
 
10
 
 
11
  from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  COMM_DTYPE = torch.bfloat16
 
16
 
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps):
62
  @dataclass
63
  class _muon_state:
64
  # TODO: use Optional
65
- worker_rank: int | None = None
 
 
 
 
 
66
  gathered_grad: torch.Tensor | None = None
67
  scattered_u: DTensor | None = None
68
  computed_u: torch.Tensor | None = None
69
  gather_event: torch.cuda.Event | None = None
70
  compute_event: torch.cuda.Event | None = None
71
  scatter_event: torch.cuda.Event | None = None
72
- process_group = None
73
- qk_clip_state = None
74
 
75
 
76
- def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
- rows = param.shape[0]
78
- cols = int(param.numel() // rows)
79
- base, rem = divmod(rows, num_ranks)
80
- my_rows = base + (1 if src_rank < rem else 0)
81
- return my_rows * cols
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  @torch.no_grad()
@@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
91
  for p in params:
92
  state = param_to_state[id(p)]
93
  if rank == state.worker_rank:
94
- num_ranks = dist.get_world_size(group=state.process_group)
95
- state.gathered_grad = torch.empty(p.grad.numel(),
96
  dtype=COMM_DTYPE,
97
  device="cuda")
98
  else:
@@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
- g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
128
- per_dst[dst].append(g)
129
  send_counts[dst] += shard_elems
130
 
131
  assert any(
@@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
148
  for p in owned_params:
149
  state = param_to_state[id(p)]
150
  assert state.worker_rank == rank
151
- total += split_elems_for_src(p, src, num_ranks)
152
  recv_counts[src] = total
153
 
154
  recv_total = sum(recv_counts)
155
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
156
 
157
  #All2All
 
 
 
 
 
158
  dist.all_to_all_single(
159
  recv_buf,
160
  send_buf,
@@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
179
  comm_stream.wait_event(alloc_event)
180
 
181
  off = 0
182
- write_offsets = {id(p): 0 for p in owned_params}
183
  for src in range(num_ranks):
184
  if recv_counts[src] == 0:
185
  continue
@@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
189
  for p in owned_params:
190
  state = param_to_state[id(p)]
191
  assert state.worker_rank == rank
192
- n = split_elems_for_src(p, src, num_ranks)
 
 
 
 
 
 
 
 
 
193
  assert n > 0
194
 
195
  sg = recv_buf.narrow(0, off + inner_off, n)
196
- woff = write_offsets[id(p)]
197
- dst = state.gathered_grad.narrow(0, woff, n)
198
  dst.copy_(sg)
199
 
200
- write_offsets[id(p)] += n
201
  inner_off += n
202
  off += block
203
 
204
  for p in params:
205
  state = param_to_state[id(p)]
206
  if state.worker_rank == rank:
207
- state.gathered_grad = state.gathered_grad.view_as(p)
208
  state.gather_event = torch.cuda.Event()
209
  state.gather_event.record(comm_stream)
210
  else:
@@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
277
 
278
  assert state.computed_u is not None
279
 
280
- u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
281
 
282
  offset = 0
283
  for dst in range(num_ranks):
284
- n = split_elems_for_src(p, dst, num_ranks)
 
 
 
 
 
 
285
  assert n > 0
286
 
287
- su = u_full.narrow(0, offset, n)
288
  per_dst[dst].append(su)
289
  send_counts[dst] += n
290
  offset += n
@@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
313
  state = param_to_state[id(p)]
314
  if state.worker_rank != src:
315
  continue
316
- total += split_elems_for_src(p, rank, num_ranks)
317
  recv_counts[src] = total
318
 
319
  recv_total = sum(recv_counts)
@@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
357
  state = param_to_state[id(p)]
358
  if state.worker_rank != src:
359
  continue
360
- n = split_elems_for_src(p, rank, num_ranks)
361
  assert n > 0
362
 
363
  flat_local = recv_buf.narrow(0, off + inner_off,
@@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
398
  state.scattered_u = None
399
  u_dtensor = None
400
 
401
- scales_full = Muon._compute_scales(p, state.qk_clip_state)
 
 
402
  if scales_full is not None:
403
- num_ranks = dist.get_world_size(group=state.process_group)
404
- local_rank = dist.get_rank(group=state.process_group)
405
- scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
 
 
 
 
 
 
 
 
 
 
406
  scales_local = DTensor.from_local(
407
  scales_local,
408
  placements=p.placements,
@@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
478
  @dataclass
479
  class QKClipInfo:
480
  """Per-parameter dynamic info computed from config + runtime logits."""
481
- kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
482
- indices: List[int] # which heads to consider for clipping
483
  head_dim: int # from config
484
  threshold: float # from config
485
- logit: Optional[torch.Tensor]
486
 
487
 
488
  class Muon(torch.optim.Optimizer):
@@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer):
525
  "head_dim": 128,
526
  "threshold": 100
527
  }
528
- overlap_step : How many all2all gather, compute operations are launched in advance
529
- before the corresponding all2all scatter steps begin.
530
- A higher overlap_step increases memory usage but can improve
531
- performance by overlapping communication.
532
- Parallel muon only.
 
 
 
 
 
533
  """
534
 
535
  def __init__(self,
@@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer):
549
  "head_dim": 128,
550
  "threshold": 100
551
  },
552
- overlap_step=5):
 
 
553
  defaults = dict(
554
  lr=lr,
555
  weight_decay=weight_decay,
@@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer):
579
  self.compute_stream = torch.cuda.Stream()
580
  self.debug = debug
581
  self.clip_config = clip_config
582
- self.overlap_step = overlap_step
 
 
583
 
584
  def _calc_flops(self, G, steps):
585
  assert len(G.shape) == 2
@@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
 
 
 
 
 
 
600
  def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
@@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer):
604
  assert isinstance(
605
  p, DTensor), "Parallel Muon only supports DTensor parameters."
606
 
607
- if p.placements == (Shard(dim=0), ):
608
- # Case for FSDP
609
- process_group = p.device_mesh.get_group(mesh_dim=0)
610
- if self.rank is None:
611
- self.rank = dist.get_rank(group=process_group)
612
- else:
613
- assert self.rank == dist.get_rank(group=process_group)
614
- return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
- elif p.placements == (Replicate(), Shard(dim=0)):
616
- # Case for HSDP
617
- process_group = p.device_mesh.get_group(mesh_dim=1)
618
- if self.rank is None:
619
- self.rank = dist.get_rank(group=process_group)
620
- else:
621
- assert self.rank == dist.get_rank(group=process_group)
622
- for i, shard_mesh in enumerate(p.device_mesh.mesh):
623
- if self.rank in shard_mesh:
624
- return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
625
- else:
626
- raise ValueError(f"Unsupported placements ({p.placements}).")
627
 
628
  def init_state_and_assign_params(self, names, params, group, qk_logits):
629
  param_to_state = {}
@@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer):
655
  ordered_params = list(params_sorted)
656
 
657
  round_robin = 0
658
- mesh = None
659
- shard_mesh = None
660
- process_group = None
 
 
 
 
 
661
  for n, p in zip(ordered_names, ordered_params):
662
- if mesh is None:
663
- mesh = p.device_mesh
664
- shard_mesh, process_group = self.get_shard_mesh(p)
665
- elif mesh != p.device_mesh:
666
  raise ValueError("All parameters must be on the same mesh.")
667
- num_ranks = dist.get_world_size(group=process_group)
668
- param_to_state[id(p)] = _muon_state()
669
- param_to_state[id(
670
- p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
671
- param_to_state[id(p)].process_group = process_group
672
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
673
- param_to_state[id(p)].qk_clip_state = qk_clip_state
674
- round_robin = (round_robin + 1) % len(shard_mesh)
 
 
 
 
 
 
 
675
 
676
  return param_to_state, ordered_params
677
 
@@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer):
705
 
706
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
707
 
708
- scales_full = self._compute_scales(p, qk_clip_state)
 
709
  if scales_full is not None:
710
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  def _update_g(self, p, g, group, momentum):
713
  # calc update
714
  state = self.state[p]
@@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer):
727
  p.data.add_(u, alpha=-adjusted_lr)
728
 
729
  def get_qk_clip_info(self, n, qk_logits):
 
 
 
730
  head_dim = self.clip_config.get('head_dim')
731
  threshold = self.clip_config.get('threshold')
732
  kind, layer_idx = parse_qk_layer(n)
@@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer):
737
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
738
  indices = self.clip_config.get(indices_key, []) or []
739
 
 
 
 
 
 
740
  return QKClipInfo(
741
  kind=kind,
742
  indices=indices,
@@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer):
835
  _update_param(p, state, lr, adjusted_lr, weight_decay,
836
  self.rank, self.compute_stream)
837
 
838
- chunk_size = dist.get_world_size(param_to_state[id(
839
- params[0])].process_group)
 
 
 
 
 
 
840
 
841
  # Wait grad update
842
  self.comm_stream.wait_stream(torch.cuda.current_stream())
843
 
844
- overlap_step = self.overlap_step
845
- for i in range(0, overlap_step):
846
  enqueue_all2all_gather(i * chunk_size, chunk_size)
847
  enqueue_computes(i * chunk_size, chunk_size)
848
 
849
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
850
  enqueue_all2all_scatter(i, chunk_size)
851
- enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
852
  enqueue_update_param(i, chunk_size)
853
- enqueue_computes(i + overlap_step * chunk_size, chunk_size)
854
 
855
  # Wait the last update_param to finish
856
  torch.cuda.current_stream().wait_stream(self.compute_stream)
@@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer):
866
  amsgrad: bool,
867
  beta1: float,
868
  beta2: float,
869
- lr: Union[float, torch.Tensor],
870
  weight_decay: float,
871
  eps: float,
872
  maximize: bool,
@@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer):
876
 
877
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
878
  # treating it as a scalar.
879
- lr_dict: Optional[DeviceDict] = ({
880
  lr.device: lr
881
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
882
- None)
883
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
884
  [
885
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
@@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer):
926
  maximize=maximize,
927
  )
928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  def step(self, closure=None, qk_logits=None):
930
  """Perform a single optimization step.
931
 
@@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer):
943
  loss = closure()
944
 
945
  for group in self.param_groups:
946
- params = group["params"]
947
-
948
  if group["use_muon"]:
949
- ############################
950
- # Muon #
951
- ############################
952
- lr = group["lr"]
953
- weight_decay = group["weight_decay"]
954
- momentum = group["momentum"]
955
- names = group["names"]
956
-
957
- param_dtensors = []
958
- param_tensors = []
959
- name_dtensors = []
960
- name_tensors = []
961
-
962
- for n, p in zip(names, params):
963
- if p is None or p.grad is None:
964
- continue
965
- if isinstance(p.data, DTensor):
966
- if all(
967
- isinstance(placement, Replicate)
968
- for placement in p.placements):
969
- param_tensors.append(p)
970
- name_tensors.append(n)
971
- else:
972
- param_dtensors.append(p)
973
- name_dtensors.append(n)
974
- elif isinstance(p.data, torch.Tensor):
975
- param_tensors.append(p)
976
- name_tensors.append(n)
977
- else:
978
- raise TypeError(
979
- f"Unsupported parameter type: {type(p.data)}")
980
-
981
- if self.debug:
982
- print(
983
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
984
- flush=True,
985
- )
986
-
987
- if len(param_dtensors) > 0:
988
- if not dist.is_initialized():
989
- raise RuntimeError(
990
- "Parallel Muon requires torch.distributed to be initialized."
991
- )
992
-
993
- self.parallel(
994
- name_dtensors,
995
- param_dtensors,
996
- group,
997
- lr=lr,
998
- weight_decay=weight_decay,
999
- momentum=momentum,
1000
- qk_logits=qk_logits,
1001
- )
1002
-
1003
- if len(param_tensors) > 0:
1004
- self.base(
1005
- name_tensors,
1006
- param_tensors,
1007
- group,
1008
- lr=lr,
1009
- weight_decay=weight_decay,
1010
- momentum=momentum,
1011
- qk_logits=qk_logits,
1012
- )
1013
-
1014
  else:
1015
- ############################
1016
- # AdamW backup #
1017
- ############################
1018
-
1019
- params_with_grads = []
1020
- grads = []
1021
- moment1 = []
1022
- moment2 = []
1023
- max_exp_avg_sqs = []
1024
- state_steps = []
1025
- lr = group["lr"]
1026
- beta1, beta2 = group["adamw_betas"]
1027
- eps = group["adamw_eps"]
1028
- weight_decay = group["weight_decay"]
1029
-
1030
- for p in params:
1031
- g = p.grad
1032
- if g is None:
1033
- continue
1034
- state = self.state[p]
1035
- params_with_grads.append(p)
1036
- grads.append(g)
1037
- if "step" not in state:
1038
- state["step"] = (torch.zeros((),
1039
- dtype=torch.float32,
1040
- device=p.device))
1041
- state["moment1"] = torch.zeros_like(g)
1042
- state["moment2"] = torch.zeros_like(g)
1043
- moment1.append(state["moment1"])
1044
- moment2.append(state["moment2"])
1045
- if not isinstance(state["step"], torch.Tensor):
1046
- step_tensor = torch.tensor(state["step"],
1047
- dtype=torch.float32,
1048
- device=p.device)
1049
- else:
1050
- step_tensor = state["step"]
1051
- state_steps.append(step_tensor)
1052
-
1053
- self._fused_adamw(
1054
- params_with_grads,
1055
- grads,
1056
- moment1,
1057
- moment2,
1058
- max_exp_avg_sqs,
1059
- state_steps,
1060
- amsgrad=False,
1061
- beta1=beta1,
1062
- beta2=beta2,
1063
- lr=lr,
1064
- weight_decay=weight_decay,
1065
- eps=eps,
1066
- maximize=False,
1067
- )
1068
 
1069
  return loss
 
1
  import logging
2
  import math
3
  import types
4
+ from collections import defaultdict
5
  from dataclasses import dataclass
6
+ from typing import Any, cast
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
 
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
  from .matmul_transpose_triton import matmul_transpose_assign
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
 
23
 
24
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
68
  @dataclass
69
  class _muon_state:
70
  # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
  gathered_grad: torch.Tensor | None = None
78
  scattered_u: DTensor | None = None
79
  computed_u: torch.Tensor | None = None
80
  gather_event: torch.cuda.Event | None = None
81
  compute_event: torch.cuda.Event | None = None
82
  scatter_event: torch.cuda.Event | None = None
 
 
83
 
84
 
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
 
105
 
106
  @torch.no_grad()
 
113
  for p in params:
114
  state = param_to_state[id(p)]
115
  if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
 
117
  dtype=COMM_DTYPE,
118
  device="cuda")
119
  else:
 
142
  state = param_to_state[id(p)]
143
  dst = state.worker_rank
144
  assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
  g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
  assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
  send_counts[dst] += shard_elems
151
 
152
  assert any(
 
169
  for p in owned_params:
170
  state = param_to_state[id(p)]
171
  assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
  recv_counts[src] = total
174
 
175
  recv_total = sum(recv_counts)
176
  recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
 
178
  #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
  dist.all_to_all_single(
185
  recv_buf,
186
  send_buf,
 
205
  comm_stream.wait_event(alloc_event)
206
 
207
  off = 0
 
208
  for src in range(num_ranks):
209
  if recv_counts[src] == 0:
210
  continue
 
214
  for p in owned_params:
215
  state = param_to_state[id(p)]
216
  assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
  assert n > 0
228
 
229
  sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
 
231
  dst.copy_(sg)
232
 
 
233
  inner_off += n
234
  off += block
235
 
236
  for p in params:
237
  state = param_to_state[id(p)]
238
  if state.worker_rank == rank:
 
239
  state.gather_event = torch.cuda.Event()
240
  state.gather_event.record(comm_stream)
241
  else:
 
308
 
309
  assert state.computed_u is not None
310
 
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
 
313
  offset = 0
314
  for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
  assert n > 0
323
 
 
324
  per_dst[dst].append(su)
325
  send_counts[dst] += n
326
  offset += n
 
349
  state = param_to_state[id(p)]
350
  if state.worker_rank != src:
351
  continue
352
+ total += numel_for_rank(p, rank, state)
353
  recv_counts[src] = total
354
 
355
  recv_total = sum(recv_counts)
 
393
  state = param_to_state[id(p)]
394
  if state.worker_rank != src:
395
  continue
396
+ n = numel_for_rank(p, rank, state)
397
  assert n > 0
398
 
399
  flat_local = recv_buf.narrow(0, off + inner_off,
 
434
  state.scattered_u = None
435
  u_dtensor = None
436
 
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
  if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
  scales_local = DTensor.from_local(
455
  scales_local,
456
  placements=p.placements,
 
526
  @dataclass
527
  class QKClipInfo:
528
  """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
  head_dim: int # from config
532
  threshold: float # from config
533
+ logit: torch.Tensor | None
534
 
535
 
536
  class Muon(torch.optim.Optimizer):
 
573
  "head_dim": 128,
574
  "threshold": 100
575
  }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
  """
587
 
588
  def __init__(self,
 
602
  "head_dim": 128,
603
  "threshold": 100
604
  },
605
+ warmup_step=5,
606
+ chunk_size=-1,
607
+ use_distributed_muon=False):
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
 
634
  self.compute_stream = torch.cuda.Stream()
635
  self.debug = debug
636
  self.clip_config = clip_config
637
+ self.warmup_step = warmup_step
638
+ self.chunk_size = chunk_size
639
+ self.use_distributed_muon = use_distributed_muon
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
 
654
  adjusted_lr = lr * adjusted_ratio
655
  return adjusted_lr
656
 
657
+ def set_rank_once(self, rank):
658
+ if self.rank is None:
659
+ self.rank = rank
660
+ else:
661
+ assert self.rank == rank
662
+
663
  def get_shard_mesh(self, p):
664
  """
665
  Get the shard mesh for a parameter p on the given rank.
 
667
  assert isinstance(
668
  p, DTensor), "Parallel Muon only supports DTensor parameters."
669
 
670
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
671
+ p.placements, p.device_mesh)
672
+
673
+ # set rank with the local rank in the shard process group
674
+ self.set_rank_once(dist.get_rank(group=shard_pg))
675
+
676
+ return shard_mesh, shard_pg, shard_placements
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  def init_state_and_assign_params(self, names, params, group, qk_logits):
679
  param_to_state = {}
 
705
  ordered_params = list(params_sorted)
706
 
707
  round_robin = 0
708
+ mesh = ordered_params[0].device_mesh
709
+ placements = ordered_params[0].placements
710
+
711
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
712
+ ordered_params[0])
713
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
714
+ num_ranks = dist.get_world_size(group=shard_pg)
715
+
716
  for n, p in zip(ordered_names, ordered_params):
717
+ if mesh != p.device_mesh:
 
 
 
718
  raise ValueError("All parameters must be on the same mesh.")
719
+ if placements != p.placements:
720
+ raise ValueError("All parameters must have same placements.")
721
+
722
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
723
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
724
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
725
+
726
+ param_to_state[id(p)] = _muon_state(
727
+ worker_rank=worker_rank,
728
+ process_group=shard_pg,
729
+ shard_mesh=shard_mesh,
730
+ shard_placements=shard_placements,
731
+ name=n,
732
+ qk_clip_state=qk_clip_state,
733
+ )
734
 
735
  return param_to_state, ordered_params
736
 
 
764
 
765
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
766
 
767
+ scales_full = self._compute_scales(
768
+ p, qk_clip_state) if qk_clip_state is not None else None
769
  if scales_full is not None:
770
  Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
771
 
772
+ def distributed_muon(
773
+ self,
774
+ names: list[str],
775
+ params: list[torch.nn.Parameter],
776
+ group: dict[str, Any],
777
+ lr: float,
778
+ weight_decay: float,
779
+ momentum: float,
780
+ qk_logits: list[torch.Tensor | DTensor] | None,
781
+ ):
782
+ """ Implementation of Distributed Muon by Liu et al. """
783
+ if qk_logits is not None:
784
+ raise NotImplementedError("QK clipping is not supported yet")
785
+
786
+ if isinstance(params[0], DTensor):
787
+ shard_mesh, _, shard_placements = construct_shard_mesh(
788
+ placements=params[0].placements,
789
+ mesh=params[0].device_mesh,
790
+ )
791
+
792
+ for n, p in zip(names, params):
793
+ g = p.grad
794
+ if g is None:
795
+ continue
796
+ if g.ndim > 2:
797
+ g = g.view(g.size(0), -1)
798
+ assert g is not None
799
+
800
+ # calc update
801
+ state = self.state[p]
802
+ if "momentum_buffer" not in state:
803
+ state["momentum_buffer"] = torch.zeros_like(g)
804
+ buf = state["momentum_buffer"]
805
+ buf.mul_(momentum).add_(g)
806
+ if group["nesterov"]:
807
+ g = g.add(buf, alpha=momentum)
808
+ else:
809
+ g = buf
810
+
811
+ # Gather G
812
+ if isinstance(p.data, DTensor):
813
+ g = g.full_tensor()
814
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
+ steps=group["ns_steps"])
816
+
817
+ if isinstance(p.data, DTensor):
818
+ slices = get_slices_of_dtensor(
819
+ target=p,
820
+ local_rank=dist.get_rank(),
821
+ shard_mesh=shard_mesh,
822
+ shard_placements=shard_placements,
823
+ )
824
+ u_shard = u[slices]
825
+ u = DTensor.from_local(
826
+ u_shard,
827
+ device_mesh=p.device_mesh,
828
+ placements=p.placements,
829
+ )
830
+
831
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
+
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
836
  state = self.state[p]
 
849
  p.data.add_(u, alpha=-adjusted_lr)
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
+ if self.clip_config is None:
853
+ return None
854
+
855
  head_dim = self.clip_config.get('head_dim')
856
  threshold = self.clip_config.get('threshold')
857
  kind, layer_idx = parse_qk_layer(n)
 
862
  indices_key = 'q_indices' if 'q' in kind else 'k_indices'
863
  indices = self.clip_config.get(indices_key, []) or []
864
 
865
+ if isinstance(logit, DTensor):
866
+ # In TP settings, qk_logits may be DTensor
867
+ # We convert it to full tensor here for simplicity
868
+ logit = logit.full_tensor()
869
+
870
  return QKClipInfo(
871
  kind=kind,
872
  indices=indices,
 
965
  _update_param(p, state, lr, adjusted_lr, weight_decay,
966
  self.rank, self.compute_stream)
967
 
968
+ if self.chunk_size == -1:
969
+ shard_ranks = dist.get_world_size(param_to_state[id(
970
+ params[0])].process_group)
971
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
972
+ elif self.chunk_size > 0:
973
+ chunk_size = self.chunk_size
974
+ else:
975
+ raise ValueError("chunk_size must be -1 or a positive integer.")
976
 
977
  # Wait grad update
978
  self.comm_stream.wait_stream(torch.cuda.current_stream())
979
 
980
+ warmup_step = self.warmup_step
981
+ for i in range(0, warmup_step):
982
  enqueue_all2all_gather(i * chunk_size, chunk_size)
983
  enqueue_computes(i * chunk_size, chunk_size)
984
 
985
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
986
  enqueue_all2all_scatter(i, chunk_size)
987
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
988
  enqueue_update_param(i, chunk_size)
989
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
990
 
991
  # Wait the last update_param to finish
992
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
1002
  amsgrad: bool,
1003
  beta1: float,
1004
  beta2: float,
1005
+ lr: float | torch.Tensor,
1006
  weight_decay: float,
1007
  eps: float,
1008
  maximize: bool,
 
1012
 
1013
  # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1014
  # treating it as a scalar.
1015
+ lr_dict: DeviceDict | None = ({
1016
  lr.device: lr
1017
  } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1018
+ None)
1019
  grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1020
  [
1021
  params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
 
1062
  maximize=maximize,
1063
  )
1064
 
1065
+ def _step_muon(self, group, qk_logits=None):
1066
+ params = group["params"]
1067
+ lr = group["lr"]
1068
+ weight_decay = group["weight_decay"]
1069
+ momentum = group["momentum"]
1070
+ names = group["names"]
1071
+
1072
+ param_dtensors = []
1073
+ param_tensors = []
1074
+ name_dtensors = []
1075
+ name_tensors = []
1076
+
1077
+ if self.use_distributed_muon:
1078
+ self.distributed_muon(names=names,
1079
+ params=params,
1080
+ group=group,
1081
+ lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ momentum=momentum,
1084
+ qk_logits=qk_logits)
1085
+ return
1086
+
1087
+ for n, p in zip(names, params):
1088
+ if p is None or p.grad is None:
1089
+ continue
1090
+ if isinstance(p.data, DTensor):
1091
+ if all(
1092
+ isinstance(placement, Replicate)
1093
+ for placement in p.placements):
1094
+ param_tensors.append(p)
1095
+ name_tensors.append(n)
1096
+ else:
1097
+ param_dtensors.append(p)
1098
+ name_dtensors.append(n)
1099
+ elif isinstance(p.data, torch.Tensor):
1100
+ param_tensors.append(p)
1101
+ name_tensors.append(n)
1102
+ else:
1103
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
+
1105
+ logger.debug(
1106
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
+ )
1108
+
1109
+ if len(param_dtensors) > 0:
1110
+ if not dist.is_initialized():
1111
+ raise RuntimeError(
1112
+ "Parallel Muon requires torch.distributed to be initialized."
1113
+ )
1114
+
1115
+ # To support different placements, we group parameters by placements
1116
+ # and run parallel Muon on each group.
1117
+
1118
+ placement_to_params = defaultdict(lambda: ([], []))
1119
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
+
1121
+ assert len(name_dtensors) == len(param_dtensors)
1122
+ for n, p in zip(name_dtensors, param_dtensors):
1123
+ placement_to_params[tuple([p.placements,
1124
+ p.device_mesh])][0].append(n)
1125
+ placement_to_params[tuple([p.placements,
1126
+ p.device_mesh])][1].append(p)
1127
+
1128
+ for _, (names, params) in placement_to_params.items():
1129
+ self.parallel(
1130
+ names,
1131
+ params,
1132
+ group,
1133
+ lr=lr,
1134
+ weight_decay=weight_decay,
1135
+ momentum=momentum,
1136
+ qk_logits=qk_logits,
1137
+ )
1138
+
1139
+ if len(param_tensors) > 0:
1140
+ self.base(
1141
+ name_tensors,
1142
+ param_tensors,
1143
+ group,
1144
+ lr=lr,
1145
+ weight_decay=weight_decay,
1146
+ momentum=momentum,
1147
+ qk_logits=qk_logits,
1148
+ )
1149
+
1150
+ def _step_adamw_params(self, params, group):
1151
+ params_with_grads = []
1152
+ grads = []
1153
+ moment1 = []
1154
+ moment2 = []
1155
+ max_exp_avg_sqs = []
1156
+ state_steps = []
1157
+ lr = group["lr"]
1158
+ beta1, beta2 = group["adamw_betas"]
1159
+ eps = group["adamw_eps"]
1160
+ weight_decay = group["weight_decay"]
1161
+
1162
+ for p in params:
1163
+ g = p.grad
1164
+ if g is None:
1165
+ continue
1166
+ state = self.state[p]
1167
+ params_with_grads.append(p)
1168
+ grads.append(g)
1169
+ if "step" not in state:
1170
+ state["step"] = (torch.zeros((),
1171
+ dtype=torch.float32,
1172
+ device=p.device))
1173
+ state["moment1"] = torch.zeros_like(g)
1174
+ state["moment2"] = torch.zeros_like(g)
1175
+ moment1.append(state["moment1"])
1176
+ moment2.append(state["moment2"])
1177
+ if not isinstance(state["step"], torch.Tensor):
1178
+ step_tensor = torch.tensor(state["step"],
1179
+ dtype=torch.float32,
1180
+ device=p.device)
1181
+ else:
1182
+ step_tensor = state["step"]
1183
+ state_steps.append(step_tensor)
1184
+
1185
+ self._fused_adamw(
1186
+ params_with_grads,
1187
+ grads,
1188
+ moment1,
1189
+ moment2,
1190
+ max_exp_avg_sqs,
1191
+ state_steps,
1192
+ amsgrad=False,
1193
+ beta1=beta1,
1194
+ beta2=beta2,
1195
+ lr=lr,
1196
+ weight_decay=weight_decay,
1197
+ eps=eps,
1198
+ maximize=False,
1199
+ )
1200
+
1201
+ def _step_adamw(self, group):
1202
+ params = group["params"]
1203
+
1204
+ # group params with it's type and placement
1205
+ placement_to_params: dict[tuple[Placement | type,
1206
+ DeviceMesh | None]] = defaultdict(list)
1207
+ for p in params:
1208
+ match p:
1209
+ case DTensor():
1210
+ placement_to_params[tuple([p.placements,
1211
+ p.device_mesh])].append(p)
1212
+ case torch.Tensor():
1213
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1214
+
1215
+ for params in placement_to_params.values():
1216
+ self._step_adamw_params(params, group)
1217
+
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
1232
  loss = closure()
1233
 
1234
  for group in self.param_groups:
 
 
1235
  if group["use_muon"]:
1236
+ self._step_muon(group, qk_logits=qk_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  else:
1238
+ self._step_adamw(group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  return loss
docs/muon/balanced.png DELETED

Git LFS Details

  • SHA256: 9933e2cd5490513593dd6cf1c5c4f18b7f33fd6e6b11c696784269c2bb78055b
  • Pointer size: 130 Bytes
  • Size of remote file: 98 kB
docs/muon/distributed_muon.png DELETED

Git LFS Details

  • SHA256: 31caea472991fd24a7934bf211b5adcbf154b5295bfe364bba5b603851c2cfae
  • Pointer size: 131 Bytes
  • Size of remote file: 408 kB