import logging import torch import torch.distributed as dist from optimizer.muon import Muon, get_default_muon_param_groups from torch.distributed.fsdp import FSDPModule, fully_shard from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import Replicate from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def load_model(fsdp: bool) -> torch.nn.Module: model_name = "Motif-Technologies/Motif-2.6B" model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, ).bfloat16().cuda() random_grads = [] for param in model.parameters(): random_grad = torch.randn_like(param, device=param.device, dtype=param.dtype) random_grads.append(random_grad) if fsdp: for layer in model.model.layers: fully_shard(layer) layer.reshard() fully_shard(model) model.reshard() for i, param in enumerate(model.parameters()): if isinstance(param.data, DTensor): unsharded_grad = DTensor.from_local( random_grads[i], device_mesh=param.data.device_mesh, placements=[Replicate()] * param.data.device_mesh.ndim, ) sharded_grad = unsharded_grad.redistribute( device_mesh=param.data.device_mesh, placements=param.data.placements) param.grad = sharded_grad else: param.grad = random_grads[i] return model def run_muon(fsdp: bool, qk_clip: bool, seed: int) -> torch.nn.Module: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) model = load_model(fsdp=fsdp) params = get_default_muon_param_groups(model) qk_logits = None if qk_clip: qk_logits = { i: torch.rand(model.config.num_attention_heads) for i in range(model.config.num_hidden_layers) } optim = Muon( params=params, clip_config={ "q_indices": list(range(model.config.num_attention_heads)), "k_indices": list(range(model.config.num_attention_heads)), "head_dim": model.config.hidden_size // model.config.num_attention_heads, "threshold": 0.5 }) optim.step(qk_logits=qk_logits) return model def run_case(qk_clip: bool, seed: int = 0): parallel_muon_result = run_muon(fsdp=True, qk_clip=qk_clip, seed=seed) sequential_muon_result = run_muon(fsdp=False, qk_clip=qk_clip, seed=seed) label = f"qk_clip={'ON' if qk_clip else 'OFF'}" success = compare_results(parallel_muon_result, sequential_muon_result, label=label) return success, label def test_muon(): base_result = run_case(qk_clip=False, seed=0) clip_result = run_case(qk_clip=True, seed=0) for success, label in [base_result, clip_result]: if success: logger.info(f"[{label}] Models match") def compare_results(parallel_muon_result: torch.nn.Module, sequential_muon_result: torch.nn.Module, label: str) -> None: success = True for (name_p, p), (name_s, s) in zip(parallel_muon_result.named_parameters(), sequential_muon_result.named_parameters()): p = p.data.full_tensor() s = s.data # Parallel Muon should exactly match Sequential Muon if torch.abs(p - s).max() > 0: max_diff_index = torch.argmax(torch.abs(p - s)) logger.info(f"Models differ at parameter {name_p}") success = False return success if __name__ == "__main__": dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) test_muon()