feat: update muon to receive paramgroups, not model (#4)
Browse files* feat: update muon to receive paramgroups, not model
* feat: update message formats
* feat: remove boilerplate assertion
* chore: run pre-commits
* test: fix testscript
* fix: fix misc bugs
* feat: add get_default_muon_param_groups helper function
* fix: fix readme
* fix: raise error if parametergroup does not follow instructions
* chore: upload binary
---------
Co-authored-by: junhyeok.lee <[email protected]>
Co-authored-by: WyldeCat <[email protected]>
- README.md +7 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +125 -112
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +125 -112
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +125 -112
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +125 -112
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +125 -112
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +125 -112
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +125 -112
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +125 -112
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +125 -112
- test/test_muon/test.py +3 -2
- torch-ext/optimizer/muon.py +125 -112
README.md
CHANGED
@@ -21,12 +21,18 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
21 |
from kernels import get_kernel
|
22 |
|
23 |
optimizer = get_kernel("motif-technologies/optimizer")
|
|
|
24 |
|
25 |
model = None # your model here
|
26 |
fsdp_model = FSDP(model)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
optim = optimizer.Muon(
|
29 |
-
|
30 |
lr=0.01,
|
31 |
momentum=0.9,
|
32 |
weight_decay=1e-4,
|
|
|
21 |
from kernels import get_kernel
|
22 |
|
23 |
optimizer = get_kernel("motif-technologies/optimizer")
|
24 |
+
get_default_muon_param_groups = optimizer.muon.get_default_muon_param_groups
|
25 |
|
26 |
model = None # your model here
|
27 |
fsdp_model = FSDP(model)
|
28 |
|
29 |
+
# muon, in nature, cannot use 1-d tensor
|
30 |
+
# we provide helper function to group such tensors
|
31 |
+
# you can use your own function, if necessary
|
32 |
+
params = get_default_muon_param_groups(model) # user can write own is_muon_func, if necessary
|
33 |
+
|
34 |
optim = optimizer.Muon(
|
35 |
+
params,
|
36 |
lr=0.01,
|
37 |
momentum=0.9,
|
38 |
weight_decay=1e-4,
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48cd88108696ba8ed7487e637b785445bb5ff6075a3ae0c15355698958ad340a
|
3 |
+
size 1787376
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5908748e60a61c59e315fbba8b32e3867a4b673b587a2a9606ddde5b4f67da5
|
3 |
+
size 1824264
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1729faaee0dd55134348a0d775c147cf3aaba106e0475e1389159d48dfc1ebe
|
3 |
+
size 1883360
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0857945a1ebfdbb6c7219d0b96c8ab47649aa3b47b65fa800c84b51ddbda9c19
|
3 |
+
size 1749880
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:df5044ffb45124dfe7088ed991123724405b00285e4d8d1ba2961802f521aa0f
|
3 |
-
size 1824256
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cebddf4b9cb794ad3cd7b88affd011160f7fb9a16257fcb4d942604839b31b37
|
3 |
+
size 1824264
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:80cb3ac21d3afafe368f31318c31a4c6356b53bbc2186ae81b79e1eb3ff441f5
|
3 |
-
size 1883352
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22dc3ab77ab74837126281f79f417c5d55b2cc9885388fd9d3a1c7c824ece2bd
|
3 |
+
size 1883360
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:32af855517484e2695b6d83c29a03d85fcbaaea559d95cbb62fd9fa67cc3ccac
|
3 |
-
size 1883352
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62ecfc7e6a1ab0c4ada19ed7aea40fc0a431c4ceb1729666efa98ac0e407f9c8
|
3 |
+
size 1883360
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:2dd72f3b9f513dc8bd0724fede9b668761b1d701dfdf3a294979706d803b0800
|
3 |
-
size 1750000
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37e389c650fc1fcbc9fbd68f1e7c1a768b08e90509fd8a5d87879655726f2db2
|
3 |
+
size 1750040
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_20250911094409
|
3 |
+
ops = torch.ops._optimizer_20250911094409
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_20250911094409::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:2a49b0225ecf27b33bbbe55936811ecf443ce97be97ccb7237b3b66eb46c0ad8
|
3 |
-
size 1750088
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e62682b711f002505bb17c170b2bb233f8d389510ff8e2e0a753ee96d11d0746
|
3 |
+
size 1750128
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
test/test_muon/test.py
CHANGED
@@ -2,7 +2,7 @@ import logging
|
|
2 |
|
3 |
import torch
|
4 |
import torch.distributed as dist
|
5 |
-
from muon import Muon
|
6 |
from torch.distributed.fsdp import FSDPModule, fully_shard
|
7 |
from torch.distributed.tensor import DTensor
|
8 |
from torch.distributed.tensor.placement_types import Replicate
|
@@ -54,7 +54,8 @@ def load_model(fsdp: bool) -> torch.nn.Module:
|
|
54 |
|
55 |
def run_muon(fsdp: bool) -> torch.nn.Module:
|
56 |
model = load_model(fsdp=fsdp)
|
57 |
-
|
|
|
58 |
optim.step()
|
59 |
|
60 |
return model
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.distributed as dist
|
5 |
+
from muon import Muon, get_default_muon_param_groups
|
6 |
from torch.distributed.fsdp import FSDPModule, fully_shard
|
7 |
from torch.distributed.tensor import DTensor
|
8 |
from torch.distributed.tensor.placement_types import Replicate
|
|
|
54 |
|
55 |
def run_muon(fsdp: bool) -> torch.nn.Module:
|
56 |
model = load_model(fsdp=fsdp)
|
57 |
+
params = get_default_muon_param_groups(model)
|
58 |
+
optim = Muon(params=params)
|
59 |
optim.step()
|
60 |
|
61 |
return model
|
torch-ext/optimizer/muon.py
CHANGED
@@ -1,10 +1,14 @@
|
|
|
|
1 |
import math
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
7 |
|
|
|
|
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
10 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
175 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
176 |
|
177 |
|
178 |
-
def default_is_muon(
|
179 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
180 |
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
class Muon(torch.optim.Optimizer):
|
183 |
"""
|
184 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
|
|
210 |
|
211 |
def __init__(
|
212 |
self,
|
213 |
-
|
214 |
-
is_muon_func=default_is_muon,
|
215 |
lr=1e-3,
|
216 |
momentum=0.95,
|
217 |
nesterov=True,
|
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
|
|
231 |
adamw_betas=adamw_betas,
|
232 |
adamw_eps=adamw_eps,
|
233 |
none_grad=none_grad,
|
|
|
234 |
)
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if dist.is_initialized():
|
241 |
self.rank = dist.get_rank()
|
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
|
|
246 |
self.compute_stream = torch.cuda.Stream()
|
247 |
self.debug = debug
|
248 |
|
249 |
-
def __setstate__(self, state):
|
250 |
-
# Sort parameters into those for which we will use Muon, and those for which we will not
|
251 |
-
super().__setstate__(state)
|
252 |
-
self._init_state()
|
253 |
-
|
254 |
-
def _init_state(self):
|
255 |
-
for name, p in self.model.named_parameters():
|
256 |
-
if self.is_muon_func(p, name):
|
257 |
-
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
258 |
-
assert p.ndim == 2, p.ndim
|
259 |
-
self.state[p]["use_muon"] = True
|
260 |
-
else:
|
261 |
-
# Do not use Muon for parameters in adamw_params
|
262 |
-
self.state[p]["use_muon"] = False
|
263 |
-
|
264 |
def _calc_flops(self, G, steps):
|
265 |
assert len(G.shape) == 2
|
266 |
M, N = G.shape
|
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
|
|
462 |
loss = closure()
|
463 |
|
464 |
for group in self.param_groups:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
487 |
param_tensors.append(p)
|
488 |
else:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
f"
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
506 |
)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
buf1.lerp_(g, 1 - beta1)
|
551 |
-
buf2.lerp_(g.square(), 1 - beta2)
|
552 |
-
|
553 |
-
g = buf1 / (eps + buf2.sqrt())
|
554 |
-
|
555 |
-
bias_correction1 = 1 - beta1**step
|
556 |
-
bias_correction2 = 1 - beta2**step
|
557 |
-
scale = bias_correction1 / bias_correction2**0.5
|
558 |
-
p.data.mul_(1 - lr * weight_decay)
|
559 |
-
p.data.add_(g, alpha=-lr / scale)
|
560 |
|
561 |
return loss
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
+
import types
|
4 |
from dataclasses import dataclass
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
from torch.distributed._tensor import DTensor, Replicate, Shard
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
180 |
|
181 |
|
182 |
+
def default_is_muon(name, x):
|
183 |
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
184 |
|
185 |
|
186 |
+
def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
|
187 |
+
return [
|
188 |
+
{
|
189 |
+
"params": [
|
190 |
+
p for n, p in model.named_parameters()
|
191 |
+
if (is_muon_func(n, p) and p.requires_grad)
|
192 |
+
],
|
193 |
+
"use_muon":
|
194 |
+
True
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"params": [
|
198 |
+
p for n, p in model.named_parameters()
|
199 |
+
if (not is_muon_func(n, p) and p.requires_grad)
|
200 |
+
],
|
201 |
+
"use_muon":
|
202 |
+
False
|
203 |
+
},
|
204 |
+
]
|
205 |
+
|
206 |
+
|
207 |
class Muon(torch.optim.Optimizer):
|
208 |
"""
|
209 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
235 |
|
236 |
def __init__(
|
237 |
self,
|
238 |
+
params,
|
|
|
239 |
lr=1e-3,
|
240 |
momentum=0.95,
|
241 |
nesterov=True,
|
|
|
255 |
adamw_betas=adamw_betas,
|
256 |
adamw_eps=adamw_eps,
|
257 |
none_grad=none_grad,
|
258 |
+
use_muon=True,
|
259 |
)
|
260 |
+
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
|
261 |
+
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
|
262 |
|
263 |
+
if isinstance(params, types.GeneratorType):
|
264 |
+
raise ValueError(error_message.format(idx=0) + instruction_code)
|
265 |
+
for _idx, param_group in enumerate(params):
|
266 |
+
if param_group.get("use_muon", None) is None:
|
267 |
+
raise ValueError(
|
268 |
+
error_message.format(idx=_idx) + instruction_code)
|
269 |
+
|
270 |
+
super().__init__(params, defaults)
|
271 |
|
272 |
if dist.is_initialized():
|
273 |
self.rank = dist.get_rank()
|
|
|
278 |
self.compute_stream = torch.cuda.Stream()
|
279 |
self.debug = debug
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def _calc_flops(self, G, steps):
|
282 |
assert len(G.shape) == 2
|
283 |
M, N = G.shape
|
|
|
479 |
loss = closure()
|
480 |
|
481 |
for group in self.param_groups:
|
482 |
+
params = group["params"]
|
483 |
+
|
484 |
+
if group["use_muon"]:
|
485 |
+
############################
|
486 |
+
# Muon #
|
487 |
+
############################
|
488 |
+
lr = group["lr"]
|
489 |
+
weight_decay = group["weight_decay"]
|
490 |
+
momentum = group["momentum"]
|
491 |
+
|
492 |
+
param_dtensors = []
|
493 |
+
param_tensors = []
|
494 |
+
|
495 |
+
for p in params:
|
496 |
+
if p is None or p.grad is None:
|
497 |
+
continue
|
498 |
+
if isinstance(p.data, DTensor):
|
499 |
+
if all(
|
500 |
+
isinstance(placement, Replicate)
|
501 |
+
for placement in p.placements):
|
502 |
+
param_tensors.append(p)
|
503 |
+
else:
|
504 |
+
param_dtensors.append(p)
|
505 |
+
elif isinstance(p.data, torch.Tensor):
|
506 |
param_tensors.append(p)
|
507 |
else:
|
508 |
+
raise TypeError(
|
509 |
+
f"Unsupported parameter type: {type(p.data)}")
|
510 |
+
|
511 |
+
if self.debug:
|
512 |
+
print(
|
513 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
514 |
+
flush=True,
|
515 |
+
)
|
516 |
+
|
517 |
+
if len(param_dtensors) > 0:
|
518 |
+
if not dist.is_initialized():
|
519 |
+
raise RuntimeError(
|
520 |
+
"Parallel Muon requires torch.distributed to be initialized."
|
521 |
+
)
|
522 |
+
|
523 |
+
self.parallel(
|
524 |
+
param_dtensors,
|
525 |
+
group,
|
526 |
+
lr=lr,
|
527 |
+
weight_decay=weight_decay,
|
528 |
+
momentum=momentum,
|
529 |
)
|
530 |
|
531 |
+
if len(param_tensors) > 0:
|
532 |
+
self.base(
|
533 |
+
param_tensors,
|
534 |
+
group,
|
535 |
+
lr=lr,
|
536 |
+
weight_decay=weight_decay,
|
537 |
+
momentum=momentum,
|
538 |
+
)
|
539 |
+
|
540 |
+
else:
|
541 |
+
############################
|
542 |
+
# AdamW backup #
|
543 |
+
############################
|
544 |
+
|
545 |
+
lr = group["lr"]
|
546 |
+
beta1, beta2 = group["adamw_betas"]
|
547 |
+
eps = group["adamw_eps"]
|
548 |
+
weight_decay = group["weight_decay"]
|
549 |
+
|
550 |
+
for p in params:
|
551 |
+
g = p.grad
|
552 |
+
if g is None:
|
553 |
+
continue
|
554 |
+
state = self.state[p]
|
555 |
+
if "step" not in state:
|
556 |
+
state["step"] = 0
|
557 |
+
state["moment1"] = torch.zeros_like(g)
|
558 |
+
state["moment2"] = torch.zeros_like(g)
|
559 |
+
state["step"] += 1
|
560 |
+
step = state["step"]
|
561 |
+
buf1 = state["moment1"]
|
562 |
+
buf2 = state["moment2"]
|
563 |
+
buf1.lerp_(g, 1 - beta1)
|
564 |
+
buf2.lerp_(g.square(), 1 - beta2)
|
565 |
+
|
566 |
+
g = buf1 / (eps + buf2.sqrt())
|
567 |
+
|
568 |
+
bias_correction1 = 1 - beta1**step
|
569 |
+
bias_correction2 = 1 - beta2**step
|
570 |
+
scale = bias_correction1 / bias_correction2**0.5
|
571 |
+
p.data.mul_(1 - lr * weight_decay)
|
572 |
+
p.data.add_(g, alpha=-lr / scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|