junhyeok-motech leejunhyeok wyldecat commited on
Commit
b0f46c7
·
unverified ·
1 Parent(s): 99e7c0c

feat: update muon to receive paramgroups, not model (#4)

Browse files

* feat: update muon to receive paramgroups, not model

* feat: update message formats

* feat: remove boilerplate assertion

* chore: run pre-commits

* test: fix testscript

* fix: fix misc bugs

* feat: add get_default_muon_param_groups helper function

* fix: fix readme

* fix: raise error if parametergroup does not follow instructions

* chore: upload binary

---------

Co-authored-by: junhyeok.lee <[email protected]>
Co-authored-by: WyldeCat <[email protected]>

Files changed (35) hide show
  1. README.md +7 -1
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
  4. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +125 -112
  5. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  6. build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
  7. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +125 -112
  8. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  9. build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
  10. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +125 -112
  11. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  12. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} +2 -2
  13. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +125 -112
  14. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  15. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
  16. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
  17. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +125 -112
  18. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  19. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
  20. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
  21. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +125 -112
  22. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
  24. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
  25. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +125 -112
  26. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  27. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
  28. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
  29. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +125 -112
  30. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  31. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so +0 -3
  32. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +3 -0
  33. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +125 -112
  34. test/test_muon/test.py +3 -2
  35. torch-ext/optimizer/muon.py +125 -112
README.md CHANGED
@@ -21,12 +21,18 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
21
  from kernels import get_kernel
22
 
23
  optimizer = get_kernel("motif-technologies/optimizer")
 
24
 
25
  model = None # your model here
26
  fsdp_model = FSDP(model)
