Commit
·
db36e39
1
Parent(s):
80187d6
applied lint
Browse files- .pre-commit-config.yaml +1 -0
- optimizer/dummy.cu +1 -1
- torch-ext/optimizer/muon.py +31 -24
.pre-commit-config.yaml
CHANGED
@@ -15,6 +15,7 @@ repos:
|
|
15 |
rev: v1.34.0
|
16 |
hooks:
|
17 |
- id: typos
|
|
|
18 |
- repo: https://github.com/PyCQA/isort
|
19 |
rev: 6.0.1
|
20 |
hooks:
|
|
|
15 |
rev: v1.34.0
|
16 |
hooks:
|
17 |
- id: typos
|
18 |
+
exclude: '.gitattributes'
|
19 |
- repo: https://github.com/PyCQA/isort
|
20 |
rev: 6.0.1
|
21 |
hooks:
|
optimizer/dummy.cu
CHANGED
@@ -3,4 +3,4 @@ namespace {
|
|
3 |
__global__ void dummy() {
|
4 |
// This kernel does nothing but serves as a placeholder
|
5 |
}
|
6 |
-
}
|
|
|
3 |
__global__ void dummy() {
|
4 |
// This kernel does nothing but serves as a placeholder
|
5 |
}
|
6 |
+
} // namespace
|
torch-ext/optimizer/muon.py
CHANGED
@@ -59,7 +59,9 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
-
gather_list = [
|
|
|
|
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
@@ -73,8 +75,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
-
"Gather event already exists, which should not happen."
|
77 |
-
)
|
78 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
79 |
state.gather_event = torch.cuda.Event()
|
80 |
state.gather_event.record()
|
@@ -240,9 +241,10 @@ class Muon(torch.optim.Optimizer):
|
|
240 |
"""
|
241 |
Get the shard mesh for a parameter p on the given rank.
|
242 |
"""
|
243 |
-
assert isinstance(
|
|
|
244 |
|
245 |
-
if p.placements == (Shard(dim=0),):
|
246 |
# Case for FSDP
|
247 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
@@ -269,11 +271,12 @@ class Muon(torch.optim.Optimizer):
|
|
269 |
total_flops += flops
|
270 |
|
271 |
if self.debug:
|
272 |
-
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
|
|
273 |
|
274 |
-
ordered_params = sorted(
|
275 |
-
|
276 |
-
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
@@ -369,28 +372,29 @@ class Muon(torch.optim.Optimizer):
|
|
369 |
p.grad = g
|
370 |
|
371 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
372 |
-
params, group
|
373 |
-
)
|
374 |
|
375 |
def enqueue_gathers(start_idx, chunk_size):
|
376 |
-
for p in ordered_params[start_idx
|
377 |
state = param_to_state[id(p)]
|
378 |
-
_gather(p, state, self.rank, self.comm_stream,
|
|
|
379 |
|
380 |
def enqueue_computes(start_idx, chunk_size):
|
381 |
-
for p in ordered_params[start_idx
|
382 |
state = param_to_state[id(p)]
|
383 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
|
|
384 |
|
385 |
def enqueue_scatters(start_idx, chunk_size):
|
386 |
-
for p in ordered_params[start_idx
|
387 |
state = param_to_state[id(p)]
|
388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
389 |
-
_scatter(
|
390 |
-
|
391 |
-
)
|
392 |
|
393 |
-
chunk_size = dist.get_world_size(param_to_state[id(
|
|
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
@@ -436,15 +440,16 @@ class Muon(torch.optim.Optimizer):
|
|
436 |
continue
|
437 |
if isinstance(p.data, DTensor):
|
438 |
if all(
|
439 |
-
|
440 |
-
|
441 |
param_tensors.append(p)
|
442 |
else:
|
443 |
param_dtensors.append(p)
|
444 |
elif isinstance(p.data, torch.Tensor):
|
445 |
param_tensors.append(p)
|
446 |
else:
|
447 |
-
raise TypeError(
|
|
|
448 |
|
449 |
if self.debug:
|
450 |
print(
|
@@ -479,7 +484,9 @@ class Muon(torch.optim.Optimizer):
|
|
479 |
# AdamW backup #
|
480 |
############################
|
481 |
|
482 |
-
params = [
|
|
|
|
|
483 |
lr = group["lr"]
|
484 |
beta1, beta2 = group["adamw_betas"]
|
485 |
eps = group["adamw_eps"]
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [
|
63 |
+
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
64 |
+
]
|
65 |
else:
|
66 |
gather_list = None
|
67 |
|
|
|
75 |
if rank == state.worker_rank:
|
76 |
if state.gathered_grad is not None:
|
77 |
raise RuntimeError(
|
78 |
+
"Gather event already exists, which should not happen.")
|
|
|
79 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
80 |
state.gather_event = torch.cuda.Event()
|
81 |
state.gather_event.record()
|
|
|
241 |
"""
|
242 |
Get the shard mesh for a parameter p on the given rank.
|
243 |
"""
|
244 |
+
assert isinstance(
|
245 |
+
p, DTensor), "Parallel Muon only supports DTensor parameters."
|
246 |
|
247 |
+
if p.placements == (Shard(dim=0), ):
|
248 |
# Case for FSDP
|
249 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
250 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
|
|
271 |
total_flops += flops
|
272 |
|
273 |
if self.debug:
|
274 |
+
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
275 |
+
flush=True)
|
276 |
|
277 |
+
ordered_params = sorted(params,
|
278 |
+
key=lambda p: param_to_flops[id(p)],
|
279 |
+
reverse=True)
|
280 |
|
281 |
round_robin = 0
|
282 |
mesh = None
|
|
|
372 |
p.grad = g
|
373 |
|
374 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
375 |
+
params, group)
|
|
|
376 |
|
377 |
def enqueue_gathers(start_idx, chunk_size):
|
378 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
379 |
state = param_to_state[id(p)]
|
380 |
+
_gather(p, state, self.rank, self.comm_stream,
|
381 |
+
group["none_grad"])
|
382 |
|
383 |
def enqueue_computes(start_idx, chunk_size):
|
384 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
385 |
state = param_to_state[id(p)]
|
386 |
+
_compute_u(state, group["ns_steps"], self.rank,
|
387 |
+
self.compute_stream)
|
388 |
|
389 |
def enqueue_scatters(start_idx, chunk_size):
|
390 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
391 |
state = param_to_state[id(p)]
|
392 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
393 |
+
_scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
|
394 |
+
self.comm_stream)
|
|
|
395 |
|
396 |
+
chunk_size = dist.get_world_size(param_to_state[id(
|
397 |
+
params[0])].process_group)
|
398 |
|
399 |
# Wait grad update
|
400 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
440 |
continue
|
441 |
if isinstance(p.data, DTensor):
|
442 |
if all(
|
443 |
+
isinstance(placement, Replicate)
|
444 |
+
for placement in p.placements):
|
445 |
param_tensors.append(p)
|
446 |
else:
|
447 |
param_dtensors.append(p)
|
448 |
elif isinstance(p.data, torch.Tensor):
|
449 |
param_tensors.append(p)
|
450 |
else:
|
451 |
+
raise TypeError(
|
452 |
+
f"Unsupported parameter type: {type(p.data)}")
|
453 |
|
454 |
if self.debug:
|
455 |
print(
|
|
|
484 |
# AdamW backup #
|
485 |
############################
|
486 |
|
487 |
+
params = [
|
488 |
+
p for p in group["params"] if not self.state[p]["use_muon"]
|
489 |
+
]
|
490 |
lr = group["lr"]
|
491 |
beta1, beta2 = group["adamw_betas"]
|
492 |
eps = group["adamw_eps"]
|