iamwyldecat commited on
Commit
bdd2678
·
1 Parent(s): 8535e80

fix(muon): delete intermediate tensors immediately to lower peak mem usage

Browse files
Files changed (36) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  3. build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
  4. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  6. build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -19
  7. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  9. build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
  10. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  12. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -19
  13. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  15. build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -19
  16. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  18. build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -19
  19. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  21. build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -19
  22. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  24. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
  25. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  26. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  27. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
  28. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  29. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  30. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -19
  31. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc +0 -0
  32. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc +0 -0
  33. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  34. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
  35. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -19
  36. torch-ext/optimizer/muon.py +10 -19
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a46d9e65efcfa82522950d9ebf2b2b4594d9ed5abc28704352a1f7de2dae707a
3
+ size 1787272
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d351a600884b7378f546a345afe65c176e1399bb42fb7dfe4333b0e90975803b
3
+ size 1824224
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c0843f38cee494b7a5939eb62d27039d76dc3f69401d411efbacaa25cb0d67a
3
+ size 1824224
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acdba99ce95532a9ca6a8987a7ab61a257657872f2cc672c91e8e5fe809aa24e
3
+ size 1749744
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7d5e76c002507f66f2a227d02c2b11aa3fdc3f07a2a0b82faaa34133adb77ef
3
+ size 1787192
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:becccd250f38a84803350cfb5fac3a6682b1e594968a714642724cbc71246b4a
3
+ size 1824184
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34215ecc274ef516967962c8457dad214e9bbf618bf5eee8f467371f4f620284
3
+ size 1824184
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c23a3adbe4dc1a64b4851a9f8e4aed0e3e1eeeded27322c54f5b942282a2a332
3
+ size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aa09c22745d5efe1ef0669c4ca05615f67595dc90cabeee6e878301fa9bd22
3
+ size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4baf569b70749c4657062fb0f56943fc486adb0c482e50c7aa8e31ddf5cc870
3
+ size 1883352
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc differ
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc CHANGED
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc differ
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_b4b3752_dirty
3
- ops = torch.ops._optimizer_b4b3752_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_b4b3752_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8535e80_dirty
3
+ ops = torch.ops._optimizer_8535e80_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8535e80_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8566c9bc05e13c9394572f9f9c6bac24c31932548be485f49eb49fb249880832
3
+ size 1749648
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378
 
torch-ext/optimizer/muon.py CHANGED
@@ -48,7 +48,6 @@ class _muon_state:
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
51
- scattered_u: torch.Tensor | None = None
52
  gather_event: torch.cuda.Event | None = None
53
  compute_event: torch.cuda.Event | None = None
54
 
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
93
  state.computed_u = u
94
  state.compute_event = torch.cuda.Event()
95
  state.compute_event.record()
 
 
96
  else:
97
  state.computed_u = None
98
  state.compute_event = None
99
 
100
 
101
- def _scatter(p, state, rank, comm_stream):
102
  u = state.computed_u
103
  mesh = p.device_mesh
104
 
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
118
  src=state.worker_rank,
119
  group=mesh.get_group(),
120
  )
 
 
 
121
  u = DTensor.from_local(
122
  u,
123
  placements=p.placements,
124
  device_mesh=mesh,
125
  )
126
-
127
- state.scattered_u = u
128
 
129
 
130
  class Muon(torch.optim.Optimizer):
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
353
  def enqueue_scatters(start_idx, chunk_size):
354
  for p in ordered_params[start_idx : start_idx + chunk_size]:
355
  state = param_to_state[id(p)]
356
- _scatter(p, state, self.rank, self.comm_stream)
 
357
 
358
  chunk_size = params[0].device_mesh.mesh.numel()
359
 
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
368
 
369
  torch.cuda.current_stream().wait_stream(self.comm_stream)
370
 
371
- for p in params:
372
- g = p.grad
373
- if g is None:
374
- continue
375
-
376
- # Update p with sharded u
377
- state = param_to_state[id(p)]
378
- self._update_p(
379
- p,
380
- state.scattered_u,
381
- lr=lr,
382
- wd=wd,
383
- )
384
-
385
  def step(self, closure=None):
386
  """Perform a single optimization step.
387
 
 
48
  worker_rank: int | None = None
49
  gathered_grad: torch.Tensor | None = None
50
  computed_u: torch.Tensor | None = None
 
51
  gather_event: torch.cuda.Event | None = None
52
  compute_event: torch.cuda.Event | None = None
53
 
 
92
  state.computed_u = u
93
  state.compute_event = torch.cuda.Event()
94
  state.compute_event.record()
95
+ state.gathered_grad.record_stream(compute_stream)
96
+ del state.gathered_grad
97
  else:
98
  state.computed_u = None
99
  state.compute_event = None
100
 
101
 
102
+ def _scatter(p, state, lr, wd, rank, comm_stream):
103
  u = state.computed_u
104
  mesh = p.device_mesh
105
 
 
119
  src=state.worker_rank,
120
  group=mesh.get_group(),
121
  )
122
+ if rank == state.worker_rank:
123
+ state.computed_u.record_stream(comm_stream)
124
+ del state.computed_u
125
  u = DTensor.from_local(
126
  u,
127
  placements=p.placements,
128
  device_mesh=mesh,
129
  )
130
+ p.data.mul_(1 - lr * wd)
131
+ p.data.add_(u, alpha=-lr)
132
 
133
 
134
  class Muon(torch.optim.Optimizer):
 
357
  def enqueue_scatters(start_idx, chunk_size):
358
  for p in ordered_params[start_idx : start_idx + chunk_size]:
359
  state = param_to_state[id(p)]
360
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
361
+ _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
362
 
363
  chunk_size = params[0].device_mesh.mesh.numel()
364
 
 
373
 
374
  torch.cuda.current_stream().wait_stream(self.comm_stream)
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def step(self, closure=None):
377
  """Perform a single optimization step.
378