Kernels
dongseokmotif dongseokmotif github-actions[bot] commited on
Commit
d65066c
Β·
unverified Β·
1 Parent(s): 8997e30

feat(muon_clip) : add muon clip (#6)

Browse files

* feat(muon_clip) : add muon clip

* fix(muon_clip): delete comment

* fix(muon_clip): delete comment

* fix(muon_clip): considering when nkvheadgroup>1

* docs(muon_clip): refine __init__ docstring and add clip_info argument description

* refactor(muon_clip): refactor clip info using dataclass

* fix(muon_clip): change min -> new_scaling compare

* test(muon): add qk_clip=False case to model comparison

* test(muon): show results

* fix(muon_clip): change default is muon func

* Add built binary [ci skip]

---------

Co-authored-by: dongseokmotif <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Files changed (29) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} +1 -1
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +198 -39
  4. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} +1 -1
  6. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +198 -39
  7. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} +1 -1
  9. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +198 -39
  10. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} +1 -1
  12. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +198 -39
  13. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} +1 -1
  15. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +198 -39
  16. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} +1 -1
  18. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +198 -39
  19. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} +1 -1
  21. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +198 -39
  22. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} +1 -1
  24. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +198 -39
  25. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  26. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} +1 -1
  27. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +198 -39
  28. test/test_muon/test.py +47 -14
  29. torch-ext/optimizer/muon.py +198 -39
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd521b375aefeabe5cd5b38215d71b393e3902ed347426c64307e37c01f79a7c
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf8b97161714dff91953d26ae0bf59ebc9f3653ce57a3998723cc08aa97b71e6
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:282b0d443dc7b9c82703e5fd0f1a0faea94370934a92bef5042bf53ac3cae39c
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42ae6ac1cf967d7d23cac7930c8db635105f60631220a60b9cee060d082f40ae
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cd3d459f72674bcd05ba7cb96111bc90b08eeda3cbe1cd81ec5c0cd11730990
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dae71b7e998e72130093a86f8c983c3379510e23525e3cdcd4afe5c21bf4d3db
3
  size 1883344
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:36782463aaeaa8b35d9770743fe068b907085876d957c9d830d468fff4ebc735
3
  size 1749776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41492cb1479920b654768a5597d88670dd0caeedbdcd73fd63afa31ffc6961d6
3
  size 1749776
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:282b0d443dc7b9c82703e5fd0f1a0faea94370934a92bef5042bf53ac3cae39c
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42ae6ac1cf967d7d23cac7930c8db635105f60631220a60b9cee060d082f40ae
3
  size 1824256
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dff83fb4e6107a9447ae36fa98c19a873d71525898fde676c51252396c02a633
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dae71b7e998e72130093a86f8c983c3379510e23525e3cdcd4afe5c21bf4d3db
3
  size 1883344
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_4043ece_dirty.abi3.so β†’ torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cd3d459f72674bcd05ba7cb96111bc90b08eeda3cbe1cd81ec5c0cd11730990
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb40a06623bb3668b82ff248b5a3c1bcf41e7f3f860888b261505b3a71257bc7
3
  size 1883344
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:752f4346d75c9ede747a6baf4102022bc4bd776db86b5dbd74e47c2a112547ea
3
  size 1749936
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8f845b8df6426eb5db57e4525b8dd3c80004c44759b01a3e39cc37a817813b5
3
  size 1749936
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_4043ece_dirty
3
- ops = torch.ops._optimizer_4043ece_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_4043ece_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_9c21645_dirty
3
+ ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_9c21645_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_4043ece_dirty.abi3.so β†’ _optimizer_9c21645_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2ccbbb16b7d65cd7a4cb562dbef2a3d963f042836c700527a6bb755a8277f0c1
3
  size 1750024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a477575e3cc30e54d355b3e778240dc25fb0dab30362f3540dc5f925ac03ba1
3
  size 1750024
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else:
test/test_muon/test.py CHANGED
@@ -20,7 +20,6 @@ def load_model(fsdp: bool) -> torch.nn.Module:
20
  trust_remote_code=True,
21
  ).bfloat16().cuda()
22
 
