feat(muon) : add tuned-abc-values & blfoat16 communication
Browse files- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +158 -50
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +158 -50
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +2 -2
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +158 -50
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so} +2 -2
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +158 -50
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so +0 -3
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so +3 -0
- build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +158 -50
- torch-ext/optimizer/muon.py +158 -50
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787376
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:55f17ad6ecdd22d84ea5b776a317fa9fbb6b81f622fa8fc80b78e0ef80bd4ea6
|
3 |
size 1787376
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824264
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f37c80a535a081e997c1973902a010c48b33ca40085a7f267a5278e56cff26f3
|
3 |
size 1824264
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f8bf16b0ae5af74852e8c890183c8c32175886c3d0366cfc776fb3e1ee15906
|
3 |
+
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_20250911094409.abi3.so → _optimizer_ee6ed44_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d50267ec23db9512ae1d82c99012901d58e50dee9bf34346702561a5d3e6d9e7
|
3 |
+
size 1749840
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824264
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80ce6b0d62167a8ea10b6e2a1f90df70aa108997570c0ed210f458debd26f32f
|
3 |
size 1824264
|
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:22dc3ab77ab74837126281f79f417c5d55b2cc9885388fd9d3a1c7c824ece2bd
|
3 |
-
size 1883360
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3487612a8f022a1df1353945fc6d65bbd6797179b06c5d3202dc6e2aa6afb27a
|
3 |
+
size 1883352
|
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:62ecfc7e6a1ab0c4ada19ed7aea40fc0a431c4ceb1729666efa98ac0e407f9c8
|
3 |
-
size 1883360
|
|
|
|
|
|
|
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f5e375def39d93758b60534cef504ae75d9c13e0d86da5dcf7642f1f90b77f52
|
3 |
+
size 1883352
|
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:37e389c650fc1fcbc9fbd68f1e7c1a768b08e90509fd8a5d87879655726f2db2
|
3 |
-
size 1750040
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33e0d50fbf340612b0e1129717e4116197c8562592e5920f2dedc718ce9a0585
|
3 |
+
size 1750000
|
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_ee6ed44_dirty
|
3 |
+
ops = torch.ops._optimizer_ee6ed44_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_ee6ed44_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_20250911094409.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e62682b711f002505bb17c170b2bb233f8d389510ff8e2e0a753ee96d11d0746
|
3 |
-
size 1750128
|
|
|
|
|
|
|
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_ee6ed44_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5eedf56e661a7d314727e40f192236dbd9696f62ba21f11e366643f2662c03a4
|
3 |
+
size 1750088
|
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|
torch-ext/optimizer/muon.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
@@ -12,6 +13,8 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
14 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
|
15 |
@torch.no_grad()
|
16 |
def _zeropower_via_newtonschulz5(G, steps):
|
17 |
"""
|
@@ -24,15 +27,21 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
24 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
25 |
"""
|
26 |
assert len(G.shape) == 2
|
27 |
-
|
28 |
X = G # no manual typecast
|
|
|
29 |
if G.size(0) > G.size(1):
|
30 |
X = X.T
|
31 |
# Ensure spectral norm is at most 1
|
32 |
X = X / (X.norm() + 1e-7)
|
33 |
-
X = X.bfloat16()
|
34 |
# Perform the NS iterations
|
35 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
A = X @ X.T
|
37 |
# B = (
|
38 |
# b * A + c * A @ A
|
@@ -43,7 +52,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
43 |
|
44 |
if G.size(0) > G.size(1):
|
45 |
X = X.T
|
46 |
-
return X
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,17 +74,19 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
65 |
Gather the gradients to worker_rank.
|
66 |
If none_grad is True, free p.grad after the gather.
|
67 |
"""
|
68 |
-
g = p.grad
|
69 |
-
|
70 |
-
if rank == state.worker_rank:
|
71 |
-
num_ranks = dist.get_world_size(group=state.process_group)
|
72 |
-
gather_list = [
|
73 |
-
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
74 |
-
]
|
75 |
-
else:
|
76 |
-
gather_list = None
|
77 |
-
|
78 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.distributed.gather(
|
80 |
g.to_local(),
|
81 |
dst=state.worker_rank,
|
@@ -92,6 +103,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
92 |
else:
|
93 |
state.gathered_grad = None
|
94 |
state.gather_event = None
|
|
|
95 |
if none_grad:
|
96 |
# We can safely free p.grad without calling record_stream:
|
97 |
# p.grad.to_local().record_stream(comm_stream)
|
@@ -104,7 +116,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _compute_u(state, steps, rank, compute_stream):
|
108 |
"""
|
109 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
110 |
"""
|
@@ -115,11 +127,11 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
115 |
compute_stream.wait_event(state.gather_event)
|
116 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
117 |
state.computed_u = u
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
@torch.no_grad()
|
@@ -129,12 +141,12 @@ def _scatter(p, state, rank, comm_stream):
|
|
129 |
"""
|
130 |
|
131 |
with torch.cuda.stream(comm_stream):
|
|
|
|
|
|
|
|
|
132 |
if rank == state.worker_rank:
|
133 |
num_ranks = dist.get_world_size(group=state.process_group)
|
134 |
-
if state.compute_event is None:
|
135 |
-
raise RuntimeError("Compute event must be set before scatter.")
|
136 |
-
comm_stream.wait_event(state.compute_event)
|
137 |
-
|
138 |
# Clear the gathered gradient to free memory
|
139 |
state.gathered_grad = None
|
140 |
|
@@ -144,22 +156,15 @@ def _scatter(p, state, rank, comm_stream):
|
|
144 |
else:
|
145 |
scatter_list = None
|
146 |
|
147 |
-
u_received = torch.empty_like(p.to_local())
|
148 |
torch.distributed.scatter(
|
149 |
-
|
150 |
scatter_list=scatter_list,
|
151 |
src=state.worker_rank,
|
152 |
group=state.process_group,
|
153 |
)
|
154 |
-
u_dtensor = DTensor.from_local(
|
155 |
-
u_received,
|
156 |
-
placements=p.placements,
|
157 |
-
device_mesh=p.device_mesh,
|
158 |
-
)
|
159 |
-
|
160 |
-
state.scattered_u = u_dtensor
|
161 |
state.scatter_event = torch.cuda.Event()
|
162 |
state.scatter_event.record()
|
|
|
163 |
|
164 |
|
165 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
@@ -172,11 +177,21 @@ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
172 |
if state.scatter_event is None:
|
173 |
raise RuntimeError("Scatter event must be set before update")
|
174 |
compute_stream.wait_event(state.scatter_event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if rank == state.worker_rank:
|
176 |
# Free computed_u
|
177 |
state.computed_u = None
|
178 |
|
179 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
|
|
|
|
180 |
|
181 |
|
182 |
def default_is_muon(name, x):
|
@@ -375,7 +390,8 @@ class Muon(torch.optim.Optimizer):
|
|
375 |
else:
|
376 |
g = buf
|
377 |
|
378 |
-
u = _zeropower_via_newtonschulz5(g,
|
|
|
379 |
|
380 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
381 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
@@ -433,7 +449,7 @@ class Muon(torch.optim.Optimizer):
|
|
433 |
def enqueue_computes(start_idx, chunk_size):
|
434 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
435 |
state = param_to_state[id(p)]
|
436 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
437 |
self.compute_stream)
|
438 |
|
439 |
def enqueue_scatters(start_idx, chunk_size):
|
@@ -466,6 +482,77 @@ class Muon(torch.optim.Optimizer):
|
|
466 |
# Wait the last update_param to finish
|
467 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def step(self, closure=None):
|
470 |
"""Perform a single optimization step.
|
471 |
|
@@ -542,6 +629,12 @@ class Muon(torch.optim.Optimizer):
|
|
542 |
# AdamW backup #
|
543 |
############################
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
lr = group["lr"]
|
546 |
beta1, beta2 = group["adamw_betas"]
|
547 |
eps = group["adamw_eps"]
|
@@ -552,23 +645,38 @@ class Muon(torch.optim.Optimizer):
|
|
552 |
if g is None:
|
553 |
continue
|
554 |
state = self.state[p]
|
|
|
|
|
555 |
if "step" not in state:
|
556 |
-
state["step"] =
|
|
|
|
|
557 |
state["moment1"] = torch.zeros_like(g)
|
558 |
state["moment2"] = torch.zeros_like(g)
|
559 |
-
state["
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
return loss
|
|
|
2 |
import math
|
3 |
import types
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Optional, Union, cast
|
6 |
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
|
|
13 |
|
14 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
15 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
16 |
+
# Muon's Newton–Schulz iteration causes high variance in singular values
|
17 |
+
# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
|
18 |
@torch.no_grad()
|
19 |
def _zeropower_via_newtonschulz5(G, steps):
|
20 |
"""
|
|
|
27 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
28 |
"""
|
29 |
assert len(G.shape) == 2
|
30 |
+
assert G.dtype == torch.bfloat16
|
31 |
X = G # no manual typecast
|
32 |
+
|
33 |
if G.size(0) > G.size(1):
|
34 |
X = X.T
|
35 |
# Ensure spectral norm is at most 1
|
36 |
X = X / (X.norm() + 1e-7)
|
|
|
37 |
# Perform the NS iterations
|
38 |
+
for a, b, c in [
|
39 |
+
(4.0848, -6.8946, 2.9270),
|
40 |
+
(3.9505, -6.3029, 2.6377),
|
41 |
+
(3.7418, -5.5913, 2.3037),
|
42 |
+
(2.8769, -3.1427, 1.2046),
|
43 |
+
(2.8366, -3.0525, 1.2012),
|
44 |
+
]:
|
45 |
A = X @ X.T
|
46 |
# B = (
|
47 |
# b * A + c * A @ A
|
|
|
52 |
|
53 |
if G.size(0) > G.size(1):
|
54 |
X = X.T
|
55 |
+
return X
|
56 |
|
57 |
|
58 |
@dataclass
|
|
|
74 |
Gather the gradients to worker_rank.
|
75 |
If none_grad is True, free p.grad after the gather.
|
76 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.cuda.stream(comm_stream):
|
78 |
+
g = p.grad
|
79 |
+
|
80 |
+
if rank == state.worker_rank:
|
81 |
+
num_ranks = dist.get_world_size(group=state.process_group)
|
82 |
+
gather_list = [
|
83 |
+
torch.empty_like(g.to_local(), dtype=torch.bfloat16)
|
84 |
+
for _ in range(num_ranks)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gather_list = None
|
88 |
+
|
89 |
+
g = g.to(torch.bfloat16)
|
90 |
torch.distributed.gather(
|
91 |
g.to_local(),
|
92 |
dst=state.worker_rank,
|
|
|
103 |
else:
|
104 |
state.gathered_grad = None
|
105 |
state.gather_event = None
|
106 |
+
gather_list = None
|
107 |
if none_grad:
|
108 |
# We can safely free p.grad without calling record_stream:
|
109 |
# p.grad.to_local().record_stream(comm_stream)
|
|
|
116 |
|
117 |
|
118 |
@torch.no_grad()
|
119 |
+
def _compute_u(p, state, steps, rank, compute_stream):
|
120 |
"""
|
121 |
On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
|
122 |
"""
|
|
|
127 |
compute_stream.wait_event(state.gather_event)
|
128 |
u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
|
129 |
state.computed_u = u
|
130 |
+
state.scattered_u = torch.empty_like(p.to_local(),
|
131 |
+
dtype=torch.bfloat16)
|
132 |
+
state.compute_event = torch.cuda.Event()
|
133 |
+
state.compute_event.record()
|
134 |
+
u = None
|
135 |
|
136 |
|
137 |
@torch.no_grad()
|
|
|
141 |
"""
|
142 |
|
143 |
with torch.cuda.stream(comm_stream):
|
144 |
+
if state.compute_event is None:
|
145 |
+
raise RuntimeError("Compute event must be set before scatter.")
|
146 |
+
comm_stream.wait_event(state.compute_event)
|
147 |
+
|
148 |
if rank == state.worker_rank:
|
149 |
num_ranks = dist.get_world_size(group=state.process_group)
|
|
|
|
|
|
|
|
|
150 |
# Clear the gathered gradient to free memory
|
151 |
state.gathered_grad = None
|
152 |
|
|
|
156 |
else:
|
157 |
scatter_list = None
|
158 |
|
|
|
159 |
torch.distributed.scatter(
|
160 |
+
state.scattered_u,
|
161 |
scatter_list=scatter_list,
|
162 |
src=state.worker_rank,
|
163 |
group=state.process_group,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
state.scatter_event = torch.cuda.Event()
|
166 |
state.scatter_event.record()
|
167 |
+
scatter_list = None
|
168 |
|
169 |
|
170 |
def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
|
|
|
177 |
if state.scatter_event is None:
|
178 |
raise RuntimeError("Scatter event must be set before update")
|
179 |
compute_stream.wait_event(state.scatter_event)
|
180 |
+
u_dtensor = DTensor.from_local(
|
181 |
+
state.scattered_u,
|
182 |
+
placements=p.placements,
|
183 |
+
device_mesh=p.device_mesh,
|
184 |
+
)
|
185 |
+
|
186 |
+
state.scattered_u = u_dtensor
|
187 |
+
|
188 |
if rank == state.worker_rank:
|
189 |
# Free computed_u
|
190 |
state.computed_u = None
|
191 |
|
192 |
Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
|
193 |
+
state.scattered_u = None
|
194 |
+
u_dtensor = None
|
195 |
|
196 |
|
197 |
def default_is_muon(name, x):
|
|
|
390 |
else:
|
391 |
g = buf
|
392 |
|
393 |
+
u = _zeropower_via_newtonschulz5(g.bfloat16(),
|
394 |
+
steps=group["ns_steps"])
|
395 |
|
396 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
397 |
Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
|
|
|
449 |
def enqueue_computes(start_idx, chunk_size):
|
450 |
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
451 |
state = param_to_state[id(p)]
|
452 |
+
_compute_u(p, state, group["ns_steps"], self.rank,
|
453 |
self.compute_stream)
|
454 |
|
455 |
def enqueue_scatters(start_idx, chunk_size):
|
|
|
482 |
# Wait the last update_param to finish
|
483 |
torch.cuda.current_stream().wait_stream(self.compute_stream)
|
484 |
|
485 |
+
@staticmethod
|
486 |
+
def _fused_adamw(
|
487 |
+
params: list[torch.Tensor],
|
488 |
+
grads: list[torch.Tensor],
|
489 |
+
exp_avgs: list[torch.Tensor],
|
490 |
+
exp_avg_sqs: list[torch.Tensor],
|
491 |
+
max_exp_avg_sqs: list[torch.Tensor],
|
492 |
+
state_steps: list[torch.Tensor],
|
493 |
+
amsgrad: bool,
|
494 |
+
beta1: float,
|
495 |
+
beta2: float,
|
496 |
+
lr: Union[float, torch.Tensor],
|
497 |
+
weight_decay: float,
|
498 |
+
eps: float,
|
499 |
+
maximize: bool,
|
500 |
+
) -> None:
|
501 |
+
if not params:
|
502 |
+
return
|
503 |
+
|
504 |
+
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
505 |
+
# treating it as a scalar.
|
506 |
+
lr_dict: Optional[DeviceDict] = ({
|
507 |
+
lr.device: lr
|
508 |
+
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
|
509 |
+
None)
|
510 |
+
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
|
511 |
+
[
|
512 |
+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
|
513 |
+
state_steps
|
514 |
+
] # type: ignore[list-item]
|
515 |
+
)
|
516 |
+
for (device, _), (
|
517 |
+
(
|
518 |
+
device_params_,
|
519 |
+
device_grads_,
|
520 |
+
device_exp_avgs_,
|
521 |
+
device_exp_avg_sqs_,
|
522 |
+
device_max_exp_avg_sqs,
|
523 |
+
device_state_steps_,
|
524 |
+
),
|
525 |
+
_,
|
526 |
+
) in grouped_tensors.items():
|
527 |
+
device_params = cast(list[torch.Tensor], device_params_)
|
528 |
+
device_grads = cast(list[torch.Tensor], device_grads_)
|
529 |
+
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
|
530 |
+
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
|
531 |
+
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
|
532 |
+
|
533 |
+
if lr_dict is not None and device not in lr_dict:
|
534 |
+
lr_dict[device] = lr.to(
|
535 |
+
device=device,
|
536 |
+
non_blocking=True) # type: ignore[union-attr]
|
537 |
+
lr = lr_dict[device]
|
538 |
+
torch._foreach_add_(device_state_steps, 1)
|
539 |
+
func = torch._fused_adamw_
|
540 |
+
func(
|
541 |
+
device_params,
|
542 |
+
device_grads,
|
543 |
+
device_exp_avgs,
|
544 |
+
device_exp_avg_sqs,
|
545 |
+
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
546 |
+
device_state_steps,
|
547 |
+
amsgrad=amsgrad,
|
548 |
+
lr=lr, # type: ignore[arg-type]
|
549 |
+
beta1=beta1,
|
550 |
+
beta2=beta2,
|
551 |
+
weight_decay=weight_decay,
|
552 |
+
eps=eps,
|
553 |
+
maximize=maximize,
|
554 |
+
)
|
555 |
+
|
556 |
def step(self, closure=None):
|
557 |
"""Perform a single optimization step.
|
558 |
|
|
|
629 |
# AdamW backup #
|
630 |
############################
|
631 |
|
632 |
+
params_with_grads = []
|
633 |
+
grads = []
|
634 |
+
moment1 = []
|
635 |
+
moment2 = []
|
636 |
+
max_exp_avg_sqs = []
|
637 |
+
state_steps = []
|
638 |
lr = group["lr"]
|
639 |
beta1, beta2 = group["adamw_betas"]
|
640 |
eps = group["adamw_eps"]
|
|
|
645 |
if g is None:
|
646 |
continue
|
647 |
state = self.state[p]
|
648 |
+
params_with_grads.append(p)
|
649 |
+
grads.append(g)
|
650 |
if "step" not in state:
|
651 |
+
state["step"] = (torch.zeros((),
|
652 |
+
dtype=torch.float32,
|
653 |
+
device=p.device))
|
654 |
state["moment1"] = torch.zeros_like(g)
|
655 |
state["moment2"] = torch.zeros_like(g)
|
656 |
+
moment1.append(state["moment1"])
|
657 |
+
moment2.append(state["moment2"])
|
658 |
+
if not isinstance(state["step"], torch.Tensor):
|
659 |
+
step_tensor = torch.tensor(state["step"],
|
660 |
+
dtype=torch.float32,
|
661 |
+
device=p.device)
|
662 |
+
else:
|
663 |
+
step_tensor = state["step"]
|
664 |
+
state_steps.append(step_tensor)
|
665 |
+
|
666 |
+
self._fused_adamw(
|
667 |
+
params_with_grads,
|
668 |
+
grads,
|
669 |
+
moment1,
|
670 |
+
moment2,
|
671 |
+
max_exp_avg_sqs,
|
672 |
+
state_steps,
|
673 |
+
amsgrad=False,
|
674 |
+
beta1=beta1,
|
675 |
+
beta2=beta2,
|
676 |
+
lr=lr,
|
677 |
+
weight_decay=weight_decay,
|
678 |
+
eps=eps,
|
679 |
+
maximize=False,
|
680 |
+
)
|
681 |
|
682 |
return loss
|