TaehyunKimMotif commited on
Commit
db36e39
·
1 Parent(s): 80187d6

applied lint

Browse files
.pre-commit-config.yaml CHANGED
@@ -15,6 +15,7 @@ repos:
15
  rev: v1.34.0
16
  hooks:
17
  - id: typos
 
18
  - repo: https://github.com/PyCQA/isort
19
  rev: 6.0.1
20
  hooks:
 
15
  rev: v1.34.0
16
  hooks:
17
  - id: typos
18
+ exclude: '.gitattributes'
19
  - repo: https://github.com/PyCQA/isort
20
  rev: 6.0.1
21
  hooks:
optimizer/dummy.cu CHANGED
@@ -3,4 +3,4 @@ namespace {
3
  __global__ void dummy() {
4
  // This kernel does nothing but serves as a placeholder
5
  }
6
- }
 
3
  __global__ void dummy() {
4
  // This kernel does nothing but serves as a placeholder
5
  }
6
+ } // namespace
torch-ext/optimizer/muon.py CHANGED
@@ -59,7 +59,9 @@ def _gather(p, state, rank, comm_stream, none_grad):
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +75,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -240,9 +241,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +271,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -369,28 +372,29 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, lr, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -436,15 +440,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +484,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [
63
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
64
+ ]
65
  else:
66
  gather_list = None
67
 
 
75
  if rank == state.worker_rank:
76
  if state.gathered_grad is not None:
77
  raise RuntimeError(
78
+ "Gather event already exists, which should not happen.")
 
79
  state.gathered_grad = torch.cat(gather_list, dim=0)
80
  state.gather_event = torch.cuda.Event()
81
  state.gather_event.record()
 
241
  """
242
  Get the shard mesh for a parameter p on the given rank.
243
  """
244
+ assert isinstance(
245
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
246
 
247
+ if p.placements == (Shard(dim=0), ):
248
  # Case for FSDP
249
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
250
  elif p.placements == (Replicate(), Shard(dim=0)):
 
271
  total_flops += flops
272
 
273
  if self.debug:
274
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
275
+ flush=True)
276
 
277
+ ordered_params = sorted(params,
278
+ key=lambda p: param_to_flops[id(p)],
279
+ reverse=True)
280
 
281
  round_robin = 0
282
  mesh = None
 
372
  p.grad = g
373
 
374
  param_to_state, ordered_params = self.init_state_and_assign_params(
375
+ params, group)
 
376
 
377
  def enqueue_gathers(start_idx, chunk_size):
378
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
379
  state = param_to_state[id(p)]
380
+ _gather(p, state, self.rank, self.comm_stream,
381
+ group["none_grad"])
382
 
383
  def enqueue_computes(start_idx, chunk_size):
384
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
385
  state = param_to_state[id(p)]
386
+ _compute_u(state, group["ns_steps"], self.rank,
387
+ self.compute_stream)
388
 
389
  def enqueue_scatters(start_idx, chunk_size):
390
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
391
  state = param_to_state[id(p)]
392
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
393
+ _scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
394
+ self.comm_stream)
 
395
 
396
+ chunk_size = dist.get_world_size(param_to_state[id(
397
+ params[0])].process_group)
398
 
399
  # Wait grad update
400
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  continue
441
  if isinstance(p.data, DTensor):
442
  if all(
443
+ isinstance(placement, Replicate)
444
+ for placement in p.placements):
445
  param_tensors.append(p)
446
  else:
447
  param_dtensors.append(p)
448
  elif isinstance(p.data, torch.Tensor):
449
  param_tensors.append(p)
450
  else:
451
+ raise TypeError(
452
+ f"Unsupported parameter type: {type(p.data)}")
453
 
454
  if self.debug:
455
  print(
 
484
  # AdamW backup #
485
  ############################
486
 
487
+ params = [
488
+ p for p in group["params"] if not self.state[p]["use_muon"]
489
+ ]
490
  lr = group["lr"]
491
  beta1, beta2 = group["adamw_betas"]
492
  eps = group["adamw_eps"]