23
- torch.manual_seed(0)
24
  random_grads = []
25
  for param in model.parameters():
26
  random_grad = torch.randn_like(param,
@@ -52,17 +51,57 @@ def load_model(fsdp: bool) -> torch.nn.Module:
52
  return model
53
 
54
 
55
- def run_muon(fsdp: bool) -> torch.nn.Module:
 
 
 
56
  model = load_model(fsdp=fsdp)
57
  params = get_default_muon_param_groups(model)
58
- optim = Muon(params=params)
59
- optim.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  return model
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def compare_results(parallel_muon_result: torch.nn.Module,
65
- sequential_muon_result: torch.nn.Module) -> None:
 
 
66
  for (name_p, p), (name_s,
67
  s) in zip(parallel_muon_result.named_parameters(),
68
  sequential_muon_result.named_parameters()):
@@ -71,16 +110,10 @@ def compare_results(parallel_muon_result: torch.nn.Module,
71
  # Parallel Muon should exactly match Sequential Muon
72
  if torch.abs(p - s).max() > 0:
73
  max_diff_index = torch.argmax(torch.abs(p - s))
74
- logger.error(f"Models differ at parameter {name_p}")
75
- return
76
- logger.info("Models match!")
77
-
78
-
79
- def test_muon():
80
- parallel_muon_result = run_muon(fsdp=True)
81
- sequential_muon_result = run_muon(fsdp=False)
82
 
83
- compare_results(parallel_muon_result, sequential_muon_result)
84
 
85
 
86
  if __name__ == "__main__":
 
20
  trust_remote_code=True,
21
  ).bfloat16().cuda()
22
 
 
23
  random_grads = []
24
  for param in model.parameters():
25
  random_grad = torch.randn_like(param,
 
51
  return model
52
 
53
 
54
+ def run_muon(fsdp: bool, qk_clip: bool, seed: int) -> torch.nn.Module:
55
+ torch.manual_seed(seed)
56
+ if torch.cuda.is_available():
57
+ torch.cuda.manual_seed_all(seed)
58
  model = load_model(fsdp=fsdp)
59
  params = get_default_muon_param_groups(model)
60
+ qk_logits = None
61
+ if qk_clip:
62
+ qk_logits = {
63
+ i: torch.rand(model.config.num_attention_heads)
64
+ for i in range(model.config.num_hidden_layers)
65
+ }
66
+ optim = Muon(
67
+ params=params,
68
+ clip_config={
69
+ "q_indices": list(range(model.config.num_attention_heads)),
70
+ "k_indices": list(range(model.config.num_attention_heads)),
71
+ "head_dim":
72
+ model.config.hidden_size // model.config.num_attention_heads,
73
+ "threshold": 0.5
74
+ })
75
+ optim.step(qk_logits=qk_logits)
76
 
77
  return model
78
 
79
 
80
+ def run_case(qk_clip: bool, seed: int = 0):
81
+ parallel_muon_result = run_muon(fsdp=True, qk_clip=qk_clip, seed=seed)
82
+ sequential_muon_result = run_muon(fsdp=False, qk_clip=qk_clip, seed=seed)
83
+ label = f"qk_clip={'ON' if qk_clip else 'OFF'}"
84
+ success = compare_results(parallel_muon_result,
85
+ sequential_muon_result,
86
+ label=label)
87
+
88
+ return success, label
89
+
90
+
91
+ def test_muon():
92
+
93
+ base_result = run_case(qk_clip=False, seed=0)
94
+ clip_result = run_case(qk_clip=True, seed=0)
95
+
96
+ for success, label in [base_result, clip_result]:
97
+ if success:
98
+ logger.info(f"[{label}] Models match")
99
+
100
+
101
  def compare_results(parallel_muon_result: torch.nn.Module,
102
+ sequential_muon_result: torch.nn.Module,
103
+ label: str) -> None:
104
+ success = True
105
  for (name_p, p), (name_s,
106
  s) in zip(parallel_muon_result.named_parameters(),
107
  sequential_muon_result.named_parameters()):
 
110
  # Parallel Muon should exactly match Sequential Muon
111
  if torch.abs(p - s).max() > 0:
112
  max_diff_index = torch.argmax(torch.abs(p - s))
113
+ logger.info(f"Models differ at parameter {name_p}")
114
+ success = False
 
 
 
 
 
 
115
 
116
+ return success
117
 
118
 
119
  if __name__ == "__main__":
torch-ext/optimizer/muon.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
@@ -66,6 +66,7 @@ class _muon_state:
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
 
69
 
70
 
71
  @torch.no_grad()
@@ -193,32 +194,93 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
193
  state.scattered_u = None
194
  u_dtensor = None
195
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def default_is_muon(name, x):
198
- return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
 
199
 
200
 
201
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
 
 
 
 
 
 
 
 
 
 
 
 
202
  return [
203
  {
204
- "params": [
205
- p for n, p in model.named_parameters()
206
- if (is_muon_func(n, p) and p.requires_grad)
207
- ],
208
- "use_muon":
209
- True
210
  },
211
  {
212
- "params": [
213
- p for n, p in model.named_parameters()
214
- if (not is_muon_func(n, p) and p.requires_grad)
215
- ],
216
- "use_muon":
217
- False
218
  },
219
  ]
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class Muon(torch.optim.Optimizer):
223
  """
224
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -246,21 +308,38 @@ class Muon(torch.optim.Optimizer):
246
  adamw_eps: The epsilon for the internal AdamW.
247
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
248
  debug: Whether to print debug information.
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
- def __init__(
252
- self,
253
- params,
254
- lr=1e-3,
255
- momentum=0.95,
256
- nesterov=True,
257
- ns_steps=5,
258
- weight_decay=0.1,
259
- adamw_betas=(0.9, 0.95),
260
- adamw_eps=1e-8,
261
- none_grad=True,
262
- debug=False,
263
- ):
 
 
 
 
264
  defaults = dict(
265
  lr=lr,
266
  weight_decay=weight_decay,
@@ -292,6 +371,7 @@ class Muon(torch.optim.Optimizer):
292
  self.comm_stream = torch.cuda.Stream()
293
  self.compute_stream = torch.cuda.Stream()
294
  self.debug = debug
 
295
 
296
  def _calc_flops(self, G, steps):
297
  assert len(G.shape) == 2
@@ -327,7 +407,7 @@ class Muon(torch.optim.Optimizer):
327
  else:
328
  raise ValueError(f"Unsupported placements ({p.placements}).")
329
 
330
- def init_state_and_assign_params(self, params, group):
331
  param_to_state = {}
332
  param_to_flops = {}
333
 
@@ -346,15 +426,21 @@ class Muon(torch.optim.Optimizer):
346
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
347
  flush=True)
348
 
349
- ordered_params = sorted(params,
350
- key=lambda p: param_to_flops[id(p)],
351
- reverse=True)
 
 
 
 
 
 
352
 
353
  round_robin = 0
354
  mesh = None
355
  shard_mesh = None
356
  process_group = None
357
- for p in ordered_params:
358
  if mesh is None:
359
  mesh = p.device_mesh
360
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
@@ -364,14 +450,16 @@ class Muon(torch.optim.Optimizer):
364
  param_to_state[id(p)] = _muon_state()
365
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
366
  param_to_state[id(p)].process_group = process_group
367
-
 
368
  round_robin = (round_robin + 1) % len(shard_mesh)
369
 
370
  return param_to_state, ordered_params
371
 
372
- def base(self, params, group, lr, weight_decay, momentum):
 
373
  # generate weight updates in distributed fashion
374
- for p in params:
375
  g = p.grad
376
  if g is None:
377
  continue
@@ -396,6 +484,12 @@ class Muon(torch.optim.Optimizer):
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
398
 
 
 
 
 
 
 
399
  def _update_g(self, p, g, group, momentum):
400
  # calc update
401
  state = self.state[p]
@@ -416,7 +510,58 @@ class Muon(torch.optim.Optimizer):
416
  # apply update
417
  p.data.add_(u, alpha=-adjusted_lr)
418
 
419
- def parallel(self, params, group, lr, weight_decay, momentum):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  """
421
  Perform a parallel optimization step using Muon.
422
  """
@@ -438,7 +583,7 @@ class Muon(torch.optim.Optimizer):
438
  p.grad = g
439
 
440
  param_to_state, ordered_params = self.init_state_and_assign_params(
441
- params, group)
442
 
443
  def enqueue_gathers(start_idx, chunk_size):
444
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -553,12 +698,16 @@ class Muon(torch.optim.Optimizer):
553
  maximize=maximize,
554
  )
555
 
556
- def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
559
  Args:
560
  closure (Callable, optional): A closure that reevaluates the model
561
  and returns the loss.
 
 
 
 
562
  """
563
  loss = None
564
  if closure is not None:
@@ -575,11 +724,14 @@ class Muon(torch.optim.Optimizer):
575
  lr = group["lr"]
576
  weight_decay = group["weight_decay"]
577
  momentum = group["momentum"]
 
578
 
579
  param_dtensors = []
580
  param_tensors = []
 
 
581
 
582
- for p in params:
583
  if p is None or p.grad is None:
584
  continue
585
  if isinstance(p.data, DTensor):
@@ -587,10 +739,13 @@ class Muon(torch.optim.Optimizer):
587
  isinstance(placement, Replicate)
588
  for placement in p.placements):
589
  param_tensors.append(p)
 
590
  else:
591
  param_dtensors.append(p)
 
592
  elif isinstance(p.data, torch.Tensor):
593
  param_tensors.append(p)
 
594
  else:
595
  raise TypeError(
596
  f"Unsupported parameter type: {type(p.data)}")
@@ -608,20 +763,24 @@ class Muon(torch.optim.Optimizer):
608
  )
