diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75685f8c19716e414e4e3d77d778593c4c5a9094..55f8e34c04aac06db5a3137a475e13e3e5ecf8d5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,3 @@ repos: hooks: - id: pymarkdown args: [fix] -- repo: https://github.com/rhysd/actionlint - rev: v1.7.7 - hooks: - - id: actionlint diff --git a/README.md b/README.md index b9fc1f180e924dc089f6de9f8320a50e28bb8c94..0b701d22830d3161178074c188e4ccb1e7bcbfa0 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,13 @@ Optimizer is a python package that provides: - with support for parallelism techniques for efficient large-scale training. ## Currently implemented -- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf) +- Parallel Muon with N-D sharding + - arxiv URL: (TBW) + - Supports **general N-D sharding configurations** + - The implementation is not tied to any specific parallel strategy. + - Verified from basic FSDP2 setups up to hybrid configurations such as + **(2 TP + 2 DP-Replicate + 2 DP-Shard)**. + - Verified configurations can be found in [test_muon.py](./test/test_muon.py) ## Usage @@ -39,6 +45,9 @@ optim = optimizer.Muon( ) ``` +## Test +- Check [test/README.md](./test/README.md) for how to run the tests. + ## Pre-commit Hooks This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits. diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..62346e32d4dc69c4cefb083f0c788f6564fb142c --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35708a107d9ac807fa3e63bbacfc6234fd7622a689a79eae3e43fce11f85d3da +size 1924376 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index 9e92fc41c7a93657d40537cd4602f68973eb89f0..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:511199ac2ae46febc8aeeb96e843a748da7d6fdea4922572ccf27ee5eabe312d -size 1816064 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d31f69b06fba65c78b497ee3f83cdb2b894170b2 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03c3bbbbc5c4ceb5cebfe3a2e411f155bebb390f1921c14d59fcf791dd556da1 +size 1983488 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index a873919931082d6034271687f0224f1accc1bb39..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b3cdb515b6c56204224cc307b66d34fcee1cd5e27b4117197a71b784d34fadc5 -size 1871056 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1cc1c027dd79defcd367dd32836ace4dc43d3cf1 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cbcd3df518412314d547a86b947998802e488e8aec0f22bf8b59fbc2d1c91e8 +size 1983488 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index 56cc7eee1fd7a21a35397b3ebb44aa0fd4ef2da7..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b957f60eab442d3ff5a5525d16a1b4b71e8c6be32edb874d9a5681953c61f0c2 -size 1871056 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..965a07d2753d33cb8afcabbeb81d4c2f28517ce2 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a2999010ee158e13e3ef247e877dfab073b5bde7babefe2b2b5273b760c7ddf +size 1852152 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index fa2d9f45c0e3cbbcb6b59c9a7ef9c1802df37c3a..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:898ff08457f77c2f6ef504c73570cc87c5c5fd9a144528dbf8af4c03ffc21049 -size 1749232 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..61f550abf74c8dc521bf56e2c9e2a904b2582331 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55f869cf4220f2033d4e499da522da46794a682495c2b688dbcac0ec89135cf4 +size 1852240 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index 264ba7487b636164b2b498aaec0420484213e908..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:72d100180fd73094f7b1c6e765eb4a77f103ad392fdee571687cb0c66d304177 -size 1749320 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d5bbae6d37395b7e65f64cabcca135df1faac8b3 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca847c77875fc19f211a4c8ac217e9664b46c6862aa3234c270aacfea519d0f5 +size 1924376 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index e5bd4712fd89d9725a902d1fda8cbed459f28873..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:87c8e75ead1c831dabfce1abbd7c100aa72c9b2988dfc0e1554216ca8005267c -size 1816064 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3fd487db1ff3b2ee3b5ab65ea2272e7fe95e5c76 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc97ff00a3255d5eb363958b1e619eadbc4315f1930d0fb59cfc9560c3951721 +size 1983488 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index bc35be0a7df2a86ce8f85d1f605fc8f74d16ced1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ab1875be65811d88c407f36077aced58056a4feeb9946d7cd40ec55c7e1025c8 -size 1871056 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8910b678b8c1c618c797441b964171862abbb32d --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa394498c52692c29094cbd2cc3da6c4c37aefaa4454c97487f8e91827fbd814 +size 1988672 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index 3c17d334255f92a6d0c52c10ec5bfd9e92d12d38..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:52a744cf30c60fe1e8fc35ebb0d3421d679bb2047fbb4602846bd6902cfa9e52 -size 1872152 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fd962f33db89d6740c8d181f6b4e3ded3220fec0 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d297c32252c7f030f3ec60ab1cc908cf145c8ecc710a25690a528d06115ab998 +size 1852184 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index a20d2608b0326217b1a99f7313f2eb110e5634ab..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0661740cd0f97ca56ef83979c5a5fa059bcba411148f89d836e9305065578e73 -size 1749264 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py index 0e7e761d98b3dd32cb29d88cfffcf96d574af438..7d598206add1bca142661a3df6c510e3d9575d54 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_811726c_dirty -ops = torch.ops._optimizer_811726c_dirty +from . import _optimizer_23d68bb_dirty +ops = torch.ops._optimizer_23d68bb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_811726c_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..eaf3eac7689223f26618fe6b233e8a98058cb637 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8de22742ad0d387021a7b812ee3b7d0c8c54191914c8c0469886f6d2c082e9e3 +size 1852272 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so deleted file mode 100755 index 88ccaddae0db7abb1c9eef791ef0c790f0e8fc3b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_811726c_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:08b55491319446b12d0d890926506639640414edcba945e0f71afef0fac369d5 -size 1749352 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss diff --git a/docs/muon/balanced.png b/docs/muon/balanced.png deleted file mode 100644 index 2076978a5a0149d598b419bfc45c508405dca0df..0000000000000000000000000000000000000000 --- a/docs/muon/balanced.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9933e2cd5490513593dd6cf1c5c4f18b7f33fd6e6b11c696784269c2bb78055b -size 98003 diff --git a/docs/muon/distributed_muon.png b/docs/muon/distributed_muon.png deleted file mode 100644 index 26544c9e035afae48d1b32cd6ae729c600a47f33..0000000000000000000000000000000000000000 --- a/docs/muon/distributed_muon.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:31caea472991fd24a7934bf211b5adcbf154b5295bfe364bba5b603851c2cfae -size 407912 diff --git a/docs/muon/distributed_muon_execution.png b/docs/muon/distributed_muon_execution.png deleted file mode 100644 index 824c728b78c73ca0d5b70a169ed2e5e50a59946c..0000000000000000000000000000000000000000 --- a/docs/muon/distributed_muon_execution.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:72ab4d8076f1e182900d71636dd22c32b20bface38890cef72a0c94c496d5f02 -size 57140 diff --git a/docs/muon/imbalance.png b/docs/muon/imbalance.png deleted file mode 100644 index d63f0a034912195910cfac8a49f0533ac99968b1..0000000000000000000000000000000000000000 --- a/docs/muon/imbalance.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c71d5faed05d46b2269fefa3b6bea6791d7bf51744f47aa4bb8c311eda1b27ff -size 56528 diff --git a/docs/muon/main.tex b/docs/muon/main.tex deleted file mode 100644 index 41ad35f7b47a0bcfd9cc5bcc879b5fa1bf56c6f4..0000000000000000000000000000000000000000 --- a/docs/muon/main.tex +++ /dev/null @@ -1,142 +0,0 @@ -\documentclass{article} -\usepackage{graphicx} -\usepackage{hyperref} -\usepackage{amsmath} -\usepackage{caption} -\usepackage{tgtermes} -\usepackage{float} -\usepackage[a4paper, margin=1in]{geometry} -\usepackage{booktabs} -\usepackage{algorithm} -\usepackage{algorithmicx} -\usepackage{algpseudocode} -\date{} - -\begin{document} - -{\LARGE \bfseries Parallelize Muon with FSDP2 \par} -\vspace{1em} % 제목 아래 간격 조정 - -\section*{Motivation} - -\begin{figure}[H] - \centering - \includegraphics[width=0.8\textwidth]{distributed_muon.png} - \caption*{Distributed Muon by Moonlight} -\end{figure} - -While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs. - -\begin{figure}[H] - \centering - \includegraphics[width=1.0\textwidth]{distributed_muon_execution.png} - \caption*{Execution timeline of Distributed Muon} -\end{figure} - -\begin{itemize} - \item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient - \item \texttt{AG[i]} : AllGather i-th gradient - \item \texttt{G[i]} : Gather i-th gradient - \item \texttt{SC[i]} : Scatter i-th gradient -\end{itemize} -\clearpage -\section*{Algorithm} - -\subsection*{Parallel Muon} - -\begin{algorithm} -\caption{Parallel Muon} -\textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$ -\begin{algorithmic}[1] -\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$} -\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$ -\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$} -\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$ -\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$} -\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$ -\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$} -\If{$\mathbf{r}$ == $\mathbf{R}$} - \State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$ -\Else - \State $\mathbf{u} \gets None$ -\EndIf - -\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP} -\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$ -\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$} -\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$ -\State \textbf{return $\mathbf{p'}$} -\end{algorithmic} -\end{algorithm} - -We eliminate redundant computation by assigning each parameter to a specific GPU. - -However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding. - -\begin{figure}[H] - \centering - \includegraphics[width=1.0\textwidth]{naive_execution.png} - \caption*{Execution timeline of Parallel Muon} -\end{figure} - -\subsection*{Scheduling Sub-Operations} - -We can schedule the whole sub-operations as follows, due to the following reasons: -\begin{itemize} - \item There are no dependencies between parameters. - \item GPUs can execute computation and communication concurrently. -\end{itemize} - -\begin{figure}[H] - \centering - \includegraphics[width=1.0\textwidth]{pipelined.png} - \caption*{Execution timeline of re-scheduled Parallel Muon} -\end{figure} - -We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete. - -\textbf{[Algorithm]} (To be written) -\clearpage -\subsection*{Load Balancing} - -If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\ -To mitigate this, we apply load balancing based on per-parameter FLOPs. - -\vspace{1em} -\textbf{Imbalanced (Round Robin)} - -\begin{figure}[H] - \centering - \includegraphics[width=1.0\textwidth]{imbalance.png} -\end{figure} - -\textbf{After Load Balancing} - -\begin{figure}[H] - \centering - \includegraphics[width=1.0\textwidth]{balanced.png} -\end{figure} - -\section*{Implementation} - -The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}. -To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations. - -Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity. - -\section*{Evaluation} -We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step. - -\begin{table}[H] - \centering - \begin{tabular}{@{}lllll@{}} - \toprule - Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\ - \midrule - 10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\ - \bottomrule - \end{tabular} -\end{table} -Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead. - -\end{document} \ No newline at end of file diff --git a/docs/muon/naive_execution.png b/docs/muon/naive_execution.png deleted file mode 100644 index e8f3c4ce721cda02eb95f569c58739d36008b525..0000000000000000000000000000000000000000 --- a/docs/muon/naive_execution.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:eaacd3625f33cee9735ed0d96b95f98c696dfc771976be970a38c991e2ce84ab -size 42729 diff --git a/docs/muon/parallel_muon.pdf b/docs/muon/parallel_muon.pdf deleted file mode 100644 index 8321c572edfae32e963a013d69187d58971fc27e..0000000000000000000000000000000000000000 --- a/docs/muon/parallel_muon.pdf +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c1a88537a50ecc3db52d6e148d3513b31e2c9810c09df0da8f6aff03fa652fe5 -size 654538 diff --git a/docs/muon/pipelined.png b/docs/muon/pipelined.png deleted file mode 100644 index 7e3d51f98c8f2e501704298c6ec48dca08203884..0000000000000000000000000000000000000000 --- a/docs/muon/pipelined.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a1f8043cc58e7d8d9da5694ad7bccd1b9fe0210349b9aa9a62652a97f75cf097 -size 64316 diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000000000000000000000000000000000000..35ae009acbca01163cc3a3cddac000fa957c0a4f --- /dev/null +++ b/test/README.md @@ -0,0 +1,42 @@ +# Muon Optimizer Test + +This directory contains a test script for the **Muon optimizer**. + +## Prerequisites + +- **GPU Requirement** + - All tests require **8 GPUs** by default. + - If you have fewer GPUs available: + - Modify the parallelism configurations in `test_muon.py`. + +- **Model Access** + - The tests require access to the private model repository: + - `Motif-Technologies/Motif-2.6B-4layer-random` on Hugging Face. + - Set your Hugging Face token via the environment variable `HF_TOKEN`. + - If you don’t have access, please contact the maintainer. + +- **Using a Different Model (Optional)** + - You may modify the test to use a different model by: + - Updating the model name in `conftest.py::inputs`. + - Adjusting the tensor parallel rules in `utils.py::_apply_tp`. + +## Usage + +- To execute the test with 8 GPUs, simply run: + +```bash +./run_test.sh +``` + +- To check the other available options, you can use: + +```bash +pytest --help +... +Custom options: + --measure-perf Measure execution time and peak memory usage during optimizer step. + --do-profile Enable profiling during tests. + --skip-verify Skip verification of optimizer step correctness with sequential implementation. + This can be useful when GPU memory is limited. +... +``` diff --git a/test/test_muon/__init__.py b/test/__init__.py similarity index 100% rename from test/test_muon/__init__.py rename to test/__init__.py diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..15177262eb39e8f60c95742bb372faf2f3857ae9 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,124 @@ +import logging + +import pytest +import torch +import torch.distributed as dist +from packaging import version +from transformers import AutoModelForCausalLM + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +SEED = 0xdeadbeef + + +def pytest_addoption(parser): + parser.addoption( + "--measure-perf", + action="store_true", + default=False, + help= + "Measure execution time and peak memory usage during optimizer step.", + ) + + parser.addoption( + "--do-profile", + action="store_true", + default=False, + help="Enable profiling during tests.", + ) + + parser.addoption( + "--skip-verify", + action="store_true", + default=False, + help= + "Skip verification of optimizer step correctness with sequential implementation.\n" + "This can be useful when GPU memory is limited.", + ) + + +def pytest_configure(config): + if config.getoption( + "--do-profile") and not config.getoption("--measure-perf"): + raise pytest.UsageError( + "--do-profile requires --measure-perf. Please enable both flags.") + + +@pytest.fixture(scope="session") +def measure_perf(request): + return request.config.getoption("--measure-perf") + + +@pytest.fixture(scope="session") +def do_profile(request): + return request.config.getoption("--do-profile") + + +@pytest.fixture(scope="session") +def skip_verify(request): + return request.config.getoption("--skip-verify") + + +@pytest.fixture(scope="session", autouse=True) +def init_dist(request): + if version.parse(torch.__version__) < version.parse("2.8"): + pytest.skip("torch>=2.8.0 is required for parallel muon") + return + + try: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + except Exception as e: + print(f"Failed to initialize torch.distributed: {e}") + pytest.skip("Failed to initialize torch.distributed") + + if dist.get_world_size() != 8: + pytest.skip("Need 8 processes in dist group. " + "You can run with `torchrun --nproc-per-node=8 " + "--local-ranks-filter 0 -m pytest " + "test_rms_norm_sequence_parallel.py`." + "To run with less than 8 gpus, modify " + "the test cases accordingly.") + + yield + dist.destroy_process_group() + + +@pytest.fixture(scope="session") +def inputs(): + """Load Motif-2.6B model and generate random gradients for testing. + Returns: + tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]: + - torch.nn.Module: The Motif-2.6B model. + - list[torch.Tensor]: A list of random gradients for each model parameter. + - dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits. + """ + model_name = "Motif-Technologies/Motif-2.6B-4layer-random" + + torch.manual_seed(SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + + model = AutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + ) + logger.info( + f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)" + ) + + grads: list[torch.Tensor] = [] + for param in model.parameters(): + grad = torch.randn_like(param, device=param.device, dtype=param.dtype) + grads.append(grad) + + qk_logits: dict[int, torch.Tensor] = { + i: + torch.randn(model.config.num_attention_heads, + device=model.device, + dtype=torch.bfloat16) + for i in range(model.config.num_hidden_layers) + } + + return [model, grads, qk_logits] diff --git a/test/optimizer b/test/optimizer new file mode 120000 index 0000000000000000000000000000000000000000..c7ff828a90e1c2a67535184c5e89724fb52bea24 --- /dev/null +++ b/test/optimizer @@ -0,0 +1 @@ +../torch-ext/optimizer/ \ No newline at end of file diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..11c72fa2e2812b16b1c2e92fb0d78cb4adbda2e5 --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +log_cli = true +log_cli_level = INFO diff --git a/test/run_test.sh b/test/run_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..2c2bd5b362a36dd4facebee1454eb7b4809118f1 --- /dev/null +++ b/test/run_test.sh @@ -0,0 +1 @@ +torchrun --nproc-per-node=8 --local-ranks-filter=0 -m pytest test_muon.py diff --git a/test/test_muon.py b/test/test_muon.py new file mode 100644 index 0000000000000000000000000000000000000000..58f0ad64d2e9eded31d0b142b0c12210b73a3803 --- /dev/null +++ b/test/test_muon.py @@ -0,0 +1,235 @@ +import copy +import logging +import time +from contextlib import nullcontext + +import pytest +import torch +import torch.distributed as dist +from optimizer.muon import Muon, get_default_muon_param_groups +from torch.distributed.tensor import DTensor, Replicate +from torch.profiler import ProfilerActivity, profile + +from .utils import (ParallelDims, assert_params_equal, parallelize_motif, + parallelize_qk_logits) + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def apply_muon_step( + model: torch.nn.Module, + parallel_dims: ParallelDims | None, + grads: list[torch.Tensor], + warmup_step: int, + chunk_size: int, + qk_logits: dict[int, torch.Tensor] | None = None, + use_distributed_muon: bool = False, + measure_perf: bool = False, + do_profile: bool = False, +) -> tuple[torch.nn.Module, tuple[float, float] | None]: + """ apply single Muon step with optional QK clipping """ + + # 1. Apply gradients to model parameters + assert len(grads) == len(list(model.parameters())) + for grad, param in zip(grads, model.parameters()): + grad = grad.to(param.device) + if isinstance(param.data, DTensor): + unsharded_grad = DTensor.from_local( + grad, + device_mesh=param.data.device_mesh, + placements=[Replicate()] * param.data.device_mesh.ndim, + ) + sharded_grad = unsharded_grad.redistribute( + device_mesh=param.data.device_mesh, + placements=param.data.placements) + param.grad = sharded_grad + else: + param.grad = grad + + # 2. Setup Muon optimizer + params = get_default_muon_param_groups(model) + clip_config = dict({ + "q_indices": + list(range(model.config.num_attention_heads)), + "k_indices": + list(range(model.config.num_attention_heads)), + "head_dim": + model.config.hidden_size // model.config.num_attention_heads, + "threshold": + 0.5 + }) + optim = Muon( + params=params, + clip_config=clip_config if qk_logits is not None else None, + none_grad=False, + warmup_step=warmup_step, + chunk_size=chunk_size, + use_distributed_muon=use_distributed_muon, + ) + + optim.step(qk_logits=qk_logits) + + timing_result: tuple[float, float] | None = None + + if measure_perf: + # extra warm up + optim.step(qk_logits=qk_logits) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + num_iters = 20 + current_mem = torch.cuda.memory_allocated() + + if do_profile: + context = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True) + else: + context = nullcontext() + + with context as prof: + for _i in range(num_iters): + optim.step(qk_logits=qk_logits) + + end.record() + end.synchronize() + + if prof is not None and dist.get_rank() == 0: + date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + profile_name = "trace" + profile_name += f"_{date}" + profile_name += f"_{parallel_dims}" + profile_name += f"_{chunk_size}" + profile_name += f"_{warmup_step}" + profile_name += f"_{qk_logits is not None}" + profile_name += f"_{use_distributed_muon}" + + prof.export_chrome_trace(f"{profile_name}.json") + + peak_memory = torch.cuda.max_memory_allocated() - current_mem + + elapsed_time_ms = start.elapsed_time(end) / num_iters + + timing_result = (elapsed_time_ms, peak_memory) + + return model, timing_result + + +@pytest.fixture(scope="session") +def sequential_muon_result( + skip_verify, # from conftest.py + inputs # from conftest.py +) -> dict[bool, torch.nn.Module]: + """Run Muon optimizer to sequential model for baseline results.""" + if skip_verify: + logger.info("Skipping verification tests as per user request") + return None + + model, grads, qk_logits = inputs + + result = apply_muon_step( + model=copy.deepcopy(model).cuda(), + parallel_dims=None, + grads=grads, + warmup_step=-1, + chunk_size=-1, + qk_logits=None, + )[0].cpu() + + result_qk_clip = apply_muon_step( + model=copy.deepcopy(model).cuda(), + parallel_dims=None, + grads=grads, + warmup_step=-1, + chunk_size=-1, + qk_logits=qk_logits, + )[0].cpu() + + return { + False: result, + True: result_qk_clip, + } + + +OVERLAP_STEPS = [5] +CHUNK_SIZES = [8] + + +@pytest.mark.parametrize("parallel_dims", [ + pytest.param(ParallelDims(8, 1, 1), id="base"), + pytest.param(ParallelDims(1, 8, 1), id="fsdp"), + pytest.param(ParallelDims(2, 4, 1), id="hsdp"), + pytest.param(ParallelDims(1, 1, 8), id="tp"), + pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), + pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), +]) +@pytest.mark.parametrize("apply_qk_clip", [False, True]) +@pytest.mark.parametrize("use_distributed_muon", [False]) +@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) +@pytest.mark.parametrize("chunk_size", CHUNK_SIZES) +def test_parallel_muon( + request, + sequential_muon_result: dict[bool, torch.nn.Module], + parallel_dims: ParallelDims, + apply_qk_clip: bool, + use_distributed_muon: bool, + warmup_step: int, + chunk_size: int, + inputs: tuple[torch.nn.Module, list[torch.Tensor], + dict[int, torch.Tensor]], # from conftest.py + measure_perf, # from conftest.py + do_profile, # from conftest.py +) -> None: + if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: + pytest.skip("Distributed Muon does not effected by chunk size") + if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: + pytest.skip("Distributed Muon does not effected by warmup step") + + model, grads, qk_logits = inputs + + if not apply_qk_clip: + qk_logits = None + + # Deepcopy the model to avoid in-place modification + model = copy.deepcopy(model).cuda() + + parallelized_model = parallelize_motif(model, parallel_dims) + + if qk_logits is not None: + # Deepcopy the qk logits to avoid in-place modification + qk_logits = copy.deepcopy(qk_logits) + qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) + + parallelized_model, timing_result = apply_muon_step( + model=parallelized_model, + parallel_dims=parallel_dims, + grads=grads, + warmup_step=warmup_step, + chunk_size=chunk_size, + qk_logits=qk_logits, + use_distributed_muon=use_distributed_muon, + measure_perf=measure_perf, + do_profile=do_profile, + ) + + if measure_perf: + assert timing_result is not None + avg_time_ms, peak_memory = timing_result + logger.info( + f"\nParallel dims: {parallel_dims}, " + f"\nUse distributed Muon: {use_distributed_muon}, " + f"\nApply QK clip: {apply_qk_clip} => " + f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" + f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," + ) + + if sequential_muon_result is None: + logger.info("Skipping correctness check as sequential result is None") + elif measure_perf: + logger.info("Skipping correctness check as timing is enabled") + else: + assert_params_equal(parallelized_model, + sequential_muon_result[apply_qk_clip]) diff --git a/test/test_muon/README.md b/test/test_muon/README.md deleted file mode 100644 index ba95986801265fe049581ac970725008cea7b48a..0000000000000000000000000000000000000000 --- a/test/test_muon/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Muon Optimizer Test - -This directory contains a test script for the **Muon optimizer**. - -To execute the test, simply run: - -```bash -# By default, the test will use 8 GPUs. -./run_test.sh -``` - -The number of GPUs can be controlled with the NGPU environment variable. -For example, to run with 4 GPUs: - -```bash -NGPU=4 ./run_test.sh -``` - -## Limitations: -- Multi-node execution is not supported yet. -- Ensure that the specified number of GPUs is available on your machine before running. diff --git a/test/test_muon/optimizer b/test/test_muon/optimizer deleted file mode 120000 index 908a9c4bc91cfcb6cae8cf7de49671ee7e249017..0000000000000000000000000000000000000000 --- a/test/test_muon/optimizer +++ /dev/null @@ -1 +0,0 @@ -../../torch-ext/optimizer/ \ No newline at end of file diff --git a/test/test_muon/run_test.sh b/test/test_muon/run_test.sh deleted file mode 100755 index 8ee6e5003b5ca8f466c3cbe8bd6031a8c1aebdf1..0000000000000000000000000000000000000000 --- a/test/test_muon/run_test.sh +++ /dev/null @@ -1,2 +0,0 @@ -NGPU=${NGPU:-"8"} -torchrun --nproc-per-node=8 test.py diff --git a/test/test_muon/test.py b/test/test_muon/test.py deleted file mode 100644 index d9ab7ce581dbd83b9cc497554b75467d615e80c3..0000000000000000000000000000000000000000 --- a/test/test_muon/test.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging - -import torch -import torch.distributed as dist -from optimizer.muon import Muon, get_default_muon_param_groups -from torch.distributed.fsdp import FSDPModule, fully_shard -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import Replicate -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -def load_model(fsdp: bool) -> torch.nn.Module: - model_name = "Motif-Technologies/Motif-2.6B" - model = AutoModelForCausalLM.from_pretrained( - model_name, - trust_remote_code=True, - ).bfloat16().cuda() - - random_grads = [] - for param in model.parameters(): - random_grad = torch.randn_like(param, - device=param.device, - dtype=param.dtype) - random_grads.append(random_grad) - - if fsdp: - for layer in model.model.layers: - fully_shard(layer) - layer.reshard() - fully_shard(model) - model.reshard() - - for i, param in enumerate(model.parameters()): - if isinstance(param.data, DTensor): - unsharded_grad = DTensor.from_local( - random_grads[i], - device_mesh=param.data.device_mesh, - placements=[Replicate()] * param.data.device_mesh.ndim, - ) - sharded_grad = unsharded_grad.redistribute( - device_mesh=param.data.device_mesh, - placements=param.data.placements) - param.grad = sharded_grad - else: - param.grad = random_grads[i] - - return model - - -def run_muon(fsdp: bool, qk_clip: bool, seed: int) -> torch.nn.Module: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - model = load_model(fsdp=fsdp) - params = get_default_muon_param_groups(model) - qk_logits = None - if qk_clip: - qk_logits = { - i: torch.rand(model.config.num_attention_heads) - for i in range(model.config.num_hidden_layers) - } - optim = Muon( - params=params, - clip_config={ - "q_indices": list(range(model.config.num_attention_heads)), - "k_indices": list(range(model.config.num_attention_heads)), - "head_dim": - model.config.hidden_size // model.config.num_attention_heads, - "threshold": 0.5 - }) - optim.step(qk_logits=qk_logits) - - return model - - -def run_case(qk_clip: bool, seed: int = 0): - parallel_muon_result = run_muon(fsdp=True, qk_clip=qk_clip, seed=seed) - sequential_muon_result = run_muon(fsdp=False, qk_clip=qk_clip, seed=seed) - label = f"qk_clip={'ON' if qk_clip else 'OFF'}" - success = compare_results(parallel_muon_result, - sequential_muon_result, - label=label) - - return success, label - - -def test_muon(): - - base_result = run_case(qk_clip=False, seed=0) - clip_result = run_case(qk_clip=True, seed=0) - - for success, label in [base_result, clip_result]: - if success: - logger.info(f"[{label}] Models match") - - -def compare_results(parallel_muon_result: torch.nn.Module, - sequential_muon_result: torch.nn.Module, - label: str) -> None: - success = True - for (name_p, p), (name_s, - s) in zip(parallel_muon_result.named_parameters(), - sequential_muon_result.named_parameters()): - p = p.data.full_tensor() - s = s.data - # Parallel Muon should exactly match Sequential Muon - if torch.abs(p - s).max() > 0: - max_diff_index = torch.argmax(torch.abs(p - s)) - logger.info(f"Models differ at parameter {name_p}") - success = False - - return success - - -if __name__ == "__main__": - dist.init_process_group(backend="nccl") - torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) - test_muon() diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..494c09de1f3241a5ef5028e47f21d17c7342645a --- /dev/null +++ b/test/utils.py @@ -0,0 +1,241 @@ +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import fully_shard +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard +from torch.distributed.tensor.parallel import (ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module) + + +@dataclass +class ParallelDims: + dp_replicate_degree: int + dp_shard_degree: int + tp_degree: int + + def __str__(self) -> str: + return (f"dp_replicate-{self.dp_replicate_degree}_" + f"dp_shard-{self.dp_shard_degree}_" + f"tp-{self.tp_degree}") + + +def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: + """Constructs a DeviceMesh based on the given parallel dimensions. + + Args: + parallel_dims (ParallelDims): The parallelism configuration. + + Returns: + DeviceMesh: The constructed device mesh. + """ + world_size = dist.get_world_size() + expected_devices = (parallel_dims.dp_replicate_degree * + parallel_dims.dp_shard_degree * + parallel_dims.tp_degree) + if world_size < expected_devices: + raise ValueError( + f"Not enough devices: found {world_size}, " + f"but expected at least {expected_devices}. ({parallel_dims})") + + degrees = [ + parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, + parallel_dims.tp_degree + ] + dim_names = ["dp_replicate", "dp_shard", "tp"] + + mesh_shape = [] + mesh_dim_names = [] + for degree, dim_name in zip(degrees, dim_names): + if degree > 1: + mesh_shape.append(degree) + mesh_dim_names.append(dim_name) + + device_mesh = dist.init_device_mesh("cuda", + mesh_shape, + mesh_dim_names=mesh_dim_names) + + return device_mesh + + +def _apply_tp( + model: torch.nn.Module, + tp_mesh: DeviceMesh, +): + """Apply tensor parallelism.""" + + # Layer names must match Motif model definition + # https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py + + assert type(model).__name__ == "MotifForCausalLM" + + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + + parallelize_module( + model, + tp_mesh, + { + # This below separate tie_weights and make difficult to compare + # the answer with non-tensor-parallel version. + # TODO(jeesoo): check correctness for training semantic + + #"model.embed_tokens": + #RowwiseParallel( + # input_layouts=Replicate(), + # output_layouts=Shard(1), + #), + "model.norm": + SequenceParallel(), + "output": + ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1), # loss_parallel + use_local_output=False, + ), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.model.layers: + layer_plan = { + "input_layernorm": + SequenceParallel(), + "post_attention_layernorm": + SequenceParallel(), + "self_attn": + PrepareModuleInput( + # x, freqs_cis, attention_mask, position_ids, qk_clip + input_layouts=(Shard(1), Replicate(), None, None, None), + desired_input_layouts=(Replicate(), Replicate(), None, None, + None), + ), + "self_attn.q_proj": + ColwiseParallel(), + "self_attn.k_proj": + ColwiseParallel(), + "self_attn.v_proj": + ColwiseParallel(), + "self_attn.o_proj": + RowwiseParallel(output_layouts=Shard(1)), + "mlp": + PrepareModuleInput( + input_layouts=(Shard(1), ), + desired_input_layouts=(Replicate(), ), + ), + "mlp.gate_proj": + ColwiseParallel(), + "mlp.down_proj": + RowwiseParallel(output_layouts=Shard(1)), + "mlp.up_proj": + ColwiseParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + +def _apply_fsdp( + model: torch.nn.Module, + dp_mesh: DeviceMesh, +): + for layer in model.model.layers: + fully_shard(layer, mesh=dp_mesh) + layer.reshard() + fully_shard(model, mesh=dp_mesh) + model.reshard() + + +def parallelize_motif(model: torch.nn.Module, + parallel_dims: ParallelDims) -> torch.nn.Module: + """Parallelize the Motif model according to the given parallel dimensions. + + Args: + model (torch.nn.Module): The Motif model to be parallelized. + parallel_dims (ParallelDims): The parallelism configuration. + + Returns: + torch.nn.Module: The parallelized Motif model. + """ + + mesh = _construct_device_mesh(parallel_dims) + + if parallel_dims.tp_degree > 1: + _apply_tp(model, mesh["tp"]) + + if parallel_dims.dp_shard_degree > 1: + if parallel_dims.dp_replicate_degree > 1: + dp_dim_names = ("dp_replicate", "dp_shard") + else: + dp_dim_names = ("dp_shard", ) + _apply_fsdp(model, mesh[dp_dim_names]) + + return model + + +def parallelize_qk_logits( + qk_logits: dict[int, torch.Tensor], + parallel_dims: ParallelDims, +) -> dict[int, torch.Tensor]: + """Parallelize the QK logits according to the given parallel dimensions. + + Args: + qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized. + parallel_dims (ParallelDims): The parallelism configuration. + + Returns: + dict[int, torch.Tensor]: The parallelized QK logits. + """ + + mesh = _construct_device_mesh(parallel_dims) + + if parallel_dims.tp_degree > 1: + tp_rank = mesh["tp"].get_local_rank() + placements = [ + Shard(0) if dim_name == "tp" else Replicate() + for dim_name in mesh.mesh_dim_names + ] + for layer_idx, logits in qk_logits.items(): + assert logits.size(0) % parallel_dims.tp_degree == 0 + local_logits = logits.chunk(parallel_dims.tp_degree, + dim=0)[tp_rank].contiguous() + + qk_logits[layer_idx] = DTensor.from_local( + local_tensor=local_logits, + device_mesh=mesh, + placements=placements, + ) + + return qk_logits + + +def assert_params_equal(actual: torch.nn.Module, + expected: torch.nn.Module) -> None: + """Asserts that the parameters of two models are equal. + + Args: + actual (torch.nn.Module): The actual model. + expected (torch.nn.Module): The expected model. + Returns: + None + """ + + def get_full_param(param: torch.nn.Parameter) -> torch.Tensor: + if isinstance(param.data, DTensor): + return param.data.full_tensor() + return param.data + + for (name_p, p), (name_s, s) in zip(actual.named_parameters(), + expected.named_parameters()): + p = get_full_param(p.cuda()) + s = get_full_param(s.cuda()) + + torch.testing.assert_close(p, s, atol=0, rtol=0) diff --git a/torch-ext/optimizer/distributed/utils.py b/torch-ext/optimizer/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a --- /dev/null +++ b/torch-ext/optimizer/distributed/utils.py @@ -0,0 +1,174 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}.") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py index f66a1f41fd8b8cbdb453e48b2267ddb0450e5af9..cfbcca71741be70048bfd290c62148b2aceda631 100644 --- a/torch-ext/optimizer/muon.py +++ b/torch-ext/optimizer/muon.py @@ -1,18 +1,24 @@ import logging import math import types +from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union, cast +from typing import Any, cast import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor from .matmul_transpose_triton import matmul_transpose_assign logger = logging.getLogger(__name__) COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 # This code snippet is a modified version adapted from the following GitHub repositories: @@ -62,23 +68,39 @@ def _zeropower_via_newtonschulz5(G, steps): @dataclass class _muon_state: # TODO: use Optional - worker_rank: int | None = None + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None gathered_grad: torch.Tensor | None = None scattered_u: DTensor | None = None computed_u: torch.Tensor | None = None gather_event: torch.cuda.Event | None = None compute_event: torch.cuda.Event | None = None scatter_event: torch.cuda.Event | None = None - process_group = None - qk_clip_state = None -def split_elems_for_src(param, src_rank, num_ranks) -> int: - rows = param.shape[0] - cols = int(param.numel() // rows) - base, rem = divmod(rows, num_ranks) - my_rows = base + (1 if src_rank < rem else 0) - return my_rows * cols +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel @torch.no_grad() @@ -91,8 +113,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: - num_ranks = dist.get_world_size(group=state.process_group) - state.gathered_grad = torch.empty(p.grad.numel(), + state.gathered_grad = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: @@ -121,11 +142,11 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, state = param_to_state[id(p)] dst = state.worker_rank assert dst < num_ranks - shard_elems = split_elems_for_src(p, rank, num_ranks) + shard_elems = numel_for_rank(p, rank, state) g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous().view(-1) + g = g.to_local().to(COMM_DTYPE).contiguous() assert g.numel() == shard_elems - per_dst[dst].append(g) + per_dst[dst].append(g.view(-1)) send_counts[dst] += shard_elems assert any( @@ -148,13 +169,18 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - total += split_elems_for_src(p, src, num_ranks) + total += numel_for_rank(p, src, state) recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") dist.all_to_all_single( recv_buf, send_buf, @@ -179,7 +205,6 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, comm_stream.wait_event(alloc_event) off = 0 - write_offsets = {id(p): 0 for p in owned_params} for src in range(num_ranks): if recv_counts[src] == 0: continue @@ -189,22 +214,28 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank - n = split_elems_for_src(p, src, num_ranks) + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() assert n > 0 sg = recv_buf.narrow(0, off + inner_off, n) - woff = write_offsets[id(p)] - dst = state.gathered_grad.narrow(0, woff, n) + sg = sg.reshape_as(dst) dst.copy_(sg) - write_offsets[id(p)] += n inner_off += n off += block for p in params: state = param_to_state[id(p)] if state.worker_rank == rank: - state.gathered_grad = state.gathered_grad.view_as(p) state.gather_event = torch.cuda.Event() state.gather_event.record(comm_stream) else: @@ -277,14 +308,19 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): assert state.computed_u is not None - u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1) + u_full = state.computed_u.to(COMM_DTYPE).contiguous() offset = 0 for dst in range(num_ranks): - n = split_elems_for_src(p, dst, num_ranks) + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() assert n > 0 - su = u_full.narrow(0, offset, n) per_dst[dst].append(su) send_counts[dst] += n offset += n @@ -313,7 +349,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - total += split_elems_for_src(p, rank, num_ranks) + total += numel_for_rank(p, rank, state) recv_counts[src] = total recv_total = sum(recv_counts) @@ -357,7 +393,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): state = param_to_state[id(p)] if state.worker_rank != src: continue - n = split_elems_for_src(p, rank, num_ranks) + n = numel_for_rank(p, rank, state) assert n > 0 flat_local = recv_buf.narrow(0, off + inner_off, @@ -398,11 +434,23 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, state.scattered_u = None u_dtensor = None - scales_full = Muon._compute_scales(p, state.qk_clip_state) + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None if scales_full is not None: - num_ranks = dist.get_world_size(group=state.process_group) - local_rank = dist.get_rank(group=state.process_group) - scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank] + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] scales_local = DTensor.from_local( scales_local, placements=p.placements, @@ -478,11 +526,11 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: List[int] # which heads to consider for clipping + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping head_dim: int # from config threshold: float # from config - logit: Optional[torch.Tensor] + logit: torch.Tensor | None class Muon(torch.optim.Optimizer): @@ -525,11 +573,16 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 } - overlap_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher overlap_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. """ def __init__(self, @@ -549,7 +602,9 @@ class Muon(torch.optim.Optimizer): "head_dim": 128, "threshold": 100 }, - overlap_step=5): + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -579,7 +634,9 @@ class Muon(torch.optim.Optimizer): self.compute_stream = torch.cuda.Stream() self.debug = debug self.clip_config = clip_config - self.overlap_step = overlap_step + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -597,6 +654,12 @@ class Muon(torch.optim.Optimizer): adjusted_lr = lr * adjusted_ratio return adjusted_lr + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. @@ -604,26 +667,13 @@ class Muon(torch.optim.Optimizer): assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." - if p.placements == (Shard(dim=0), ): - # Case for FSDP - process_group = p.device_mesh.get_group(mesh_dim=0) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0) - elif p.placements == (Replicate(), Shard(dim=0)): - # Case for HSDP - process_group = p.device_mesh.get_group(mesh_dim=1) - if self.rank is None: - self.rank = dist.get_rank(group=process_group) - else: - assert self.rank == dist.get_rank(group=process_group) - for i, shard_mesh in enumerate(p.device_mesh.mesh): - if self.rank in shard_mesh: - return shard_mesh, p.device_mesh.get_group(mesh_dim=1) - else: - raise ValueError(f"Unsupported placements ({p.placements}).") + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} @@ -655,23 +705,32 @@ class Muon(torch.optim.Optimizer): ordered_params = list(params_sorted) round_robin = 0 - mesh = None - shard_mesh = None - process_group = None + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + for n, p in zip(ordered_names, ordered_params): - if mesh is None: - mesh = p.device_mesh - shard_mesh, process_group = self.get_shard_mesh(p) - elif mesh != p.device_mesh: + if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") - num_ranks = dist.get_world_size(group=process_group) - param_to_state[id(p)] = _muon_state() - param_to_state[id( - p)].worker_rank = shard_mesh[round_robin].item() % num_ranks - param_to_state[id(p)].process_group = process_group + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = self.get_qk_clip_info(n, qk_logits) - param_to_state[id(p)].qk_clip_state = qk_clip_state - round_robin = (round_robin + 1) % len(shard_mesh) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) return param_to_state, ordered_params @@ -705,10 +764,73 @@ class Muon(torch.optim.Optimizer): qk_clip_state = self.get_qk_clip_info(n, qk_logits) - scales_full = self._compute_scales(p, qk_clip_state) + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + if qk_logits is not None: + raise NotImplementedError("QK clipping is not supported yet") + + if isinstance(params[0], DTensor): + shard_mesh, _, shard_placements = construct_shard_mesh( + placements=params[0].placements, + mesh=params[0].device_mesh, + ) + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + # Gather G + if isinstance(p.data, DTensor): + g = g.full_tensor() + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + if isinstance(p.data, DTensor): + slices = get_slices_of_dtensor( + target=p, + local_rank=dist.get_rank(), + shard_mesh=shard_mesh, + shard_placements=shard_placements, + ) + u_shard = u[slices] + u = DTensor.from_local( + u_shard, + device_mesh=p.device_mesh, + placements=p.placements, + ) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + def _update_g(self, p, g, group, momentum): # calc update state = self.state[p] @@ -727,6 +849,9 @@ class Muon(torch.optim.Optimizer): p.data.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + head_dim = self.clip_config.get('head_dim') threshold = self.clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) @@ -737,6 +862,11 @@ class Muon(torch.optim.Optimizer): indices_key = 'q_indices' if 'q' in kind else 'k_indices' indices = self.clip_config.get(indices_key, []) or [] + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + return QKClipInfo( kind=kind, indices=indices, @@ -835,22 +965,28 @@ class Muon(torch.optim.Optimizer): _update_param(p, state, lr, adjusted_lr, weight_decay, self.rank, self.compute_stream) - chunk_size = dist.get_world_size(param_to_state[id( - params[0])].process_group) + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") # Wait grad update self.comm_stream.wait_stream(torch.cuda.current_stream()) - overlap_step = self.overlap_step - for i in range(0, overlap_step): + warmup_step = self.warmup_step + for i in range(0, warmup_step): enqueue_all2all_gather(i * chunk_size, chunk_size) enqueue_computes(i * chunk_size, chunk_size) for i in range(0, len(params) + chunk_size - 1, chunk_size): enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) enqueue_update_param(i, chunk_size) - enqueue_computes(i + overlap_step * chunk_size, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) # Wait the last update_param to finish torch.cuda.current_stream().wait_stream(self.compute_stream) @@ -866,7 +1002,7 @@ class Muon(torch.optim.Optimizer): amsgrad: bool, beta1: float, beta2: float, - lr: Union[float, torch.Tensor], + lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, @@ -876,10 +1012,10 @@ class Muon(torch.optim.Optimizer): # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ({ + lr_dict: DeviceDict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) + None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, @@ -926,6 +1062,159 @@ class Muon(torch.optim.Optimizer): maximize=maximize, ) + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + param_tensors = [] + name_dtensors = [] + name_tensors = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(name_dtensors) == len(param_dtensors) + for n, p in zip(name_dtensors, param_dtensors): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + + for _, (names, params) in placement_to_params.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -943,127 +1232,9 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - params = group["params"] - if group["use_muon"]: - ############################ - # Muon # - ############################ - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError( - f"Unsupported parameter type: {type(p.data)}") - - if self.debug: - print( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors", - flush=True, - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.parallel( - name_dtensors, - param_dtensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - + self._step_muon(group, qk_logits=qk_logits) else: - ############################ - # AdamW backup # - ############################ - - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) + self._step_adamw(group) return loss