Kernels
ca1207's picture
misc
35894d1
raw
history blame
4.09 kB
import logging
import torch
import torch.distributed as dist
from optimizer.muon import Muon, get_default_muon_param_groups
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def load_model(fsdp: bool) -> torch.nn.Module:
model_name = "Motif-Technologies/Motif-2.6B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
).bfloat16().cuda()
random_grads = []
for param in model.parameters():
random_grad = torch.randn_like(param,
device=param.device,
dtype=param.dtype)
random_grads.append(random_grad)
if fsdp:
for layer in model.model.layers:
fully_shard(layer)
layer.reshard()
fully_shard(model)
model.reshard()
for i, param in enumerate(model.parameters()):
if isinstance(param.data, DTensor):
unsharded_grad = DTensor.from_local(
random_grads[i],
device_mesh=param.data.device_mesh,
placements=[Replicate()] * param.data.device_mesh.ndim,
)
sharded_grad = unsharded_grad.redistribute(
device_mesh=param.data.device_mesh,
placements=param.data.placements)
param.grad = sharded_grad
else:
param.grad = random_grads[i]
return model
def run_muon(fsdp: bool, qk_clip: bool, seed: int) -> torch.nn.Module:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model = load_model(fsdp=fsdp)
params = get_default_muon_param_groups(model)
qk_logits = None
if qk_clip:
qk_logits = {
i: torch.rand(model.config.num_attention_heads)
for i in range(model.config.num_hidden_layers)
}
optim = Muon(
params=params,
clip_config={
"q_indices": list(range(model.config.num_attention_heads)),
"k_indices": list(range(model.config.num_attention_heads)),
"head_dim":
model.config.hidden_size // model.config.num_attention_heads,
"threshold": 0.5
})
optim.step(qk_logits=qk_logits)
return model
def run_case(qk_clip: bool, seed: int = 0):
parallel_muon_result = run_muon(fsdp=True, qk_clip=qk_clip, seed=seed)
sequential_muon_result = run_muon(fsdp=False, qk_clip=qk_clip, seed=seed)
label = f"qk_clip={'ON' if qk_clip else 'OFF'}"
success = compare_results(parallel_muon_result,
sequential_muon_result,
label=label)
return success, label
def test_muon():
base_result = run_case(qk_clip=False, seed=0)
clip_result = run_case(qk_clip=True, seed=0)
for success, label in [base_result, clip_result]:
if success:
logger.info(f"[{label}] Models match")
def compare_results(parallel_muon_result: torch.nn.Module,
sequential_muon_result: torch.nn.Module,
label: str) -> None:
success = True
for (name_p, p), (name_s,
s) in zip(parallel_muon_result.named_parameters(),
sequential_muon_result.named_parameters()):
p = p.data.full_tensor()
s = s.data
# Parallel Muon should exactly match Sequential Muon
if torch.abs(p - s).max() > 0:
max_diff_index = torch.argmax(torch.abs(p - s))
logger.info(f"Models differ at parameter {name_p}")
success = False
return success
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
test_muon()