609
 
610
  self.parallel(
 
611
  param_dtensors,
612
  group,
613
  lr=lr,
614
  weight_decay=weight_decay,
615
  momentum=momentum,
 
616
  )
617
 
618
  if len(param_tensors) > 0:
619
  self.base(
 
620
  param_tensors,
621
  group,
622
  lr=lr,
623
  weight_decay=weight_decay,
624
  momentum=momentum,
 
625
  )
626
 
627
  else:
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import List, Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
66
  compute_event: torch.cuda.Event | None = None
67
  scatter_event: torch.cuda.Event | None = None
68
  process_group = None
69
+ qk_clip_state = None
70
 
71
 
72
  @torch.no_grad()
 
194
  state.scattered_u = None
195
  u_dtensor = None
196
 
197
+ scales_full = Muon._compute_scales(p, state.qk_clip_state)
198
+ if scales_full is not None:
199
+ num_ranks = dist.get_world_size(group=state.process_group)
200
+ local_rank = dist.get_rank(group=state.process_group)
201
+ scales_local = scales_full.chunk(num_ranks, dim=0)[local_rank]
202
+ scales_local = DTensor.from_local(
203
+ scales_local,
204
+ placements=p.placements,
205
+ device_mesh=p.device_mesh,
206
+ )
207
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
208
+
209
 
