fix bug in fsdp
Browse files
torch-ext/optimizer/muon.py
CHANGED
|
@@ -606,6 +606,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 606 |
|
| 607 |
if p.placements == (Shard(dim=0), ):
|
| 608 |
# Case for FSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
| 606 |
|
| 607 |
if p.placements == (Shard(dim=0), ):
|
| 608 |
# Case for FSDP
|
| 609 |
+
process_group = p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
+
if self.rank is None:
|
| 611 |
+
self.rank = dist.get_rank(group=process_group)
|
| 612 |
+
else:
|
| 613 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 614 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 615 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 616 |
# Case for HSDP
|