wyldecat commited on
Commit
f7faa93
·
1 Parent(s): b0f46c7

feat(muon) : add tuned-abc-values & blfoat16 communication

Browse files
Files changed (32) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +1 -1
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +158 -50
  4. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} +1 -1
  6. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +158 -50
  7. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +2 -2
  9. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +158 -50
  10. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +2 -2
  12. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +158 -50
  13. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} +1 -1
  15. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +158 -50
  16. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
  18. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
  19. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +158 -50
  20. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  21. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
  22. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
  23. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +158 -50
  24. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  25. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
  26. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
  27. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +158 -50
  28. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  29. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
  30. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
  31. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +158 -50
  32. torch-ext/optimizer/muon.py +158 -50
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:48cd88108696ba8ed7487e637b785445bb5ff6075a3ae0c15355698958ad340a
3
  size 1787376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55f17ad6ecdd22d84ea5b776a317fa9fbb6b81f622fa8fc80b78e0ef80bd4ea6
3
  size 1787376
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cebddf4b9cb794ad3cd7b88affd011160f7fb9a16257fcb4d942604839b31b37
3
  size 1824264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f37c80a535a081e997c1973902a010c48b33ca40085a7f267a5278e56cff26f3
3
  size 1824264
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b1729faaee0dd55134348a0d775c147cf3aaba106e0475e1389159d48dfc1ebe
3
- size 1883360
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f8bf16b0ae5af74852e8c890183c8c32175886c3d0366cfc776fb3e1ee15906
3
+ size 1883352
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0857945a1ebfdbb6c7219d0b96c8ab47649aa3b47b65fa800c84b51ddbda9c19
3
- size 1749880
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d50267ec23db9512ae1d82c99012901d58e50dee9bf34346702561a5d3e6d9e7
3
+ size 1749840
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5908748e60a61c59e315fbba8b32e3867a4b673b587a2a9606ddde5b4f67da5
3
  size 1824264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80ce6b0d62167a8ea10b6e2a1f90df70aa108997570c0ed210f458debd26f32f
3
  size 1824264
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:22dc3ab77ab74837126281f79f417c5d55b2cc9885388fd9d3a1c7c824ece2bd
3
- size 1883360
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3487612a8f022a1df1353945fc6d65bbd6797179b06c5d3202dc6e2aa6afb27a
3
+ size 1883352
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:62ecfc7e6a1ab0c4ada19ed7aea40fc0a431c4ceb1729666efa98ac0e407f9c8
3
- size 1883360
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5e375def39d93758b60534cef504ae75d9c13e0d86da5dcf7642f1f90b77f52
3
+ size 1883352
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:37e389c650fc1fcbc9fbd68f1e7c1a768b08e90509fd8a5d87879655726f2db2
3
- size 1750040
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33e0d50fbf340612b0e1129717e4116197c8562592e5920f2dedc718ce9a0585
3
+ size 1750000
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_20250911094409
3
- ops = torch.ops._optimizer_20250911094409
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_20250911094409::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_ee6ed44_dirty
3
+ ops = torch.ops._optimizer_ee6ed44_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_ee6ed44_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e62682b711f002505bb17c170b2bb233f8d389510ff8e2e0a753ee96d11d0746
3
- size 1750128
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5eedf56e661a7d314727e40f192236dbd9696f62ba21f11e366643f2662c03a4
3
+ size 1750088
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss
torch-ext/optimizer/muon.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import types
4
  from dataclasses import dataclass
 
5
 
6
  import torch
7
  import torch.distributed as dist
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
12
 
13
  # This code snippet is a modified version adapted from the following GitHub repositories:
14
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
 
15
  @torch.no_grad()
16
  def _zeropower_via_newtonschulz5(G, steps):
17
  """
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
24
  performance at all relative to UV^T, where USV^T = G is the SVD.
25
  """
26
  assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
  X = G # no manual typecast
 
29
  if G.size(0) > G.size(1):
30
  X = X.T
31
  # Ensure spectral norm is at most 1
32
  X = X / (X.norm() + 1e-7)
33
- X = X.bfloat16()
34
  # Perform the NS iterations
35
- for _ in range(steps):
 
 
 
 
 
 
36
  A = X @ X.T
37
  # B = (
38
  # b * A + c * A @ A
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
43
 
44
  if G.size(0) > G.size(1):
45
  X = X.T
46
- return X.to(G.dtype)
47
 
48
 