210
  def default_is_muon(name, x):
211
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
212
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
213
 
214
 
215
  def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
216
+ muon_params, muon_names = [], []
217
+ non_muon_params = []
218
+
219
+ for n, p in model.named_parameters():
220
+ if not p.requires_grad:
221
+ continue
222
+ if is_muon_func(n, p):
223
+ muon_params.append(p)
224
+ muon_names.append(n)
225
+ else:
226
+ non_muon_params.append(p)
227
+
228
  return [
229
  {
230
+ "params": muon_params,
231
+ "names": muon_names,
232
+ "use_muon": True,
 
 
 
233
  },
234
  {
235
+ "params": non_muon_params,
236
+ "use_muon": False,
 
 
 
 
237
  },
238
  ]
239
 
240
 
241
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
242
+ """
243
+ Parse a parameter name to check if it is a query/key projection layer
244
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
245
+
246
+ Returns:
247
+ (kind, layer_idx) or (None, -1) if not matched.
248
+
249
+ Example:
250
+ 'model.3.attn.wq.weight' -> ('wq', 3)
251
+ 'model.5.attn.wk.weight' -> ('wk', 5)
252
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
253
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
254
+ 'model.4.attn.v_proj.weight' -> (None, -1)
255
+ """
256
+ parts = name.split('.')
257
+ if len(parts) < 3:
258
+ return None, -1
259
+
260
+ kind = parts[-2]
261
+
262
+ layer_idx = -1
263
+ for part in reversed(parts):
264
+ if part.isdigit():
265
+ layer_idx = int(part)
266
+ break
267
+
268
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
269
+ return kind, layer_idx
270
+
271
+ return None, -1
272
+
273
+
274
+ @dataclass
275
+ class QKClipInfo:
276
+ """Per-parameter dynamic info computed from config + runtime logits."""
277
+ kind: Optional[str] # 'wq'/'q_proj' or 'wk'/'k_proj' or None
278
+ indices: List[int] # which heads to consider for clipping
279
+ head_dim: int # from config
280
+ threshold: float # from config
281
+ logit: Optional[torch.Tensor]
282
+
283
+
284
  class Muon(torch.optim.Optimizer):
