Commit
·
db36e39
1
Parent(s):
80187d6
applied lint
Browse files- .pre-commit-config.yaml +1 -0
- optimizer/dummy.cu +1 -1
- torch-ext/optimizer/muon.py +31 -24
.pre-commit-config.yaml
CHANGED
|
@@ -15,6 +15,7 @@ repos:
|
|
| 15 |
rev: v1.34.0
|
| 16 |
hooks:
|
| 17 |
- id: typos
|
|
|
|
| 18 |
- repo: https://github.com/PyCQA/isort
|
| 19 |
rev: 6.0.1
|
| 20 |
hooks:
|
|
|
|
| 15 |
rev: v1.34.0
|
| 16 |
hooks:
|
| 17 |
- id: typos
|
| 18 |
+
exclude: '.gitattributes'
|
| 19 |
- repo: https://github.com/PyCQA/isort
|
| 20 |
rev: 6.0.1
|
| 21 |
hooks:
|
optimizer/dummy.cu
CHANGED
|
@@ -3,4 +3,4 @@ namespace {
|
|
| 3 |
__global__ void dummy() {
|
| 4 |
// This kernel does nothing but serves as a placeholder
|
| 5 |
}
|
| 6 |
-
}
|
|
|
|
| 3 |
__global__ void dummy() {
|
| 4 |
// This kernel does nothing but serves as a placeholder
|
| 5 |
}
|
| 6 |
+
} // namespace
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -59,7 +59,9 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
-
gather_list = [
|
|
|
|
|
|
|
| 63 |
else:
|
| 64 |
gather_list = None
|
| 65 |
|
|
@@ -73,8 +75,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
| 73 |
if rank == state.worker_rank:
|
| 74 |
if state.gathered_grad is not None:
|
| 75 |
raise RuntimeError(
|
| 76 |
-
"Gather event already exists, which should not happen."
|
| 77 |
-
)
|
| 78 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
| 79 |
state.gather_event = torch.cuda.Event()
|
| 80 |
state.gather_event.record()
|
|
@@ -240,9 +241,10 @@ class Muon(torch.optim.Optimizer):
|
|
| 240 |
"""
|
| 241 |
Get the shard mesh for a parameter p on the given rank.
|
| 242 |
"""
|
| 243 |
-
assert isinstance(
|
|
|
|
| 244 |
|
| 245 |
-
if p.placements == (Shard(dim=0),):
|
| 246 |
# Case for FSDP
|
| 247 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 248 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
|
@@ -269,11 +271,12 @@ class Muon(torch.optim.Optimizer):
|
|
| 269 |
total_flops += flops
|
| 270 |
|
| 271 |
if self.debug:
|
| 272 |
-
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
|
|
|
| 273 |
|
| 274 |
-
ordered_params = sorted(
|
| 275 |
-
|
| 276 |
-
|
| 277 |
|
| 278 |
round_robin = 0
|
| 279 |
mesh = None
|
|
@@ -369,28 +372,29 @@ class Muon(torch.optim.Optimizer):
|
|
| 369 |
p.grad = g
|
| 370 |
|
| 371 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 372 |
-
params, group
|
| 373 |
-
)
|
| 374 |
|
| 375 |
def enqueue_gathers(start_idx, chunk_size):
|
| 376 |
-
for p in ordered_params[start_idx
|
| 377 |
state = param_to_state[id(p)]
|
| 378 |
-
_gather(p, state, self.rank, self.comm_stream,
|
|
|
|
| 379 |
|
| 380 |
def enqueue_computes(start_idx, chunk_size):
|
| 381 |
-
for p in ordered_params[start_idx
|
| 382 |
state = param_to_state[id(p)]
|
| 383 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
|
|
|
| 384 |
|
| 385 |
def enqueue_scatters(start_idx, chunk_size):
|
| 386 |
-
for p in ordered_params[start_idx
|
| 387 |
state = param_to_state[id(p)]
|
| 388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 389 |
-
_scatter(
|
| 390 |
-
|
| 391 |
-
)
|
| 392 |
|
| 393 |
-
chunk_size = dist.get_world_size(param_to_state[id(
|
|
|
|
| 394 |
|
| 395 |
# Wait grad update
|
| 396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
@@ -436,15 +440,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 436 |
continue
|
| 437 |
if isinstance(p.data, DTensor):
|
| 438 |
if all(
|
| 439 |
-
|
| 440 |
-
|
| 441 |
param_tensors.append(p)
|
| 442 |
else:
|
| 443 |
param_dtensors.append(p)
|
| 444 |
elif isinstance(p.data, torch.Tensor):
|
| 445 |
param_tensors.append(p)
|
| 446 |
else:
|
| 447 |
-
raise TypeError(
|
|
|
|
| 448 |
|
| 449 |
if self.debug:
|
| 450 |
print(
|
|
@@ -479,7 +484,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 479 |
# AdamW backup #
|
| 480 |
############################
|
| 481 |
|
| 482 |
-
params = [
|
|
|
|
|
|
|
| 483 |
lr = group["lr"]
|
| 484 |
beta1, beta2 = group["adamw_betas"]
|
| 485 |
eps = group["adamw_eps"]
|
|
|
|
| 59 |
|
| 60 |
if rank == state.worker_rank:
|
| 61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 62 |
+
gather_list = [
|
| 63 |
+
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
| 64 |
+
]
|
| 65 |
else:
|
| 66 |
gather_list = None
|
| 67 |
|
|
|
|
| 75 |
if rank == state.worker_rank:
|
| 76 |
if state.gathered_grad is not None:
|
| 77 |
raise RuntimeError(
|
| 78 |
+
"Gather event already exists, which should not happen.")
|
|
|
|
| 79 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
| 80 |
state.gather_event = torch.cuda.Event()
|
| 81 |
state.gather_event.record()
|
|
|
|
| 241 |
"""
|
| 242 |
Get the shard mesh for a parameter p on the given rank.
|
| 243 |
"""
|
| 244 |
+
assert isinstance(
|
| 245 |
+
p, DTensor), "Parallel Muon only supports DTensor parameters."
|
| 246 |
|
| 247 |
+
if p.placements == (Shard(dim=0), ):
|
| 248 |
# Case for FSDP
|
| 249 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 250 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
|
|
|
| 271 |
total_flops += flops
|
| 272 |
|
| 273 |
if self.debug:
|
| 274 |
+
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
| 275 |
+
flush=True)
|
| 276 |
|
| 277 |
+
ordered_params = sorted(params,
|
| 278 |
+
key=lambda p: param_to_flops[id(p)],
|
| 279 |
+
reverse=True)
|
| 280 |
|
| 281 |
round_robin = 0
|
| 282 |
mesh = None
|
|
|
|
| 372 |
p.grad = g
|
| 373 |
|
| 374 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
| 375 |
+
params, group)
|
|
|
|
| 376 |
|
| 377 |
def enqueue_gathers(start_idx, chunk_size):
|
| 378 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 379 |
state = param_to_state[id(p)]
|
| 380 |
+
_gather(p, state, self.rank, self.comm_stream,
|
| 381 |
+
group["none_grad"])
|
| 382 |
|
| 383 |
def enqueue_computes(start_idx, chunk_size):
|
| 384 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 385 |
state = param_to_state[id(p)]
|
| 386 |
+
_compute_u(state, group["ns_steps"], self.rank,
|
| 387 |
+
self.compute_stream)
|
| 388 |
|
| 389 |
def enqueue_scatters(start_idx, chunk_size):
|
| 390 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
| 391 |
state = param_to_state[id(p)]
|
| 392 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 393 |
+
_scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
|
| 394 |
+
self.comm_stream)
|
|
|
|
| 395 |
|
| 396 |
+
chunk_size = dist.get_world_size(param_to_state[id(
|
| 397 |
+
params[0])].process_group)
|
| 398 |
|
| 399 |
# Wait grad update
|
| 400 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
| 440 |
continue
|
| 441 |
if isinstance(p.data, DTensor):
|
| 442 |
if all(
|
| 443 |
+
isinstance(placement, Replicate)
|
| 444 |
+
for placement in p.placements):
|
| 445 |
param_tensors.append(p)
|
| 446 |
else:
|
| 447 |
param_dtensors.append(p)
|
| 448 |
elif isinstance(p.data, torch.Tensor):
|
| 449 |
param_tensors.append(p)
|
| 450 |
else:
|
| 451 |
+
raise TypeError(
|
| 452 |
+
f"Unsupported parameter type: {type(p.data)}")
|
| 453 |
|
| 454 |
if self.debug:
|
| 455 |
print(
|
|
|
|
| 484 |
# AdamW backup #
|
| 485 |
############################
|
| 486 |
|
| 487 |
+
params = [
|
| 488 |
+
p for p in group["params"] if not self.state[p]["use_muon"]
|
| 489 |
+
]
|
| 490 |
lr = group["lr"]
|
| 491 |
beta1, beta2 = group["adamw_betas"]
|
| 492 |
eps = group["adamw_eps"]
|