49
  @dataclass
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
65
  Gather the gradients to worker_rank.
66
  If none_grad is True, free p.grad after the gather.
67
  """
68
- g = p.grad
69
-
70
- if rank == state.worker_rank:
71
- num_ranks = dist.get_world_size(group=state.process_group)
72
- gather_list = [
73
- torch.empty_like(g.to_local()) for _ in range(num_ranks)
74
- ]
75
- else:
76
- gather_list = None
77
-
78
  with torch.cuda.stream(comm_stream):
 
 
 
 
 
 
 
 
 
 
 
 
79
  torch.distributed.gather(
80
  g.to_local(),
81
  dst=state.worker_rank,
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
92
  else:
93
  state.gathered_grad = None
94
  state.gather_event = None
 
95
  if none_grad:
96
  # We can safely free p.grad without calling record_stream:
97
  # p.grad.to_local().record_stream(comm_stream)
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
104
 
105
 
106
  @torch.no_grad()
107
- def _compute_u(state, steps, rank, compute_stream):
108
  """
109
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
110
  """
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
115
  compute_stream.wait_event(state.gather_event)
116
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
117
  state.computed_u = u
118
- state.compute_event = torch.cuda.Event()
119
- state.compute_event.record()
120
- else:
121
- state.computed_u = None
122
- state.compute_event = None
123
 
124
 
125
  @torch.no_grad()
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
129
  """
130
 
131
  with torch.cuda.stream(comm_stream):
 
 
 
 
132
  if rank == state.worker_rank:
133
  num_ranks = dist.get_world_size(group=state.process_group)
134
- if state.compute_event is None:
135
- raise RuntimeError("Compute event must be set before scatter.")
136
- comm_stream.wait_event(state.compute_event)
137
-
138
  # Clear the gathered gradient to free memory
139
  state.gathered_grad = None
140
 
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
144
  else:
145
  scatter_list = None
146
 
147
- u_received = torch.empty_like(p.to_local())
148
  torch.distributed.scatter(
149
- u_received,
150
  scatter_list=scatter_list,
151
  src=state.worker_rank,
152
  group=state.process_group,
153
  )
154
- u_dtensor = DTensor.from_local(
155
- u_received,
156
- placements=p.placements,
157
- device_mesh=p.device_mesh,
158
- )
159
-
160
- state.scattered_u = u_dtensor
161
  state.scatter_event = torch.cuda.Event()
162
  state.scatter_event.record()
 
163
 
164
 
165
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
172
  if state.scatter_event is None:
173
  raise RuntimeError("Scatter event must be set before update")
174
  compute_stream.wait_event(state.scatter_event)
 
 
 
 
 
 
 
 
175
  if rank == state.worker_rank:
176
  # Free computed_u
177
  state.computed_u = None
178
 
179
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
 
 
180
 
181
 
182
  def default_is_muon(name, x):
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
375
  else:
376
  g = buf
377
 
378
- u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
 
379
 
380
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
381
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
433
  def enqueue_computes(start_idx, chunk_size):
434
  for p in ordered_params[start_idx:start_idx + chunk_size]:
435
  state = param_to_state[id(p)]
436
- _compute_u(state, group["ns_steps"], self.rank,
437
  self.compute_stream)
438
 
439
  def enqueue_scatters(start_idx, chunk_size):
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
466
  # Wait the last update_param to finish
467
  torch.cuda.current_stream().wait_stream(self.compute_stream)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def step(self, closure=None):
