Commit
·
1f13dae
1
Parent(s):
02ac540
fix(muon): handle un-distributed env
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +15 -7
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +15 -7
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +15 -7
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +15 -7
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +15 -7
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +15 -7
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +15 -7
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +15 -7
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +15 -7
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +15 -7
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +15 -7
- torch-ext/optimizer/muon.py +15 -7
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787272
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:829533f24bccb220101238fcbafa1343d2ec3ba3922a91a836b8a05813b44672
|
| 3 |
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:880f65ca04a52278892cdcb40dac073f21552ac16b69903f2b8026894a81e35d
|
| 3 |
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c478b90b83052c5931cb3d872adad7811663e28bd3447f12ac412f15b1d0ffc5
|
| 3 |
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749744
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec46d147914be5998dfc62d4b87eb6730be7f012700d49543a318cadab3820db
|
| 3 |
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787192
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e09882858886be06e8ac48d184b320c57624d9c85165ce8b56640b022838e44
|
| 3 |
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f63b2cd2c67b44f5e54837a0a4f26d94d3e6e8bfa4964bd99fc7e38494e2d52
|
| 3 |
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48795cb66a740b14266d757ac70a6b43fb11df6662970bb4040650d237e6cbc5
|
| 3 |
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787368
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec1f34fd4ead50eb51db63f51afc0751d6bf0c64a46c44c713ab245f150979cc
|
| 3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bdb8ab38f72351ae88307560aca5e1af7b2dcb63a39627dbd4c806cad3f83442
|
| 3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0652d611e00b1bcbece47da13dffb28396ae0831dc4be43c7ae9be27ad9a10fe
|
| 3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc differ
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc differ
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_02ac540_dirty
|
| 3 |
+
ops = torch.ops._optimizer_02ac540_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_02ac540_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_64757cb_dirty.abi3.so → _optimizer_02ac540_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749648
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a96bfd1f461d7cd029dd39d142d2999dcc86dd7f56fb40f045e00f3fb2c400bd
|
| 3 |
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -195,12 +195,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
-
if
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
self.rank = dist.get_rank()
|
| 204 |
|
| 205 |
self.comm_stream = torch.cuda.Stream()
|
| 206 |
self.compute_stream = torch.cuda.Stream()
|
|
@@ -209,12 +207,14 @@ class Muon(torch.optim.Optimizer):
|
|
| 209 |
def __setstate__(self, state):
|
| 210 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 211 |
super().__setstate__(state)
|
|
|
|
|
|
|
|
|
|
| 212 |
for name, p in self.model.named_parameters():
|
| 213 |
if self.is_muon_func(p, name):
|
| 214 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 215 |
assert p.ndim == 2, p.ndim
|
| 216 |
self.state[p]["use_muon"] = True
|
| 217 |
-
self.state[p]["orig_shape"] = p.shape
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
@@ -402,6 +402,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
|
|
|
|
|
|
|
|
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
weight_decay = group["weight_decay"]
|
|
@@ -432,6 +435,11 @@ class Muon(torch.optim.Optimizer):
|
|
| 432 |
)
|
| 433 |
|
| 434 |
if len(param_dtensors) > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
self.parallel(
|
| 436 |
param_dtensors,
|
| 437 |
group,
|
|
|
|
| 195 |
self.is_muon_func = is_muon_func
|
| 196 |
self.model = model
|
| 197 |
|
| 198 |
+
if dist.is_initialized():
|
| 199 |
+
self.rank = dist.get_rank()
|
| 200 |
+
else:
|
| 201 |
+
self.rank = None
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.comm_stream = torch.cuda.Stream()
|
| 204 |
self.compute_stream = torch.cuda.Stream()
|
|
|
|
| 207 |
def __setstate__(self, state):
|
| 208 |
# Sort parameters into those for which we will use Muon, and those for which we will not
|
| 209 |
super().__setstate__(state)
|
| 210 |
+
self._init_state()
|
| 211 |
+
|
| 212 |
+
def _init_state(self):
|
| 213 |
for name, p in self.model.named_parameters():
|
| 214 |
if self.is_muon_func(p, name):
|
| 215 |
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
| 216 |
assert p.ndim == 2, p.ndim
|
| 217 |
self.state[p]["use_muon"] = True
|
|
|
|
| 218 |
else:
|
| 219 |
# Do not use Muon for parameters in adamw_params
|
| 220 |
self.state[p]["use_muon"] = False
|
|
|
|
| 402 |
# Muon #
|
| 403 |
############################
|
| 404 |
|
| 405 |
+
if "use_muon" not in self.state[group["params"][0]]:
|
| 406 |
+
self._init_state()
|
| 407 |
+
|
| 408 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 409 |
lr = group["lr"]
|
| 410 |
weight_decay = group["weight_decay"]
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
if len(param_dtensors) > 0:
|
| 438 |
+
if not dist.is_initialized():
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
self.parallel(
|
| 444 |
param_dtensors,
|
| 445 |
group,
|