Commit
·
64757cb
1
Parent(s):
036642a
fix(muon): free tensors that are no longer needed
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -7
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -7
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -7
- torch-ext/optimizer/muon.py +10 -7
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787272
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c77e5647b6056bfaee25050cca7948c40859db0a88fa4fcf40b67a85c947d8c
|
| 3 |
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94ea66089cc8d9eda72b017733a9e05e4fee5a2f04c50658b690d2c19f0d3068
|
| 3 |
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46e01e1d957ada2d485b30cd60bc3ef7230b8857dffc59f2e7924339761ec577
|
| 3 |
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749744
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a825a0cd31d8c1b91aa9db4b24248d7fc0a506615f625a385b40e6002025c7dd
|
| 3 |
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787192
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:579e9ddf66a4f17ead9232c2f32e6327fe6a3f16dd235e2e73e6cb282de1797e
|
| 3 |
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:beacb4ba2d56463b6d444875728b3462cb3ff6c1449e3c9693cd665bfbbbbb73
|
| 3 |
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b04b011803d328d8dcd2edcf4c3840ddbb1bb2f093464c208f0ba2faf4f16bc
|
| 3 |
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787368
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad6c725009f2e776b99d3134c75f15e11dd7fe75fe4ba1fa94779018c7871f8c
|
| 3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50cb5819ff08a2179d78cd98164d07fd3cef1b66ee7703d599a310dfb140b9d1
|
| 3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c75e42265f382addc71327ad5628e8a2414da5872791c975e384708c4acd549
|
| 3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_036642a_dirty
|
| 3 |
+
ops = torch.ops._optimizer_036642a_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_036642a_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_febdf5b_dirty.abi3.so → _optimizer_036642a_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749648
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a2363d4311d6a75fbcc03e6d4a71c73dae4d54e00a30135d25198d4078c6b0f
|
| 3 |
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -53,7 +53,7 @@ class _muon_state:
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
-
def _gather(p, state, rank, comm_stream):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
@@ -70,7 +70,6 @@ def _gather(p, state, rank, comm_stream):
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
| 73 |
-
# TODO: Consider ,,,
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
"Gather event already exists, which should not happen."
|
|
@@ -81,6 +80,8 @@ def _gather(p, state, rank, comm_stream):
|
|
| 81 |
else:
|
| 82 |
state.gathered_grad = None
|
| 83 |
state.gather_event = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.no_grad()
|
|
@@ -94,8 +95,8 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 94 |
state.computed_u = u
|
| 95 |
state.compute_event = torch.cuda.Event()
|
| 96 |
state.compute_event.record()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
else:
|
| 100 |
state.computed_u = None
|
| 101 |
state.compute_event = None
|
|
@@ -123,8 +124,8 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 123 |
group=mesh.get_group(),
|
| 124 |
)
|
| 125 |
if rank == state.worker_rank:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
u = DTensor.from_local(
|
| 129 |
u,
|
| 130 |
placements=p.placements,
|
|
@@ -172,6 +173,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 172 |
adamw_wd=0.1,
|
| 173 |
adamw_betas=(0.9, 0.95),
|
| 174 |
adamw_eps=1e-8,
|
|
|
|
| 175 |
debug=False,
|
| 176 |
):
|
| 177 |
defaults = dict(
|
|
@@ -182,6 +184,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 182 |
ns_steps=ns_steps,
|
| 183 |
adamw_betas=adamw_betas,
|
| 184 |
adamw_eps=adamw_eps,
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
super().__init__(model.parameters(), defaults)
|
|
@@ -350,7 +353,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 350 |
def enqueue_gathers(start_idx, chunk_size):
|
| 351 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 352 |
state = param_to_state[id(p)]
|
| 353 |
-
_gather(p, state, self.rank, self.comm_stream)
|
| 354 |
|
| 355 |
def enqueue_computes(start_idx, chunk_size):
|
| 356 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
@torch.no_grad()
|
| 56 |
+
def _gather(p, state, rank, comm_stream, none_grad):
|
| 57 |
g = p.grad
|
| 58 |
mesh = g.device_mesh
|
| 59 |
|
|
|
|
| 70 |
group=mesh.get_group(),
|
| 71 |
)
|
| 72 |
if rank == state.worker_rank:
|
|
|
|
| 73 |
if state.gathered_grad is not None:
|
| 74 |
raise RuntimeError(
|
| 75 |
"Gather event already exists, which should not happen."
|
|
|
|
| 80 |
else:
|
| 81 |
state.gathered_grad = None
|
| 82 |
state.gather_event = None
|
| 83 |
+
if none_grad:
|
| 84 |
+
p.grad = None
|
| 85 |
|
| 86 |
|
| 87 |
@torch.no_grad()
|
|
|
|
| 95 |
state.computed_u = u
|
| 96 |
state.compute_event = torch.cuda.Event()
|
| 97 |
state.compute_event.record()
|
| 98 |
+
# Clear the gathered gradient to free memory
|
| 99 |
+
state.gathered_grad = None
|
| 100 |
else:
|
| 101 |
state.computed_u = None
|
| 102 |
state.compute_event = None
|
|
|
|
| 124 |
group=mesh.get_group(),
|
| 125 |
)
|
| 126 |
if rank == state.worker_rank:
|
| 127 |
+
# Clear u to free memory
|
| 128 |
+
state.computed_u = None
|
| 129 |
u = DTensor.from_local(
|
| 130 |
u,
|
| 131 |
placements=p.placements,
|
|
|
|
| 173 |
adamw_wd=0.1,
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
+
none_grad=True,
|
| 177 |
debug=False,
|
| 178 |
):
|
| 179 |
defaults = dict(
|
|
|
|
| 184 |
ns_steps=ns_steps,
|
| 185 |
adamw_betas=adamw_betas,
|
| 186 |
adamw_eps=adamw_eps,
|
| 187 |
+
none_grad=none_grad,
|
| 188 |
)
|
| 189 |
|
| 190 |
super().__init__(model.parameters(), defaults)
|
|
|
|
| 353 |
def enqueue_gathers(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
+
_gather(p, state, self.rank, self.comm_stream, group["none_grad"])
|
| 357 |
|
| 358 |
def enqueue_computes(start_idx, chunk_size):
|
| 359 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|