470
  """Perform a single optimization step.
471
 
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
542
  # AdamW backup #
543
  ############################
544
 
 
 
 
 
 
 
545
  lr = group["lr"]
546
  beta1, beta2 = group["adamw_betas"]
547
  eps = group["adamw_eps"]
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
552
  if g is None:
553
  continue
554
  state = self.state[p]
 
 
555
  if "step" not in state:
556
- state["step"] = 0
 
 
557
  state["moment1"] = torch.zeros_like(g)
558
  state["moment2"] = torch.zeros_like(g)
559
- state["step"] += 1
560
- step = state["step"]
561
- buf1 = state["moment1"]
562
- buf2 = state["moment2"]
563
- buf1.lerp_(g, 1 - beta1)
564
- buf2.lerp_(g.square(), 1 - beta2)
565
-
566
- g = buf1 / (eps + buf2.sqrt())
567
-
568
- bias_correction1 = 1 - beta1**step
569
- bias_correction2 = 1 - beta2**step
570
- scale = bias_correction1 / bias_correction2**0.5
571
- p.data.mul_(1 - lr * weight_decay)
572
- p.data.add_(g, alpha=-lr / scale)
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  return loss
 
2
  import math
3
  import types
4
  from dataclasses import dataclass
5
+ from typing import Optional, Union, cast
6
 
7
  import torch
8
  import torch.distributed as dist
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
+ # Muon's Newton–Schulz iteration causes high variance in singular values
17
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
 
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
+ assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
+
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
37
  # Perform the NS iterations
38
+ for a, b, c in [
39
+ (4.0848, -6.8946, 2.9270),
40
+ (3.9505, -6.3029, 2.6377),
41
+ (3.7418, -5.5913, 2.3037),
42
+ (2.8769, -3.1427, 1.2046),
43
+ (2.8366, -3.0525, 1.2012),
44
+ ]:
45
  A = X @ X.T
46
  # B = (
47
  # b * A + c * A @ A
 
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
55
+ return X
56
 
57
 
58
  @dataclass
 
74
  Gather the gradients to worker_rank.
75
  If none_grad is True, free p.grad after the gather.
76
  """
 
 
 
 
 
 
 
 
 
 
77
  with torch.cuda.stream(comm_stream):
