wyldecat commited on
Commit
99e7c0c
·
1 Parent(s): 3261444

fix(muon): add update_p stage and dealloc tensors properly

Browse files
Files changed (47) hide show
  1. .gitignore +0 -1
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  4. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  6. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +100 -51
  7. build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  9. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  10. build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  11. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +100 -51
  12. build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  13. build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  14. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  15. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so} +1 -1
  16. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +100 -51
  17. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  18. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  19. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  21. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +100 -51
  22. build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  23. build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  24. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  25. build/torch28-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  26. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +100 -51
  27. build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  28. build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  29. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  30. build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so} +1 -1
  31. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +100 -51
  32. build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  33. build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  34. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  35. build/torch28-cxx11-cu129-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  36. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +100 -51
  37. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  38. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  39. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  40. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  41. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +100 -51
  42. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc +0 -0
  43. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc +0 -0
  44. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  45. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} +1 -1
  46. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +100 -51
  47. torch-ext/optimizer/muon.py +65 -28
.gitignore CHANGED
@@ -2,7 +2,6 @@ __pycache__
2
  .idea
3
  .DS_Store
4
  *.egg-info
5
- build
6
  outputs
7
  dist/*
8
  .vscode
 
2
  .idea
3
  .DS_Store
4
  *.egg-info
 
5
  outputs
6
  dist/*
7
  .vscode
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (307 Bytes)
 
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9112c8dde01baefa0e3130e143288cd3073ccbab47369a6dc925ce0d35400c6d
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c2963bea474b130d3b22e507692b42c1926d0b93c20495789602da2caff5ef3
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (307 Bytes)
 
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0449cd352f44c3e848d1f9c847b00bf576673b4fef2a954ec8bd8d2524b8353a
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a55c3d0aba4548dc74a08d66987307bd381c2d93b149702fbdc60da19e03e5fc
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (307 Bytes)
 
build/torch27-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:86d98863cc7ef0b271808b0ef7b1082603cfb5a76986481df37431527aaaf27b
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c319d0fb497363746229fbabed6d14b82090a660de602125fb67135117c53f5a
3
  size 1883352
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (308 Bytes)
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bdcf9e3d8bf13aa01bf1ae7a94a12dd05c50702a24b57e4cfcc2e54ca5ed21c3
3
  size 1749840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8bda6399291a15b5bcba88214ffd3d0291b10d1cdfb0ab668436d176a9396ec
3
  size 1749840
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (307 Bytes)
 
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a423eb4ab3a31c53a3326c71e34fa59fc661f8d432701e41a7de900a9c23e37c
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5044ffb45124dfe7088ed991123724405b00285e4d8d1ba2961802f521aa0f
3
  size 1824256
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (307 Bytes)
 
build/torch28-cxx11-cu128-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_2dc97a1_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2e6bab72b965f42d466cd74bbda49851549f2810278e642cef8738e40de4fdc5
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80cb3ac21d3afafe368f31318c31a4c6356b53bbc2186ae81b79e1eb3ff441f5
3
  size 1883352
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (307 Bytes)
 
build/torch28-cxx11-cu129-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f8daaad69e6958850f848fab60c9acb938c3a5e54e3ec34a1bec03a3d32653cb
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32af855517484e2695b6d83c29a03d85fcbaaea559d95cbb62fd9fa67cc3ccac
3
  size 1883352
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (308 Bytes)
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:76910ba81e2c95c83207118725c4379db636346c4ccf05010e2ee00c41dff1ce
3
  size 1750000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dd72f3b9f513dc8bd0724fede9b668761b1d701dfdf3a294979706d803b0800
3
  size 1750000
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (308 Bytes)
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__pycache__/muon.cpython-313.pyc DELETED
Binary file (23.4 kB)
 
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_2dc97a1_dirty
3
- ops = torch.ops._optimizer_2dc97a1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_2dc97a1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_0c12ced_dirty
3
+ ops = torch.ops._optimizer_0c12ced_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_0c12ced_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_2dc97a1_dirty.abi3.so → _optimizer_0c12ced_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd0a35a6f846a075a8f4561cfc66ef17c6358dd4a0062e63057b02625d9d6af7
3
  size 1750088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a49b0225ecf27b33bbbe55936811ecf443ce97be97ccb7237b3b66eb46c0ad8
3
  size 1750088
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -47,19 +47,27 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +81,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -82,11 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
82
  state.gathered_grad = None
83
  state.gather_event = None
84
  if none_grad:
 
 
 
 
 
 
 
85
  p.grad = None
86
 
87
 
88
  @torch.no_grad()
89
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
90
  with torch.cuda.stream(compute_stream):
91
  if rank == state.worker_rank:
92
  if state.gather_event is None:
@@ -96,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
96
  state.computed_u = u
97
  state.compute_event = torch.cuda.Event()
98
  state.compute_event.record()
99
- # Clear the gathered gradient to free memory
100
- state.gathered_grad = None
101
  else:
102
  state.computed_u = None
103
  state.compute_event = None
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
- u = state.computed_u
 
 
109
 
110
  with torch.cuda.stream(comm_stream):
111
  if rank == state.worker_rank:
@@ -113,27 +130,49 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
113
  if state.compute_event is None:
114
  raise RuntimeError("Compute event must be set before scatter.")
115
  comm_stream.wait_event(state.compute_event)
 
 
 
 
 
116
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
117
  else:
118
  scatter_list = None
119
 
120
- u = torch.empty_like(p.to_local())
121
  torch.distributed.scatter(
122
- u,
123
  scatter_list=scatter_list,
124
  src=state.worker_rank,
125
  group=state.process_group,
126
  )
127
- if rank == state.worker_rank:
128
- # Clear u to free memory
129
- state.computed_u = None
130
- u = DTensor.from_local(
131
- u,
132
  placements=p.placements,
133
  device_mesh=p.device_mesh,
134
  )
135
- p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def default_is_muon(x, name):
@@ -154,17 +193,19 @@ class Muon(torch.optim.Optimizer):
154
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
155
 
156
  Arguments:
157
- muon_params: The parameters to be optimized by Muon.
 
158
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
159
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
160
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
161
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
162
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
163
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
164
  adamw_lr: The learning rate for the internal AdamW.
165
  adamw_betas: The betas for the internal AdamW.
166
  adamw_eps: The epsilon for the internal AdamW.
167
- adamw_weight_decay: The weight decay for the internal AdamW.
 
168
  """
169
 
170
  def __init__(
@@ -240,9 +281,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +311,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -317,14 +360,8 @@ class Muon(torch.optim.Optimizer):
317
 
318
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
319
 
320
- # scale update
321
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
322
-
323
- # apply weight decay
324
- p.data.mul_(1 - lr * weight_decay)
325
-
326
- # apply update
327
- p.data.add_(u, alpha=-adjusted_lr)
328
 
329
  def _update_g(self, p, g, group, momentum):
330
  # calc update
@@ -339,9 +376,8 @@ class Muon(torch.optim.Optimizer):
339
  g = buf
340
  return g
341
 
342
- def _update_p(self, p, u, lr, weight_decay):
343
- # scale update
344
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
345
  # apply weight decay
346
  p.data.mul_(1 - lr * weight_decay)
347
  # apply update
@@ -369,28 +405,34 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
 
 
 
 
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -398,10 +440,14 @@ class Muon(torch.optim.Optimizer):
398
  enqueue_gathers(0, chunk_size)
399
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
400
  enqueue_computes(i, chunk_size)
 
 
401
  enqueue_gathers(i + chunk_size, chunk_size)
402
  enqueue_scatters(i, chunk_size)
 
403
 
404
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
405
 
406
  def step(self, closure=None):
407
  """Perform a single optimization step.
@@ -436,15 +482,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +526,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
67
  num_ranks = dist.get_world_size(group=state.process_group)
68
+ gather_list = [
69
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
70
+ ]
71
  else:
72
  gather_list = None
73
 
 
81
  if rank == state.worker_rank:
82
  if state.gathered_grad is not None:
83
  raise RuntimeError(
84
+ "Gather event already exists, which should not happen.")
 
85
  state.gathered_grad = torch.cat(gather_list, dim=0)
86
  state.gather_event = torch.cuda.Event()
87
  state.gather_event.record()
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
+ u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
+ scatter_list = [s.contiguous() for s in scatter_list]
140
  else:
141
  scatter_list = None
142
 
143
+ u_received = torch.empty_like(p.to_local())
144
  torch.distributed.scatter(
145
+ u_received,
146
  scatter_list=scatter_list,
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
150
+ u_dtensor = DTensor.from_local(
151
+ u_received,
 
 
 
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
281
  """
282
  Get the shard mesh for a parameter p on the given rank.
283
  """
284
+ assert isinstance(
285
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
286
 
287
+ if p.placements == (Shard(dim=0), ):
288
  # Case for FSDP
289
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
290
  elif p.placements == (Replicate(), Shard(dim=0)):
 
311
  total_flops += flops
312
 
313
  if self.debug:
314
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
315
+ flush=True)
316
 
317
+ ordered_params = sorted(params,
318
+ key=lambda p: param_to_flops[id(p)],
319
+ reverse=True)
320
 
321
  round_robin = 0
322
  mesh = None
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
405
  p.grad = g
406
 
407
  param_to_state, ordered_params = self.init_state_and_assign_params(
408
+ params, group)
 
409
 
410
  def enqueue_gathers(start_idx, chunk_size):
411
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
412
  state = param_to_state[id(p)]
413
+ _gather(p, state, self.rank, self.comm_stream,
414
+ group["none_grad"])
415
 
416
  def enqueue_computes(start_idx, chunk_size):
417
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
418
  state = param_to_state[id(p)]
419
+ _compute_u(state, group["ns_steps"], self.rank,
420
+ self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
 
433
 
434
+ chunk_size = dist.get_world_size(param_to_state[id(
435
+ params[0])].process_group)
436
 
437
  # Wait grad update
438
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.
 
482
  continue
483
  if isinstance(p.data, DTensor):
484
  if all(
485
+ isinstance(placement, Replicate)
486
+ for placement in p.placements):
487
  param_tensors.append(p)
488
  else:
489
  param_dtensors.append(p)
490
  elif isinstance(p.data, torch.Tensor):
491
  param_tensors.append(p)
492
  else:
493
+ raise TypeError(
494
+ f"Unsupported parameter type: {type(p.data)}")
495
 
496
  if self.debug:
497
  print(
 
526
  # AdamW backup #
527
  ############################
528
 
529
+ params = [
530
+ p for p in group["params"] if not self.state[p]["use_muon"]
531
+ ]
532
  lr = group["lr"]
533
  beta1, beta2 = group["adamw_betas"]
534
  eps = group["adamw_eps"]
torch-ext/optimizer/muon.py CHANGED
@@ -47,14 +47,20 @@ class _muon_state:
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
 
50
  computed_u: torch.Tensor | None = None
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
 
53
  process_group = None
54
 
55
 
56
  @torch.no_grad()
57
  def _gather(p, state, rank, comm_stream, none_grad):
 
 
 
 
58
  g = p.grad
59
 
60
  if rank == state.worker_rank:
@@ -83,12 +89,21 @@ def _gather(p, state, rank, comm_stream, none_grad):
83
  state.gathered_grad = None
84
  state.gather_event = None
85
  if none_grad:
86
- p.grad.record_stream(comm_stream)
 
 
 
 
 
 
87
  p.grad = None
88
 
89
 
90
  @torch.no_grad()
91
  def _compute_u(state, steps, rank, compute_stream):
 
 
 
92
  with torch.cuda.stream(compute_stream):
93
  if rank == state.worker_rank:
94
  if state.gather_event is None:
@@ -98,16 +113,16 @@ def _compute_u(state, steps, rank, compute_stream):
98
  state.computed_u = u
99
  state.compute_event = torch.cuda.Event()
100
  state.compute_event.record()
101
- # Clear the gathered gradient to free memory
102
- state.gathered_grad.record_stream(compute_stream)
103
- state.gathered_grad = None
104
  else:
105
  state.computed_u = None
106
  state.compute_event = None
107
 
108
 
109
  @torch.no_grad()
110
- def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
 
 
 
111
 
112
  with torch.cuda.stream(comm_stream):
113
  if rank == state.worker_rank:
@@ -115,6 +130,10 @@ def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
115
  if state.compute_event is None:
116
  raise RuntimeError("Compute event must be set before scatter.")
117
  comm_stream.wait_event(state.compute_event)
 
 
 
 
118
  u = state.computed_u
119
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
120
  scatter_list = [s.contiguous() for s in scatter_list]
@@ -128,18 +147,32 @@ def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
128
  src=state.worker_rank,
129
  group=state.process_group,
130
  )
131
- if rank == state.worker_rank:
132
- # Clear u to free memory
133
- state.computed_u.record_stream(comm_stream)
134
- state.computed_u = None
135
-
136
  u_dtensor = DTensor.from_local(
137
  u_received,
138
  placements=p.placements,
139
  device_mesh=p.device_mesh,
140
  )
141
- p.data.mul_(1 - lr * weight_decay)
142
- p.data.add_(u_dtensor, alpha=-adjusted_lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  def default_is_muon(x, name):
@@ -160,17 +193,19 @@ class Muon(torch.optim.Optimizer):
160
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
161
 
162
  Arguments:
163
- muon_params: The parameters to be optimized by Muon.
 
164
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
165
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
166
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
167
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
168
- adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
169
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
170
  adamw_lr: The learning rate for the internal AdamW.
171
  adamw_betas: The betas for the internal AdamW.
172
  adamw_eps: The epsilon for the internal AdamW.
173
- adamw_weight_decay: The weight decay for the internal AdamW.
 
174
  """
175
 
176
  def __init__(
@@ -325,14 +360,8 @@ class Muon(torch.optim.Optimizer):
325
 
326
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
327
 
328
- # scale update
329
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
330
-
331
- # apply weight decay
332
- p.data.mul_(1 - lr * weight_decay)
333
-
334
- # apply update
335
- p.data.add_(u, alpha=-adjusted_lr)
336
 
337
  def _update_g(self, p, g, group, momentum):
338
  # calc update
@@ -347,9 +376,8 @@ class Muon(torch.optim.Optimizer):
347
  g = buf
348
  return g
349
 
350
- def _update_p(self, p, u, lr, weight_decay):
351
- # scale update
352
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
353
  # apply weight decay
354
  p.data.mul_(1 - lr * weight_decay)
355
  # apply update
@@ -392,11 +420,16 @@ class Muon(torch.optim.Optimizer):
392
  self.compute_stream)
393
 
394
  def enqueue_scatters(start_idx, chunk_size):
 
 
 
 
 
395
  for p in ordered_params[start_idx:start_idx + chunk_size]:
396
  state = param_to_state[id(p)]
397
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
398
- _scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
399
- self.comm_stream)
400
 
401
  chunk_size = dist.get_world_size(param_to_state[id(
402
  params[0])].process_group)
@@ -407,10 +440,14 @@ class Muon(torch.optim.Optimizer):
407
  enqueue_gathers(0, chunk_size)
408
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
409
  enqueue_computes(i, chunk_size)
 
 
410
  enqueue_gathers(i + chunk_size, chunk_size)
411
  enqueue_scatters(i, chunk_size)
 
412
 
413
- torch.cuda.current_stream().wait_stream(self.comm_stream)
 
414
 
415
  def step(self, closure=None):
416
  """Perform a single optimization step.
 
47
  # TODO: use Optional
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
+ scattered_u: DTensor | None = None
51
  computed_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
+ scatter_event: torch.cuda.Event | None = None
55
  process_group = None
56
 
57
 
58
  @torch.no_grad()
59
  def _gather(p, state, rank, comm_stream, none_grad):
60
+ """
61
+ Gather the gradients to worker_rank.
62
+ If none_grad is True, free p.grad after the gather.
63
+ """
64
  g = p.grad
65
 
66
  if rank == state.worker_rank:
 
89
  state.gathered_grad = None
90
  state.gather_event = None
91
  if none_grad:
92
+ # We can safely free p.grad without calling record_stream:
93
+ # p.grad.to_local().record_stream(comm_stream)
94
+ # Explanation:
95
+ # 1. p.grad is created on the default stream, but the default stream
96
+ # is synchronized with the comm stream later.
97
+ # 2. There is no further activity on the default stream before the optimizer finishes.
98
+ # Therefore, it is safe to free p.grad directly on the comm stream.
99
  p.grad = None
100
 
101
 
102
  @torch.no_grad()
103
  def _compute_u(state, steps, rank, compute_stream):
104
+ """
105
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
106
+ """
107
  with torch.cuda.stream(compute_stream):
108
  if rank == state.worker_rank:
109
  if state.gather_event is None:
 
113
  state.computed_u = u
114
  state.compute_event = torch.cuda.Event()
115
  state.compute_event.record()
 
 
 
116
  else:
117
  state.computed_u = None
118
  state.compute_event = None
119
 
120
 
121
  @torch.no_grad()
122
+ def _scatter(p, state, rank, comm_stream):
123
+ """
124
+ Scatter the computed_u from worker_rank to all ranks.
125
+ """
126
 
127
  with torch.cuda.stream(comm_stream):
128
  if rank == state.worker_rank:
 
130
  if state.compute_event is None:
131
  raise RuntimeError("Compute event must be set before scatter.")
132
  comm_stream.wait_event(state.compute_event)
133
+
134
+ # Clear the gathered gradient to free memory
135
+ state.gathered_grad = None
136
+
137
  u = state.computed_u
138
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
139
  scatter_list = [s.contiguous() for s in scatter_list]
 
147
  src=state.worker_rank,
148
  group=state.process_group,
149
  )
 
 
 
 
 
150
  u_dtensor = DTensor.from_local(
151
  u_received,
152
  placements=p.placements,
153
  device_mesh=p.device_mesh,
154
  )
155
+
156
+ state.scattered_u = u_dtensor
157
+ state.scatter_event = torch.cuda.Event()
158
+ state.scatter_event.record()
159
+
160
+
161
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
162
+ compute_stream):
163
+ """
164
+ Update sharded parameter p with the scattered_u.
165
+ Only worker_rank frees computed_u.
166
+ """
167
+ with torch.cuda.stream(compute_stream):
168
+ if state.scatter_event is None:
169
+ raise RuntimeError("Scatter event must be set before update")
170
+ compute_stream.wait_event(state.scatter_event)
171
+ if rank == state.worker_rank:
172
+ # Free computed_u
173
+ state.computed_u = None
174
+
175
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
176
 
177
 
178
  def default_is_muon(x, name):
 
193
  - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
194
 
195
  Arguments:
196
+ model: The model to be optimized by Muon.
197
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
198
  lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
199
  momentum: The momentum used by the internal SGD. (0.95 is a good default)
200
  nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
201
  ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
202
+ weight_decay: The weight decay for Muon and AdamW.
203
  {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
204
  adamw_lr: The learning rate for the internal AdamW.
205
  adamw_betas: The betas for the internal AdamW.
206
  adamw_eps: The epsilon for the internal AdamW.
207
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
208
+ debug: Whether to print debug information.
209
  """
210
 
211
  def __init__(
 
360
 
361
  u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
362
 
 
363
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
364
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
 
 
 
 
 
365
 
366
  def _update_g(self, p, g, group, momentum):
367
  # calc update
 
376
  g = buf
377
  return g
378
 
379
+ @staticmethod
380
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
381
  # apply weight decay
382
  p.data.mul_(1 - lr * weight_decay)
383
  # apply update
 
420
  self.compute_stream)
421
 
422
  def enqueue_scatters(start_idx, chunk_size):
423
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
424
+ state = param_to_state[id(p)]
425
+ _scatter(p, state, self.rank, self.comm_stream)
426
+
427
+ def enqueue_update_param(start_idx, chunk_size):
428
  for p in ordered_params[start_idx:start_idx + chunk_size]:
429
  state = param_to_state[id(p)]
430
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
431
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
432
+ self.rank, self.compute_stream)
433
 
434
  chunk_size = dist.get_world_size(param_to_state[id(
435
  params[0])].process_group)
 
440
  enqueue_gathers(0, chunk_size)
441
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
442
  enqueue_computes(i, chunk_size)
443
+ if i > 0:
444
+ enqueue_update_param(i - chunk_size, chunk_size)
445
  enqueue_gathers(i + chunk_size, chunk_size)
446
  enqueue_scatters(i, chunk_size)
447
+ enqueue_update_param(i, chunk_size)
448
 
449
+ # Wait the last update_param to finish
450
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
451
 
452
  def step(self, closure=None):
453
  """Perform a single optimization step.