TaehyunKim commited on
Commit
9d8a41d
·
unverified ·
2 Parent(s): 15b5d41 db36e39

Merge pull request #1 from MotifTechnologies/pre-commit_test_and_apply_lint

Browse files
.github/workflows/pre-commit.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pre-commit
2
+
3
+ on:
4
+ pull_request:
5
+ push:
6
+ branches: [ main, master ]
7
+
8
+ jobs:
9
+ run-pre-commit:
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ contents: read
13
+ pull-requests: read
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - uses: actions/setup-python@v5
18
+ with:
19
+ python-version: "3.11"
20
+
21
+ - name: Cache pre-commit
22
+ uses: actions/cache@v4
23
+ with:
24
+ path: ~/.cache/pre-commit
25
+ key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
26
+ restore-keys: |
27
+ pre-commit-${{ runner.os }}-
28
+
29
+ - name: Run pre-commit
30
+ uses: pre-commit/[email protected]
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_install_hook_types:
2
+ - pre-commit
3
+ - commit-msg
4
+ default_stages:
5
+ - pre-commit # Run locally
6
+ - manual # Run in CI
7
+ exclude: '(build|result)/.*'
8
+ repos:
9
+ - repo: https://github.com/google/yapf
10
+ rev: v0.43.0
11
+ hooks:
12
+ - id: yapf
13
+ args: [--in-place, --verbose]
14
+ - repo: https://github.com/crate-ci/typos
15
+ rev: v1.34.0
16
+ hooks:
17
+ - id: typos
18
+ exclude: '.gitattributes'
19
+ - repo: https://github.com/PyCQA/isort
20
+ rev: 6.0.1
21
+ hooks:
22
+ - id: isort
23
+ - repo: https://github.com/pre-commit/mirrors-clang-format
24
+ rev: v20.1.3
25
+ hooks:
26
+ - id: clang-format
27
+ types_or: [c++, cuda]
28
+ args: [--style=file, --verbose]
29
+ - repo: https://github.com/jackdewinter/pymarkdown
30
+ rev: v0.9.29
31
+ hooks:
32
+ - id: pymarkdown
33
+ args: [fix]
34
+ - repo: https://github.com/rhysd/actionlint
35
+ rev: v1.7.7
36
+ hooks:
37
+ - id: actionlint
README.md CHANGED
@@ -10,7 +10,7 @@ Optimizer is a python package that provides:
10
  - PyTorch implementation of recent optimizer algorithms
11
  - with support for parallelism techniques for efficient large-scale training.
12
 