78
+ g = p.grad
79
+
80
+ if rank == state.worker_rank:
81
+ num_ranks = dist.get_world_size(group=state.process_group)
82
+ gather_list = [
83
+ torch.empty_like(g.to_local(), dtype=torch.bfloat16)
84
+ for _ in range(num_ranks)
85
+ ]
86
+ else:
87
+ gather_list = None
88
+
89
+ g = g.to(torch.bfloat16)
90
  torch.distributed.gather(
91
  g.to_local(),
92
  dst=state.worker_rank,
 
103
  else:
104
  state.gathered_grad = None
105
  state.gather_event = None
106
+ gather_list = None
107
  if none_grad:
108
  # We can safely free p.grad without calling record_stream:
109
  # p.grad.to_local().record_stream(comm_stream)
 
116
 
117
 
118
  @torch.no_grad()
119
+ def _compute_u(p, state, steps, rank, compute_stream):
120
  """
121
  On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
122
  """
 
127
  compute_stream.wait_event(state.gather_event)
128
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
129
  state.computed_u = u
130
+ state.scattered_u = torch.empty_like(p.to_local(),
131
+ dtype=torch.bfloat16)
132
+ state.compute_event = torch.cuda.Event()
133
+ state.compute_event.record()
134
+ u = None
135
 
136
 
137
  @torch.no_grad()
 
141
  """
142
 
143
  with torch.cuda.stream(comm_stream):
144
+ if state.compute_event is None:
145
+ raise RuntimeError("Compute event must be set before scatter.")
146
+ comm_stream.wait_event(state.compute_event)
147
+
148
  if rank == state.worker_rank:
149
  num_ranks = dist.get_world_size(group=state.process_group)
 
 
 
 
150
  # Clear the gathered gradient to free memory
151
  state.gathered_grad = None
152
 
 
156
  else:
157
  scatter_list = None
158
 
 
159
  torch.distributed.scatter(
160
+ state.scattered_u,
161
  scatter_list=scatter_list,
162
  src=state.worker_rank,
163
  group=state.process_group,
164
  )
 
 
 
 
 
 
 
165
  state.scatter_event = torch.cuda.Event()
166
  state.scatter_event.record()
167
+ scatter_list = None
168
 
169
 
170
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
177
  if state.scatter_event is None:
178
  raise RuntimeError("Scatter event must be set before update")
179
  compute_stream.wait_event(state.scatter_event)
180
+ u_dtensor = DTensor.from_local(
181
+ state.scattered_u,
182
+ placements=p.placements,
183
+ device_mesh=p.device_mesh,
184
+ )
185
+
186
+ state.scattered_u = u_dtensor
187
+
188
  if rank == state.worker_rank:
189
  # Free computed_u
190
  state.computed_u = None
191
 
192
  Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
193
+ state.scattered_u = None
194
+ u_dtensor = None
195
 
196
 
197
  def default_is_muon(name, x):
 
390
  else:
391
  g = buf
392
 
393
+ u = _zeropower_via_newtonschulz5(g.bfloat16(),
394
+ steps=group["ns_steps"])
395
 
396
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
397
  Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
449
  def enqueue_computes(start_idx, chunk_size):
450
  for p in ordered_params[start_idx:start_idx + chunk_size]:
451
  state = param_to_state[id(p)]
452
+ _compute_u(p, state, group["ns_steps"], self.rank,
453
  self.compute_stream)
454
 
455
  def enqueue_scatters(start_idx, chunk_size):
 
482
  # Wait the last update_param to finish
483
  torch.cuda.current_stream().wait_stream(self.compute_stream)
484
 
485
+ @staticmethod
486
+ def _fused_adamw(
487
+ params: list[torch.Tensor],
488
+ grads: list[torch.Tensor],
489
+ exp_avgs: list[torch.Tensor],
490
+ exp_avg_sqs: list[torch.Tensor],
491
+ max_exp_avg_sqs: list[torch.Tensor],
492
+ state_steps: list[torch.Tensor],
493
+ amsgrad: bool,
494
+ beta1: float,
495
+ beta2: float,
496
+ lr: Union[float, torch.Tensor],
497
+ weight_decay: float,
498
+ eps: float,
499
+ maximize: bool,
500
+ ) -> None:
501
+ if not params:
502
+ return
503
+
504
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
505
+ # treating it as a scalar.
506
+ lr_dict: Optional[DeviceDict] = ({
507
+ lr.device: lr
508
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
509
+ None)
510
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
511
+ [
512
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
513
+ state_steps
514
+ ] # type: ignore[list-item]
515
+ )
516
+ for (device, _), (
517
+ (
518
+ device_params_,
519
+ device_grads_,
520
+ device_exp_avgs_,
521
+ device_exp_avg_sqs_,
522
+ device_max_exp_avg_sqs,
523
+ device_state_steps_,
524
+ ),
525
+ _,
526
+ ) in grouped_tensors.items():
527
+ device_params = cast(list[torch.Tensor], device_params_)
528
+ device_grads = cast(list[torch.Tensor], device_grads_)
529
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
530
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
531
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
532
+
533
+ if lr_dict is not None and device not in lr_dict:
534
+ lr_dict[device] = lr.to(
535
+ device=device,
536
+ non_blocking=True) # type: ignore[union-attr]
537
+ lr = lr_dict[device]
538
+ torch._foreach_add_(device_state_steps, 1)
539
+ func = torch._fused_adamw_
540
+ func(
541
+ device_params,
542
+ device_grads,
543
+ device_exp_avgs,
544
+ device_exp_avg_sqs,
545
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
546
+ device_state_steps,
547
+ amsgrad=amsgrad,
548
+ lr=lr, # type: ignore[arg-type]
549
+ beta1=beta1,
550
+ beta2=beta2,
551
+ weight_decay=weight_decay,
552
+ eps=eps,
553
+ maximize=maximize,
554
+ )
555
+
556
  def step(self, closure=None):
557
  """Perform a single optimization step.
558
 
 
629
  # AdamW backup #
630
  ############################
631
 
632
+ params_with_grads = []
633
+ grads = []
634
+ moment1 = []
635
+ moment2 = []
636
+ max_exp_avg_sqs = []
637
+ state_steps = []
638
  lr = group["lr"]
639
  beta1, beta2 = group["adamw_betas"]
640
  eps = group["adamw_eps"]
 
645
  if g is None:
646
  continue
647
  state = self.state[p]
648
+ params_with_grads.append(p)
649
+ grads.append(g)
650
  if "step" not in state:
651
+ state["step"] = (torch.zeros((),
652
+ dtype=torch.float32,
653
+ device=p.device))
654
  state["moment1"] = torch.zeros_like(g)
655
  state["moment2"] = torch.zeros_like(g)
656
+ moment1.append(state["moment1"])
657
+ moment2.append(state["moment2"])
658
+ if not isinstance(state["step"], torch.Tensor):
659
+ step_tensor = torch.tensor(state["step"],
660
+ dtype=torch.float32,
661
+ device=p.device)
662
+ else:
663
+ step_tensor = state["step"]
664
+ state_steps.append(step_tensor)
665
+
666
+ self._fused_adamw(
667
+ params_with_grads,
668
+ grads,
669
+ moment1,
670
+ moment2,
671
+ max_exp_avg_sqs,
672
+ state_steps,
673
+ amsgrad=False,
674
+ beta1=beta1,
675
+ beta2=beta2,
676
+ lr=lr,
677
+ weight_decay=weight_decay,
678
+ eps=eps,
679
+ maximize=False,
680
+ )
681
 
682
  return loss