285
  """
286
  Muon - MomentUm Orthogonalized by Newton-schulz
 
308
  adamw_eps: The epsilon for the internal AdamW.
309
  none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
310
  debug: Whether to print debug information.
311
+ clip_info : Configuration for QK clipping. Expected keys:
312
+ - "q_indices" (list[int]): Indices of query heads to consider.
313
+ - "k_indices" (list[int]): Indices of key heads to consider.
314
+ - "head_dim" (int): Dimensionality of each attention head.
315
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
316
+ this value will be scaled down.
317
+ Default is:
318
+ {
319
+ "q_indices": [],
320
+ "k_indices": [],
321
+ "head_dim": 128,
322
+ "threshold": 100
323
+ }
324
  """
325
 
326
+ def __init__(self,
327
+ params,
328
+ lr=1e-3,
329
+ momentum=0.95,
330
+ nesterov=True,
331
+ ns_steps=5,
332
+ weight_decay=0.1,
333
+ adamw_betas=(0.9, 0.95),
334
+ adamw_eps=1e-8,
335
+ none_grad=True,
336
+ debug=False,
337
+ clip_config={
338
+ "q_indices": [],
339
+ "k_indices": [],
340
+ "head_dim": 128,
341
+ "threshold": 100
342
+ }):
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
+ self.clip_config = clip_config
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
 
407
  else:
408
  raise ValueError(f"Unsupported placements ({p.placements}).")
409
 
410
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
411
  param_to_state = {}
412
  param_to_flops = {}
413
 
 
426
  print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
427
  flush=True)
428
 
429
+ paired = list(zip(names, params))
430
+
431
+ paired_sorted = sorted(paired,
432
+ key=lambda x: param_to_flops[id(x[1])],
433
+ reverse=True)
434
+
435
+ names_sorted, params_sorted = zip(*paired_sorted)
436
+ ordered_names = list(names_sorted)
437
+ ordered_params = list(params_sorted)
438
 
439
  round_robin = 0
440
  mesh = None
441
  shard_mesh = None
442
  process_group = None
443
+ for n, p in zip(ordered_names, ordered_params):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
450
  param_to_state[id(p)] = _muon_state()
451
  param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
452
  param_to_state[id(p)].process_group = process_group
453
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
+ param_to_state[id(p)].qk_clip_state = qk_clip_state
455
  round_robin = (round_robin + 1) % len(shard_mesh)
456
 
457
  return param_to_state, ordered_params
458
 
459
+ def base(self, names, params, group, lr, weight_decay, momentum,
460
+ qk_logits):
461
  # generate weight updates in distributed fashion
462
+ for n, p in zip(names, params):
463
  g = p.grad
464
  if g is None:
465
  continue
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
485
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
486
 
487
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
488
+
489
+ scales_full = self._compute_scales(p, qk_clip_state)
490
+ if scales_full is not None:
491
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
492
+
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
 
510
  # apply update
511
  p.data.add_(u, alpha=-adjusted_lr)
512
 