13
- ### Currently implemented
14
  - [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
15
 
16
  ## Usage
@@ -31,4 +31,49 @@ optim = optimizer.Muon(
31
  momentum=0.9,
32
  weight_decay=1e-4,
33
  )
34
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  - PyTorch implementation of recent optimizer algorithms
11
  - with support for parallelism techniques for efficient large-scale training.
12
 
13
+ ## Currently implemented
14
  - [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
15
 
16
  ## Usage
 
31
  momentum=0.9,
32
  weight_decay=1e-4,
33
  )
34
+ ```
35
+
36
+ ## Pre-commit Hooks
37
+
38
+ This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
39
+
40
+ ### Setup
41
+
42
+ 1. Install pre-commit:
43
+
44
+ ```bash
45
+ pip install pre-commit
46
+ ```
47
+
48
+ 2. Install the git hooks:
49
+
50
+ ```bash
51
+ pre-commit install
52
+ ```
53
+
54
+ Once installed, the configured hooks will run automatically on each commit.
55
+
56
+ ### Included Hooks
57
+
58
+ The following tools are run via pre-commit:
59
+
60
+ - **[yapf](https://github.com/google/yapf)** – Python code formatter
61
+ - **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
62
+ - **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
63
+ - **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
64
+ - **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
65
+ - **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
66
+
67
+ ### Usage
68
+
69
+ - Run all checks on the entire codebase:
70
+
71
+ ```bash
72
+ pre-commit run --all-files
73
+ ```
74
+
75
+ - Run a specific hook (example: isort):
76
+
77
+ ```bash
78
+ pre-commit run isort --all-files
79
+ ```
optimizer/dummy.cu CHANGED
@@ -3,4 +3,4 @@ namespace {
3
  __global__ void dummy() {
4
  // This kernel does nothing but serves as a placeholder
5
  }
6
- }
 
3
  __global__ void dummy() {
4
  // This kernel does nothing but serves as a placeholder
5
  }
6
+ } // namespace
torch-ext/optimizer/muon.py CHANGED
@@ -59,7 +59,9 @@ def _gather(p, state, rank, comm_stream, none_grad):
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
- gather_list = [torch.empty_like(g.to_local()) for _ in range(num_ranks)]
 
 
63
  else:
64
  gather_list = None
65
 
@@ -73,8 +75,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
73
  if rank == state.worker_rank:
74
  if state.gathered_grad is not None:
75
  raise RuntimeError(
76
- "Gather event already exists, which should not happen."
77
- )
78
  state.gathered_grad = torch.cat(gather_list, dim=0)
79
  state.gather_event = torch.cuda.Event()
80
  state.gather_event.record()
@@ -240,9 +241,10 @@ class Muon(torch.optim.Optimizer):
240
  """
241
  Get the shard mesh for a parameter p on the given rank.
242
  """
243
- assert isinstance(p, DTensor), "Parallel Muon only supports DTensor parameters."
 
244
 
245
- if p.placements == (Shard(dim=0),):
246
  # Case for FSDP
247
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
248
  elif p.placements == (Replicate(), Shard(dim=0)):
@@ -269,11 +271,12 @@ class Muon(torch.optim.Optimizer):
269
  total_flops += flops
270
 
271
  if self.debug:
272
- print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
 
273
 
274
- ordered_params = sorted(
275
- params, key=lambda p: param_to_flops[id(p)], reverse=True
276
- )
277
 
278
  round_robin = 0
279
  mesh = None
@@ -369,28 +372,29 @@ class Muon(torch.optim.Optimizer):
369
  p.grad = g
370
 
371
  param_to_state, ordered_params = self.init_state_and_assign_params(
372
- params, group
373
- )
374
 
375
  def enqueue_gathers(start_idx, chunk_size):
376
- for p in ordered_params[start_idx : start_idx + chunk_size]:
377
  state = param_to_state[id(p)]
378
- _gather(p, state, self.rank, self.comm_stream, group["none_grad"])
 
379
 
380
  def enqueue_computes(start_idx, chunk_size):
381
- for p in ordered_params[start_idx : start_idx + chunk_size]:
382
  state = param_to_state[id(p)]
383
- _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
 
384
 
385
  def enqueue_scatters(start_idx, chunk_size):
386
- for p in ordered_params[start_idx : start_idx + chunk_size]:
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
- _scatter(
390
- p, state, lr, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
- )
392
 
393
- chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
394
 
395
  # Wait grad update
396
  self.comm_stream.wait_stream(torch.cuda.current_stream())
@@ -436,15 +440,16 @@ class Muon(torch.optim.Optimizer):
436
  continue
437
  if isinstance(p.data, DTensor):
438
  if all(
439
- isinstance(placement, Replicate) for placement in p.placements
440
- ):
441
  param_tensors.append(p)
442
  else:
443
  param_dtensors.append(p)
444
  elif isinstance(p.data, torch.Tensor):
445
  param_tensors.append(p)
446
  else:
447
- raise TypeError(f"Unsupported parameter type: {type(p.data)}")
 
448
 
449
  if self.debug:
450
  print(
@@ -479,7 +484,9 @@ class Muon(torch.optim.Optimizer):
479
  # AdamW backup #
480
  ############################
481
 
482
- params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 
 
483
  lr = group["lr"]
484
  beta1, beta2 = group["adamw_betas"]
485
  eps = group["adamw_eps"]
 
59
 
60
  if rank == state.worker_rank:
61
  num_ranks = dist.get_world_size(group=state.process_group)
62
+ gather_list = [
63
+ torch.empty_like(g.to_local()) for _ in range(num_ranks)
64
+ ]
65
  else:
66
  gather_list = None
67
 
 
75
  if rank == state.worker_rank:
76
  if state.gathered_grad is not None:
77
  raise RuntimeError(
78
+ "Gather event already exists, which should not happen.")
 
79
  state.gathered_grad = torch.cat(gather_list, dim=0)
80
  state.gather_event = torch.cuda.Event()
81
  state.gather_event.record()
 
241
  """
242
  Get the shard mesh for a parameter p on the given rank.
243
  """
244
+ assert isinstance(
245
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
246
 
247
+ if p.placements == (Shard(dim=0), ):
248
  # Case for FSDP
249
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
250
  elif p.placements == (Replicate(), Shard(dim=0)):
 
271
  total_flops += flops
272
 
273
  if self.debug:
274
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
275
+ flush=True)
276
 
277
+ ordered_params = sorted(params,
278
+ key=lambda p: param_to_flops[id(p)],
279
+ reverse=True)
280
 
281
  round_robin = 0
282
  mesh = None
 
372
  p.grad = g
373
 
374
  param_to_state, ordered_params = self.init_state_and_assign_params(
375
+ params, group)
 
376
 
377
  def enqueue_gathers(start_idx, chunk_size):
378
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
379
  state = param_to_state[id(p)]
380
+ _gather(p, state, self.rank, self.comm_stream,
381
+ group["none_grad"])
382
 
383
  def enqueue_computes(start_idx, chunk_size):
384
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
385
  state = param_to_state[id(p)]
386
+ _compute_u(state, group["ns_steps"], self.rank,
387
+ self.compute_stream)
388
 
389
  def enqueue_scatters(start_idx, chunk_size):
390
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
391
  state = param_to_state[id(p)]
392
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
393
+ _scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
394
+ self.comm_stream)
 
395
 
396
+ chunk_size = dist.get_world_size(param_to_state[id(
397
+ params[0])].process_group)
398
 
399
  # Wait grad update
400
  self.comm_stream.wait_stream(torch.cuda.current_stream())
 
440
  continue
441
  if isinstance(p.data, DTensor):
442
  if all(
443
+ isinstance(placement, Replicate)
444
+ for placement in p.placements):
445
  param_tensors.append(p)
446
  else:
447
  param_dtensors.append(p)
448
  elif isinstance(p.data, torch.Tensor):
449
  param_tensors.append(p)
450
  else:
451
+ raise TypeError(
452
+ f"Unsupported parameter type: {type(p.data)}")
453
 
454
  if self.debug:
455
  print(
 
484
  # AdamW backup #
485
  ############################
486
 
487
+ params = [
488
+ p for p in group["params"] if not self.state[p]["use_muon"]
489
+ ]
490
  lr = group["lr"]
491
  beta1, beta2 = group["adamw_betas"]
492
  eps = group["adamw_eps"]