|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from optimizer.muon import Muon, get_default_muon_param_groups |
|
|
from torch.distributed.fsdp import FSDPModule, fully_shard |
|
|
from torch.distributed.tensor import DTensor |
|
|
from torch.distributed.tensor.placement_types import Replicate |
|
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, |
|
|
PreTrainedTokenizerBase) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
def load_model(fsdp: bool) -> torch.nn.Module: |
|
|
model_name = "Motif-Technologies/Motif-2.6B" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
).bfloat16().cuda() |
|
|
|
|
|
random_grads = [] |
|
|
for param in model.parameters(): |
|
|
random_grad = torch.randn_like(param, |
|
|
device=param.device, |
|
|
dtype=param.dtype) |
|
|
random_grads.append(random_grad) |
|
|
|
|
|
if fsdp: |
|
|
for layer in model.model.layers: |
|
|
fully_shard(layer) |
|
|
layer.reshard() |
|
|
fully_shard(model) |
|
|
model.reshard() |
|
|
|
|
|
for i, param in enumerate(model.parameters()): |
|
|
if isinstance(param.data, DTensor): |
|
|
unsharded_grad = DTensor.from_local( |
|
|
random_grads[i], |
|
|
device_mesh=param.data.device_mesh, |
|
|
placements=[Replicate()] * param.data.device_mesh.ndim, |
|
|
) |
|
|
sharded_grad = unsharded_grad.redistribute( |
|
|
device_mesh=param.data.device_mesh, |
|
|
placements=param.data.placements) |
|
|
param.grad = sharded_grad |
|
|
else: |
|
|
param.grad = random_grads[i] |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def run_muon(fsdp: bool, qk_clip: bool, seed: int) -> torch.nn.Module: |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
model = load_model(fsdp=fsdp) |
|
|
params = get_default_muon_param_groups(model) |
|
|
qk_logits = None |
|
|
if qk_clip: |
|
|
qk_logits = { |
|
|
i: torch.rand(model.config.num_attention_heads) |
|
|
for i in range(model.config.num_hidden_layers) |
|
|
} |
|
|
optim = Muon( |
|
|
params=params, |
|
|
clip_config={ |
|
|
"q_indices": list(range(model.config.num_attention_heads)), |
|
|
"k_indices": list(range(model.config.num_attention_heads)), |
|
|
"head_dim": |
|
|
model.config.hidden_size // model.config.num_attention_heads, |
|
|
"threshold": 0.5 |
|
|
}) |
|
|
optim.step(qk_logits=qk_logits) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def run_case(qk_clip: bool, seed: int = 0): |
|
|
parallel_muon_result = run_muon(fsdp=True, qk_clip=qk_clip, seed=seed) |
|
|
sequential_muon_result = run_muon(fsdp=False, qk_clip=qk_clip, seed=seed) |
|
|
label = f"qk_clip={'ON' if qk_clip else 'OFF'}" |
|
|
success = compare_results(parallel_muon_result, |
|
|
sequential_muon_result, |
|
|
label=label) |
|
|
|
|
|
return success, label |
|
|
|
|
|
|
|
|
def test_muon(): |
|
|
|
|
|
base_result = run_case(qk_clip=False, seed=0) |
|
|
clip_result = run_case(qk_clip=True, seed=0) |
|
|
|
|
|
for success, label in [base_result, clip_result]: |
|
|
if success: |
|
|
logger.info(f"[{label}] Models match") |
|
|
|
|
|
|
|
|
def compare_results(parallel_muon_result: torch.nn.Module, |
|
|
sequential_muon_result: torch.nn.Module, |
|
|
label: str) -> None: |
|
|
success = True |
|
|
for (name_p, p), (name_s, |
|
|
s) in zip(parallel_muon_result.named_parameters(), |
|
|
sequential_muon_result.named_parameters()): |
|
|
p = p.data.full_tensor() |
|
|
s = s.data |
|
|
|
|
|
if torch.abs(p - s).max() > 0: |
|
|
max_diff_index = torch.argmax(torch.abs(p - s)) |
|
|
logger.info(f"Models differ at parameter {name_p}") |
|
|
success = False |
|
|
|
|
|
return success |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dist.init_process_group(backend="nccl") |
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
|
|
test_muon() |
|
|
|