513
+ def get_qk_clip_info(self, n, qk_logits):
514
+ head_dim = self.clip_config.get('head_dim')
515
+ threshold = self.clip_config.get('threshold')
516
+ kind, layer_idx = parse_qk_layer(n)
517
+
518
+ logit, indices = None, []
519
+ if qk_logits is not None and kind is not None:
520
+ logit = qk_logits[layer_idx]
521
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
522
+ indices = self.clip_config.get(indices_key, []) or []
523
+
524
+ return QKClipInfo(
525
+ kind=kind,
526
+ indices=indices,
527
+ head_dim=head_dim,
528
+ threshold=threshold,
529
+ logit=logit,
530
+ )
531
+
532
+ @staticmethod
533
+ def _compute_scales(p, qk_clip_state):
534
+ kind = qk_clip_state.kind
535
+ indices = qk_clip_state.indices
536
+ head_dim = qk_clip_state.head_dim
537
+ threshold = qk_clip_state.threshold
538
+ logit = qk_clip_state.logit
539
+
540
+ H_global = p.shape[0] // head_dim
541
+ scales_full = torch.ones(H_global, device=p.data.device)
542
+ scaling = 0
543
+
544
+ for logit_idx, head_idx in enumerate(indices):
545
+ v_ele = float(logit[logit_idx])
546
+ if v_ele > threshold:
547
+ new_scale = math.sqrt(threshold / v_ele)
548
+ if new_scale < scales_full[head_idx]:
549
+ scales_full[head_idx] = new_scale
550
+ logger.info(
551
+ f"[{kind}] Head {head_idx} exceeded threshold "
552
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
553
+ )
554
+ scaling += 1
555
+
556
+ return scales_full if scaling > 0 else None
557
+
558
+ @staticmethod
559
+ def _qk_clip(p, scales, head_dim):
560
+ W = p.data.view(-1, head_dim, p.data.shape[1])
561
+ W.mul_(scales.view(-1, 1, 1))
562
+
563
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
564
+ qk_logits):
565
  """
566
  Perform a parallel optimization step using Muon.
567
  """
 
583
  p.grad = g
584
 
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
+ names, params, group, qk_logits)
587
 
588
  def enqueue_gathers(start_idx, chunk_size):
589
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
698
  maximize=maximize,
699
  )
700
 
701
+ def step(self, closure=None, qk_logits=None):
702
  """Perform a single optimization step.
703
 
704
  Args:
705
  closure (Callable, optional): A closure that reevaluates the model
706
  and returns the loss.
707
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
708
+ to 1D tensors of shape (num_heads,), representing the maximum
709
+ QK logits across all tokens, computed as
710
+ (1 / sqrt(head_dim)) * (Q @ K^T).
711
  """
712
  loss = None
713
  if closure is not None:
 
724
  lr = group["lr"]
725
  weight_decay = group["weight_decay"]
726
  momentum = group["momentum"]
727
+ names = group["names"]
728
 
729
  param_dtensors = []
730
  param_tensors = []
731
+ name_dtensors = []
732
+ name_tensors = []
733
 
734
+ for n, p in zip(names, params):
735
  if p is None or p.grad is None:
736
  continue
737
  if isinstance(p.data, DTensor):
 
739
  isinstance(placement, Replicate)
740
  for placement in p.placements):
741
  param_tensors.append(p)
742
+ name_tensors.append(n)
743
  else:
744
  param_dtensors.append(p)
745
+ name_dtensors.append(n)
746
  elif isinstance(p.data, torch.Tensor):
747
  param_tensors.append(p)
748
+ name_tensors.append(n)
749
  else:
750
  raise TypeError(
751
  f"Unsupported parameter type: {type(p.data)}")
 
763
  )
764
 
765
  self.parallel(
766
+ name_dtensors,
767
  param_dtensors,
768
  group,
769
  lr=lr,
770
  weight_decay=weight_decay,
771
  momentum=momentum,
772
+ qk_logits=qk_logits,
773
  )
774
 
775
  if len(param_tensors) > 0:
776
  self.base(
777
+ name_tensors,
778
  param_tensors,
779
  group,
780
  lr=lr,
781
  weight_decay=weight_decay,
782
  momentum=momentum,
783
+ qk_logits=qk_logits,
784
  )
785
 
786
  else: