Commit
·
dd04b01
1
Parent(s):
1e2b528
feat(muon): support HSDP
Browse files- build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +33 -15
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +33 -15
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +33 -15
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +33 -15
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +33 -15
- torch-ext/optimizer/muon.py +33 -15
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787368
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9112c8dde01baefa0e3130e143288cd3073ccbab47369a6dc925ce0d35400c6d
|
| 3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0449cd352f44c3e848d1f9c847b00bf576673b4fef2a954ec8bd8d2524b8353a
|
| 3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e6bab72b965f42d466cd74bbda49851549f2810278e642cef8738e40de4fdc5
|
| 3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749840
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bdcf9e3d8bf13aa01bf1ae7a94a12dd05c50702a24b57e4cfcc2e54ca5ed21c3
|
| 3 |
size 1749840
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a423eb4ab3a31c53a3326c71e34fa59fc661f8d432701e41a7de900a9c23e37c
|
| 3 |
size 1824256
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:86d98863cc7ef0b271808b0ef7b1082603cfb5a76986481df37431527aaaf27b
|
| 3 |
size 1883352
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_1f13dae_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8daaad69e6958850f848fab60c9acb938c3a5e54e3ec34a1bec03a3d32653cb
|
| 3 |
size 1883352
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1750000
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76910ba81e2c95c83207118725c4379db636346c4ccf05010e2ee00c41dff1ce
|
| 3 |
size 1750000
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc and b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc differ
|
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_2dc97a1_dirty
|
| 3 |
+
ops = torch.ops._optimizer_2dc97a1_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_2dc97a1_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_1f13dae_dirty.abi3.so → _optimizer_2dc97a1_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1750088
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd0a35a6f846a075a8f4561cfc66ef17c6358dd4a0062e63057b02625d9d6af7
|
| 3 |
size 1750088
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -50,15 +50,16 @@ class _muon_state:
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
-
mesh = g.device_mesh
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
-
|
|
|
|
| 62 |
else:
|
| 63 |
gather_list = None
|
| 64 |
|
|
@@ -67,7 +68,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 67 |
g.to_local(),
|
| 68 |
dst=state.worker_rank,
|
| 69 |
gather_list=gather_list,
|
| 70 |
-
group=
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
if state.gathered_grad is not None:
|
|
@@ -105,14 +106,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 105 |
@torch.no_grad()
|
| 106 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
-
mesh = p.device_mesh
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
|
|
|
| 112 |
if state.compute_event is None:
|
| 113 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 114 |
comm_stream.wait_event(state.compute_event)
|
| 115 |
-
scatter_list = list(torch.split(u, p.size(0) //
|
| 116 |
else:
|
| 117 |
scatter_list = None
|
| 118 |
|
|
@@ -121,7 +122,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 121 |
u,
|
| 122 |
scatter_list=scatter_list,
|
| 123 |
src=state.worker_rank,
|
| 124 |
-
group=
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
# Clear u to free memory
|
|
@@ -129,7 +130,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
| 132 |
-
device_mesh=
|
| 133 |
)
|
| 134 |
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
|
@@ -235,6 +236,23 @@ class Muon(torch.optim.Optimizer):
|
|
| 235 |
adjusted_lr = lr * adjusted_ratio
|
| 236 |
return adjusted_lr
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def init_state_and_assign_params(self, params, group):
|
| 239 |
param_to_state = {}
|
| 240 |
param_to_flops = {}
|
|
@@ -259,20 +277,20 @@ class Muon(torch.optim.Optimizer):
|
|
| 259 |
|
| 260 |
round_robin = 0
|
| 261 |
mesh = None
|
|
|
|
|
|
|
| 262 |
for p in ordered_params:
|
| 263 |
if mesh is None:
|
| 264 |
mesh = p.device_mesh
|
| 265 |
-
|
| 266 |
-
raise NotImplementedError(
|
| 267 |
-
"Muon requires a 1D mesh for distributed training yet."
|
| 268 |
-
)
|
| 269 |
elif mesh != p.device_mesh:
|
| 270 |
raise ValueError("All parameters must be on the same mesh.")
|
| 271 |
|
| 272 |
param_to_state[id(p)] = _muon_state()
|
| 273 |
-
param_to_state[id(p)].worker_rank =
|
|
|
|
| 274 |
|
| 275 |
-
round_robin = (round_robin + 1) %
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
|
@@ -372,7 +390,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 372 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
)
|
| 374 |
|
| 375 |
-
chunk_size = params[0].
|
| 376 |
|
| 377 |
# Wait grad update
|
| 378 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate, Shard
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
+
process_group = None
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
def _gather(p, state, rank, comm_stream, none_grad):
|
| 58 |
g = p.grad
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
|
|
| 68 |
g.to_local(),
|
| 69 |
dst=state.worker_rank,
|
| 70 |
gather_list=gather_list,
|
| 71 |
+
group=state.process_group,
|
| 72 |
)
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
|
|
|
| 106 |
@torch.no_grad()
|
| 107 |
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 108 |
u = state.computed_u
|
|
|
|
| 109 |
|
| 110 |
with torch.cuda.stream(comm_stream):
|
| 111 |
if rank == state.worker_rank:
|
| 112 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
| 113 |
if state.compute_event is None:
|
| 114 |
raise RuntimeError("Compute event must be set before scatter.")
|
| 115 |
comm_stream.wait_event(state.compute_event)
|
| 116 |
+
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
| 117 |
else:
|
| 118 |
scatter_list = None
|
| 119 |
|
|
|
|
| 122 |
u,
|
| 123 |
scatter_list=scatter_list,
|
| 124 |
src=state.worker_rank,
|
| 125 |
+
group=state.process_group,
|
| 126 |
)
|
| 127 |
if rank == state.worker_rank:
|
| 128 |
# Clear u to free memory
|
|
|
|
| 130 |
u = DTensor.from_local(
|
| 131 |
u,
|
| 132 |
placements=p.placements,
|
| 133 |
+
device_mesh=p.device_mesh,
|
| 134 |
)
|
| 135 |
p.data.mul_(1 - lr * weight_decay)
|
| 136 |
p.data.add_(u, alpha=-lr)
|
|
|
|
| 236 |
adjusted_lr = lr * adjusted_ratio
|
| 237 |
return adjusted_lr
|
| 238 |
|
| 239 |
+
def get_shard_mesh(self, p, rank):
|
| 240 |
+
"""
|
| 241 |
+
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
+
"""
|
| 243 |
+
assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 244 |
+
|
| 245 |
+
if p.placements == (Shard(dim=0),):
|
| 246 |
+
# Case for FSDP
|
| 247 |
+
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
+
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 249 |
+
# Case for HSDP
|
| 250 |
+
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 251 |
+
if rank in shard_mesh:
|
| 252 |
+
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unsupported placements ({p.placements}).")
|
| 255 |
+
|
| 256 |
def init_state_and_assign_params(self, params, group):
|
| 257 |
param_to_state = {}
|
| 258 |
param_to_flops = {}
|
|
|
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
| 280 |
+
shard_mesh = None
|
| 281 |
+
process_group = None
|
| 282 |
for p in ordered_params:
|
| 283 |
if mesh is None:
|
| 284 |
mesh = p.device_mesh
|
| 285 |
+
shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
|
|
|
|
|
|
|
|
|
|
| 286 |
elif mesh != p.device_mesh:
|
| 287 |
raise ValueError("All parameters must be on the same mesh.")
|
| 288 |
|
| 289 |
param_to_state[id(p)] = _muon_state()
|
| 290 |
+
param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
|
| 291 |
+
param_to_state[id(p)].process_group = process_group
|
| 292 |
|
| 293 |
+
round_robin = (round_robin + 1) % len(shard_mesh)
|
| 294 |
|
| 295 |
return param_to_state, ordered_params
|
| 296 |
|
|
|
|
| 390 |
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 391 |
)
|
| 392 |
|
| 393 |
+
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|