Commit
·
bdd2678
1
Parent(s):
8535e80
fix(muon): delete intermediate tensors immediately to lower peak mem usage
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -19
- torch-ext/optimizer/muon.py +10 -19
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a46d9e65efcfa82522950d9ebf2b2b4594d9ed5abc28704352a1f7de2dae707a
|
3 |
+
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d351a600884b7378f546a345afe65c176e1399bb42fb7dfe4333b0e90975803b
|
3 |
+
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c0843f38cee494b7a5939eb62d27039d76dc3f69401d411efbacaa25cb0d67a
|
3 |
+
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:acdba99ce95532a9ca6a8987a7ab61a257657872f2cc672c91e8e5fe809aa24e
|
3 |
+
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7d5e76c002507f66f2a227d02c2b11aa3fdc3f07a2a0b82faaa34133adb77ef
|
3 |
+
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:becccd250f38a84803350cfb5fac3a6682b1e594968a714642724cbc71246b4a
|
3 |
+
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34215ecc274ef516967962c8457dad214e9bbf618bf5eee8f467371f4f620284
|
3 |
+
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c23a3adbe4dc1a64b4851a9f8e4aed0e3e1eeeded27322c54f5b942282a2a332
|
3 |
+
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4aa09c22745d5efe1ef0669c4ca05615f67595dc90cabeee6e878301fa9bd22
|
3 |
+
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4baf569b70749c4657062fb0f56943fc486adb0c482e50c7aa8e31ddf5cc870
|
3 |
+
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc differ
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc
CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc differ
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_8535e80_dirty
|
3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8566c9bc05e13c9394572f9f9c6bac24c31932548be485f49eb49fb249880832
|
3 |
+
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|
torch-ext/optimizer/muon.py
CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
51 |
-
scattered_u: torch.Tensor | None = None
|
52 |
gather_event: torch.cuda.Event | None = None
|
53 |
compute_event: torch.cuda.Event | None = None
|
54 |
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
93 |
state.computed_u = u
|
94 |
state.compute_event = torch.cuda.Event()
|
95 |
state.compute_event.record()
|
|
|
|
|
96 |
else:
|
97 |
state.computed_u = None
|
98 |
state.compute_event = None
|
99 |
|
100 |
|
101 |
-
def _scatter(p, state, rank, comm_stream):
|
102 |
u = state.computed_u
|
103 |
mesh = p.device_mesh
|
104 |
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
118 |
src=state.worker_rank,
|
119 |
group=mesh.get_group(),
|
120 |
)
|
|
|
|
|
|
|
121 |
u = DTensor.from_local(
|
122 |
u,
|
123 |
placements=p.placements,
|
124 |
device_mesh=mesh,
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
class Muon(torch.optim.Optimizer):
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
353 |
def enqueue_scatters(start_idx, chunk_size):
|
354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
355 |
state = param_to_state[id(p)]
|
356 |
-
|
|
|
357 |
|
358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
359 |
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
368 |
|
369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
370 |
|
371 |
-
for p in params:
|
372 |
-
g = p.grad
|
373 |
-
if g is None:
|
374 |
-
continue
|
375 |
-
|
376 |
-
# Update p with sharded u
|
377 |
-
state = param_to_state[id(p)]
|
378 |
-
self._update_p(
|
379 |
-
p,
|
380 |
-
state.scattered_u,
|
381 |
-
lr=lr,
|
382 |
-
wd=wd,
|
383 |
-
)
|
384 |
-
|
385 |
def step(self, closure=None):
|
386 |
"""Perform a single optimization step.
|
387 |
|
|
|
48 |
worker_rank: int | None = None
|
49 |
gathered_grad: torch.Tensor | None = None
|
50 |
computed_u: torch.Tensor | None = None
|
|
|
51 |
gather_event: torch.cuda.Event | None = None
|
52 |
compute_event: torch.cuda.Event | None = None
|
53 |
|
|
|
92 |
state.computed_u = u
|
93 |
state.compute_event = torch.cuda.Event()
|
94 |
state.compute_event.record()
|
95 |
+
state.gathered_grad.record_stream(compute_stream)
|
96 |
+
del state.gathered_grad
|
97 |
else:
|
98 |
state.computed_u = None
|
99 |
state.compute_event = None
|
100 |
|
101 |
|
102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
103 |
u = state.computed_u
|
104 |
mesh = p.device_mesh
|
105 |
|
|
|
119 |
src=state.worker_rank,
|
120 |
group=mesh.get_group(),
|
121 |
)
|
122 |
+
if rank == state.worker_rank:
|
123 |
+
state.computed_u.record_stream(comm_stream)
|
124 |
+
del state.computed_u
|
125 |
u = DTensor.from_local(
|
126 |
u,
|
127 |
placements=p.placements,
|
128 |
device_mesh=mesh,
|
129 |
)
|
130 |
+
p.data.mul_(1 - lr * wd)
|
131 |
+
p.data.add_(u, alpha=-lr)
|
132 |
|
133 |
|
134 |
class Muon(torch.optim.Optimizer):
|
|
|
357 |
def enqueue_scatters(start_idx, chunk_size):
|
358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
359 |
state = param_to_state[id(p)]
|
360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
362 |
|
363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
364 |
|
|
|
373 |
|
374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
def step(self, closure=None):
|
377 |
"""Perform a single optimization step.
|
378 |
|