27
 
 
 
 
 
 
28
  optim = optimizer.Muon(
29
- fsdp_model.parameters(),
30
  lr=0.01,
31
  momentum=0.9,
32
  weight_decay=1e-4,
 
21
  from kernels import get_kernel
22
 
23
  optimizer = get_kernel("motif-technologies/optimizer")
24
+ get_default_muon_param_groups = optimizer.muon.get_default_muon_param_groups
25
 
26
  model = None # your model here
27
  fsdp_model = FSDP(model)
28
 
29
+ # muon, in nature, cannot use 1-d tensor
30
+ # we provide helper function to group such tensors
31
+ # you can use your own function, if necessary
32
+ params = get_default_muon_param_groups(model) # user can write own is_muon_func, if necessary
33
+
34
  optim = optimizer.Muon(
35
+ params,
36
  lr=0.01,
37
  momentum=0.9,
38
  weight_decay=1e-4,
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1c2963bea474b130d3b22e507692b42c1926d0b93c20495789602da2caff5ef3
3
- size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48cd88108696ba8ed7487e637b785445bb5ff6075a3ae0c15355698958ad340a
3
+ size 1787376
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a55c3d0aba4548dc74a08d66987307bd381c2d93b149702fbdc60da19e03e5fc
3
- size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5908748e60a61c59e315fbba8b32e3867a4b673b587a2a9606ddde5b4f67da5
3
+ size 1824264
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c319d0fb497363746229fbabed6d14b82090a660de602125fb67135117c53f5a
3
- size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1729faaee0dd55134348a0d775c147cf3aaba106e0475e1389159d48dfc1ebe
3
+ size 1883360
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_0c12ced_dirty.abi3.so → _optimizer_20250911094409.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e8bda6399291a15b5bcba88214ffd3d0291b10d1cdfb0ab668436d176a9396ec
3
- size 1749840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0857945a1ebfdbb6c7219d0b96c8ab47649aa3b47b65fa800c84b51ddbda9c19
3
+ size 1749880
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:df5044ffb45124dfe7088ed991123724405b00285e4d8d1ba2961802f521aa0f
3
- size 1824256
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cebddf4b9cb794ad3cd7b88affd011160f7fb9a16257fcb4d942604839b31b37
3
+ size 1824264
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:80cb3ac21d3afafe368f31318c31a4c6356b53bbc2186ae81b79e1eb3ff441f5
3
- size 1883352
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22dc3ab77ab74837126281f79f417c5d55b2cc9885388fd9d3a1c7c824ece2bd
3
+ size 1883360
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:32af855517484e2695b6d83c29a03d85fcbaaea559d95cbb62fd9fa67cc3ccac
3
- size 1883352
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62ecfc7e6a1ab0c4ada19ed7aea40fc0a431c4ceb1729666efa98ac0e407f9c8
3
+ size 1883360
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2dd72f3b9f513dc8bd0724fede9b668761b1d701dfdf3a294979706d803b0800
3
- size 1750000
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37e389c650fc1fcbc9fbd68f1e7c1a768b08e90509fd8a5d87879655726f2db2
3
+ size 1750040
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_0c12ced_dirty
3
- ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_0c12ced_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_20250911094409
3
+ ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_20250911094409::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a49b0225ecf27b33bbbe55936811ecf443ce97be97ccb7237b3b66eb46c0ad8
3
- size 1750088
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e62682b711f002505bb17c170b2bb233f8d389510ff8e2e0a753ee96d11d0746
3
+ size 1750128
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
test/test_muon/test.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
 
3
  import torch
4
  import torch.distributed as dist
5
- from muon import Muon
6
  from torch.distributed.fsdp import FSDPModule, fully_shard
7
  from torch.distributed.tensor import DTensor
8
  from torch.distributed.tensor.placement_types import Replicate
@@ -54,7 +54,8 @@ def load_model(fsdp: bool) -> torch.nn.Module:
54
 
55
  def run_muon(fsdp: bool) -> torch.nn.Module:
56
  model = load_model(fsdp=fsdp)
57
- optim = Muon(model)
 
58
  optim.step()
59
 
60
  return model
 
2
 
3
  import torch
4
  import torch.distributed as dist
5
+ from muon import Muon, get_default_muon_param_groups
6
  from torch.distributed.fsdp import FSDPModule, fully_shard
7
  from torch.distributed.tensor import DTensor
8
  from torch.distributed.tensor.placement_types import Replicate
 
54
 
55
  def run_muon(fsdp: bool) -> torch.nn.Module:
56
  model = load_model(fsdp=fsdp)
57
+ params = get_default_muon_param_groups(model)
58
+ optim = Muon(params=params)
59
  optim.step()
60
 
61
  return model
torch-ext/optimizer/muon.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import math
 
2
  from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
  from torch.distributed._tensor import DTensor, Replicate, Shard
7
 
 
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
10
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -175,10 +179,31 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
175
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
- def default_is_muon(x, name):
179
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  class Muon(torch.optim.Optimizer):
183
  """
184
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -210,8 +235,7 @@ class Muon(torch.optim.Optimizer):
210
 
211
  def __init__(
212
  self,
213
- model,
214
- is_muon_func=default_is_muon,
215
  lr=1e-3,
216
  momentum=0.95,
217
  nesterov=True,
@@ -231,11 +255,19 @@ class Muon(torch.optim.Optimizer):
231
  adamw_betas=adamw_betas,
232
  adamw_eps=adamw_eps,
233
  none_grad=none_grad,
 
234
  )
 
 
235
 
236
- super().__init__(model.parameters(), defaults)
237
- self.is_muon_func = is_muon_func
238
- self.model = model
 
 
 
 
 
239
 
240
  if dist.is_initialized():
241
  self.rank = dist.get_rank()
@@ -246,21 +278,6 @@ class Muon(torch.optim.Optimizer):
246
  self.compute_stream = torch.cuda.Stream()
247
  self.debug = debug
248
 
249
- def __setstate__(self, state):
250
- # Sort parameters into those for which we will use Muon, and those for which we will not
251
- super().__setstate__(state)
252
- self._init_state()
253
-
254
- def _init_state(self):
255
- for name, p in self.model.named_parameters():
256
- if self.is_muon_func(p, name):
257
- # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
258
- assert p.ndim == 2, p.ndim
259
- self.state[p]["use_muon"] = True
260
- else:
261
- # Do not use Muon for parameters in adamw_params
262
- self.state[p]["use_muon"] = False
263
-
264
  def _calc_flops(self, G, steps):
265
  assert len(G.shape) == 2
266
  M, N = G.shape
@@ -462,100 +479,96 @@ class Muon(torch.optim.Optimizer):
462
  loss = closure()
463
 
464
  for group in self.param_groups:
465
- ############################
466
- # Muon #
467
- ############################
468
-
469
- if "use_muon" not in self.state[group["params"][0]]:
470
- self._init_state()
471
-
472
- params = [p for p in group["params"] if self.state[p]["use_muon"]]
473
- lr = group["lr"]
474
- weight_decay = group["weight_decay"]
475
- momentum = group["momentum"]
476
-
477
- param_dtensors = []
478
- param_tensors = []
479
-
480
- for p in params:
481
- if p is None or p.grad is None:
482
- continue
483
- if isinstance(p.data, DTensor):
484
- if all(
485
- isinstance(placement, Replicate)
486
- for placement in p.placements):
 
 
487
  param_tensors.append(p)
488
  else:
489
- param_dtensors.append(p)
490
- elif isinstance(p.data, torch.Tensor):
491
- param_tensors.append(p)
492
- else:
493
- raise TypeError(
494
- f"Unsupported parameter type: {type(p.data)}")
495
-
496
- if self.debug:
497
- print(
498
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
499
- flush=True,
500
- )
501
-
502
- if len(param_dtensors) > 0:
503
- if not dist.is_initialized():
504
- raise RuntimeError(
505
- "Parallel Muon requires torch.distributed to be initialized."
 
 
 
 
506
  )
507
 
508
- self.parallel(
509
- param_dtensors,
510
- group,
511
- lr=lr,
512
- weight_decay=weight_decay,
513
- momentum=momentum,
514
- )
515
-
516
- if len(param_tensors) > 0:
517
- self.base(
518
- param_tensors,
519
- group,
520
- lr=lr,
521
- weight_decay=weight_decay,
522
- momentum=momentum,
523
- )
524
-
525
- ############################
526
- # AdamW backup #
527
- ############################
528
-
529
- params = [
530
- p for p in group["params"] if not self.state[p]["use_muon"]
531
- ]
532
- lr = group["lr"]
533
- beta1, beta2 = group["adamw_betas"]
534
- eps = group["adamw_eps"]
535
- weight_decay = group["weight_decay"]
536
-
537
- for p in params:
538
- g = p.grad
539
- if g is None:
540
- continue
541
- state = self.state[p]
542
- if "step" not in state:
543
- state["step"] = 0
544
- state["moment1"] = torch.zeros_like(g)
545
- state["moment2"] = torch.zeros_like(g)
546
- state["step"] += 1
547
- step = state["step"]
548
- buf1 = state["moment1"]
549
- buf2 = state["moment2"]
550
- buf1.lerp_(g, 1 - beta1)
551
- buf2.lerp_(g.square(), 1 - beta2)
552
-
553
- g = buf1 / (eps + buf2.sqrt())
554
-
555
- bias_correction1 = 1 - beta1**step
556
- bias_correction2 = 1 - beta2**step
557
- scale = bias_correction1 / bias_correction2**0.5
558
- p.data.mul_(1 - lr * weight_decay)
559
- p.data.add_(g, alpha=-lr / scale)
560
 
561
  return loss
 
1
+ import logging
2
  import math
3
+ import types
4
  from dataclasses import dataclass
5
 
6
  import torch
7
  import torch.distributed as dist
8
  from torch.distributed._tensor import DTensor, Replicate, Shard
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
180
 
181
 
182
+ def default_is_muon(name, x):
183
  return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
184
 
185
 
186
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
187
+ return [
188
+ {
189
+ "params": [
190
+ p for n, p in model.named_parameters()
191
+ if (is_muon_func(n, p) and p.requires_grad)
192
+ ],
193
+ "use_muon":
194
+ True
195
+ },
196
+ {
197
+ "params": [
198
+ p for n, p in model.named_parameters()
199
+ if (not is_muon_func(n, p) and p.requires_grad)
200
+ ],
201
+ "use_muon":
202
+ False
203
+ },
204
+ ]
205
+
206
+
207
  class Muon(torch.optim.Optimizer):
208
  """
209
  Muon - MomentUm Orthogonalized by Newton-schulz
 
235
 
236
  def __init__(
237
  self,
238
+ params,
 
239
  lr=1e-3,
240
  momentum=0.95,
241
  nesterov=True,
 
255
  adamw_betas=adamw_betas,
256
  adamw_eps=adamw_eps,
257
  none_grad=none_grad,
258
+ use_muon=True,
259
  )
260
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
261
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
262
 
263
+ if isinstance(params, types.GeneratorType):
264
+ raise ValueError(error_message.format(idx=0) + instruction_code)
265
+ for _idx, param_group in enumerate(params):
266
+ if param_group.get("use_muon", None) is None:
267
+ raise ValueError(
268
+ error_message.format(idx=_idx) + instruction_code)
269
+
270
+ super().__init__(params, defaults)
271
 
272
  if dist.is_initialized():
273
  self.rank = dist.get_rank()
 
278
  self.compute_stream = torch.cuda.Stream()
279
  self.debug = debug
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def _calc_flops(self, G, steps):
282
  assert len(G.shape) == 2
283
  M, N = G.shape
 
479
  loss = closure()
480
 
481
  for group in self.param_groups:
482
+ params = group["params"]
483
+
484
+ if group["use_muon"]:
485
+ ############################
486
+ # Muon #
487
+ ############################
488
+ lr = group["lr"]
489
+ weight_decay = group["weight_decay"]
490
+ momentum = group["momentum"]
491
+
492
+ param_dtensors = []
493
+ param_tensors = []
494
+
495
+ for p in params:
496
+ if p is None or p.grad is None:
497
+ continue
498
+ if isinstance(p.data, DTensor):
499
+ if all(
500
+ isinstance(placement, Replicate)
501
+ for placement in p.placements):
502
+ param_tensors.append(p)
503
+ else:
504
+ param_dtensors.append(p)
505
+ elif isinstance(p.data, torch.Tensor):
506
  param_tensors.append(p)
507
  else:
508
+ raise TypeError(
509
+ f"Unsupported parameter type: {type(p.data)}")
510
+
511
+ if self.debug:
512
+ print(
513
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
514
+ flush=True,
515
+ )
516
+
517
+ if len(param_dtensors) > 0:
518
+ if not dist.is_initialized():
519
+ raise RuntimeError(
520
+ "Parallel Muon requires torch.distributed to be initialized."
521
+ )
522
+
523
+ self.parallel(
524
+ param_dtensors,
525
+ group,
526
+ lr=lr,
527
+ weight_decay=weight_decay,
528
+ momentum=momentum,
529
  )
530
 
531
+ if len(param_tensors) > 0:
532
+ self.base(
533
+ param_tensors,
534
+ group,
535
+ lr=lr,
536
+ weight_decay=weight_decay,
537
+ momentum=momentum,
538
+ )
539
+
540
+ else:
541
+ ############################
542
+ # AdamW backup #
543
+ ############################
544
+
545
+ lr = group["lr"]
546
+ beta1, beta2 = group["adamw_betas"]
547
+ eps = group["adamw_eps"]
548
+ weight_decay = group["weight_decay"]
549
+
550
+ for p in params:
551
+ g = p.grad
552
+ if g is None:
553
+ continue
554
+ state = self.state[p]
555
+ if "step" not in state:
556
+ state["step"] = 0
557
+ state["moment1"] = torch.zeros_like(g)
558
+ state["moment2"] = torch.zeros_like(g)
559
+ state["step"] += 1
560
+ step = state["step"]
561
+ buf1 = state["moment1"]
562
+ buf2 = state["moment2"]
563
+ buf1.lerp_(g, 1 - beta1)
564
+ buf2.lerp_(g.square(), 1 - beta2)
565
+
566
+ g = buf1 / (eps + buf2.sqrt())
567
+
568
+ bias_correction1 = 1 - beta1**step
569
+ bias_correction2 = 1 - beta2**step
570
+ scale = bias_correction1 / bias_correction2**0.5
571
+ p.data.mul_(1 - lr * weight_decay)
572
+ p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss