Kernels
ca1207 commited on
Commit
811726c
·
1 Parent(s): 0b8d958

fix bug in fsdp

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +5 -0
torch-ext/optimizer/muon.py CHANGED
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
 
 
 
 
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
606
 
607
  if p.placements == (Shard(dim=0), ):
608
  # Case for FSDP
609
+ process_group = p.device_mesh.get_group(mesh_dim=0)
610
+ if self.rank is None:
611
+ self.rank = dist.get_rank(group=process_group)
612
+ else:
613
+ assert self.rank == dist.get_rank(group=process_group)
614
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
615
  elif p.placements == (Replicate(), Shard(dim=0)):
616
  # Case for HSDP