Merge pull request #1 from MotifTechnologies/pre-commit_test_and_apply_lint
Browse files- .github/workflows/pre-commit.yml +30 -0
- .pre-commit-config.yaml +37 -0
- README.md +47 -2
- optimizer/dummy.cu +1 -1
- torch-ext/optimizer/muon.py +31 -24
.github/workflows/pre-commit.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pre-commit
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
push:
|
6 |
+
branches: [ main, master ]
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
run-pre-commit:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
permissions:
|
12 |
+
contents: read
|
13 |
+
pull-requests: read
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v4
|
16 |
+
|
17 |
+
- uses: actions/setup-python@v5
|
18 |
+
with:
|
19 |
+
python-version: "3.11"
|
20 |
+
|
21 |
+
- name: Cache pre-commit
|
22 |
+
uses: actions/cache@v4
|
23 |
+
with:
|
24 |
+
path: ~/.cache/pre-commit
|
25 |
+
key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
|
26 |
+
restore-keys: |
|
27 |
+
pre-commit-${{ runner.os }}-
|
28 |
+
|
29 |
+
- name: Run pre-commit
|
30 |
+
uses: pre-commit/[email protected]
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_install_hook_types:
|
2 |
+
- pre-commit
|
3 |
+
- commit-msg
|
4 |
+
default_stages:
|
5 |
+
- pre-commit # Run locally
|
6 |
+
- manual # Run in CI
|
7 |
+
exclude: '(build|result)/.*'
|
8 |
+
repos:
|
9 |
+
- repo: https://github.com/google/yapf
|
10 |
+
rev: v0.43.0
|
11 |
+
hooks:
|
12 |
+
- id: yapf
|
13 |
+
args: [--in-place, --verbose]
|
14 |
+
- repo: https://github.com/crate-ci/typos
|
15 |
+
rev: v1.34.0
|
16 |
+
hooks:
|
17 |
+
- id: typos
|
18 |
+
exclude: '.gitattributes'
|
19 |
+
- repo: https://github.com/PyCQA/isort
|
20 |
+
rev: 6.0.1
|
21 |
+
hooks:
|
22 |
+
- id: isort
|
23 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
24 |
+
rev: v20.1.3
|
25 |
+
hooks:
|
26 |
+
- id: clang-format
|
27 |
+
types_or: [c++, cuda]
|
28 |
+
args: [--style=file, --verbose]
|
29 |
+
- repo: https://github.com/jackdewinter/pymarkdown
|
30 |
+
rev: v0.9.29
|
31 |
+
hooks:
|
32 |
+
- id: pymarkdown
|
33 |
+
args: [fix]
|
34 |
+
- repo: https://github.com/rhysd/actionlint
|
35 |
+
rev: v1.7.7
|
36 |
+
hooks:
|
37 |
+
- id: actionlint
|
README.md
CHANGED
@@ -10,7 +10,7 @@ Optimizer is a python package that provides:
|
|
10 |
- PyTorch implementation of recent optimizer algorithms
|
11 |
- with support for parallelism techniques for efficient large-scale training.
|
12 |
|
13 |
-
|
14 |
- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
|
15 |
|
16 |
## Usage
|
@@ -31,4 +31,49 @@ optim = optimizer.Muon(
|
|
31 |
momentum=0.9,
|
32 |
weight_decay=1e-4,
|
33 |
)
|
34 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
- PyTorch implementation of recent optimizer algorithms
|
11 |
- with support for parallelism techniques for efficient large-scale training.
|
12 |
|
13 |
+
## Currently implemented
|
14 |
- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
|
15 |
|
16 |
## Usage
|
|
|
31 |
momentum=0.9,
|
32 |
weight_decay=1e-4,
|
33 |
)
|
34 |
+
```
|
35 |
+
|
36 |
+
## Pre-commit Hooks
|
37 |
+
|
38 |
+
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
39 |
+
|
40 |
+
### Setup
|
41 |
+
|
42 |
+
1. Install pre-commit:
|
43 |
+
|
44 |
+
```bash
|
45 |
+
pip install pre-commit
|
46 |
+
```
|
47 |
+
|
48 |
+
2. Install the git hooks:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
pre-commit install
|
52 |
+
```
|
53 |
+
|
54 |
+
Once installed, the configured hooks will run automatically on each commit.
|
55 |
+
|
56 |
+
### Included Hooks
|
57 |
+
|
58 |
+
The following tools are run via pre-commit:
|
59 |
+
|
60 |
+
- **[yapf](https://github.com/google/yapf)** – Python code formatter
|
61 |
+
- **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
|
62 |
+
- **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
|
63 |
+
- **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
|
64 |
+
- **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
|
65 |
+
- **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
|
66 |
+
|
67 |
+
### Usage
|
68 |
+
|
69 |
+
- Run all checks on the entire codebase:
|
70 |
+
|
71 |
+
```bash
|
72 |
+
pre-commit run --all-files
|
73 |
+
```
|
74 |
+
|
75 |
+
- Run a specific hook (example: isort):
|
76 |
+
|
77 |
+
```bash
|
78 |
+
pre-commit run isort --all-files
|
79 |
+
```
|
optimizer/dummy.cu
CHANGED
@@ -3,4 +3,4 @@ namespace {
|
|
3 |
__global__ void dummy() {
|
4 |
// This kernel does nothing but serves as a placeholder
|
5 |
}
|
6 |
-
}
|
|
|
3 |
__global__ void dummy() {
|
4 |
// This kernel does nothing but serves as a placeholder
|
5 |
}
|
6 |
+
} // namespace
|
torch-ext/optimizer/muon.py
CHANGED
@@ -59,7 +59,9 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
-
gather_list = [
|
|
|
|
|
63 |
else:
|
64 |
gather_list = None
|
65 |
|
@@ -73,8 +75,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
73 |
if rank == state.worker_rank:
|
74 |
if state.gathered_grad is not None:
|
75 |
raise RuntimeError(
|
76 |
-
"Gather event already exists, which should not happen."
|
77 |
-
)
|
78 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
79 |
state.gather_event = torch.cuda.Event()
|
80 |
state.gather_event.record()
|
@@ -240,9 +241,10 @@ class Muon(torch.optim.Optimizer):
|
|
240 |
"""
|
241 |
Get the shard mesh for a parameter p on the given rank.
|
242 |
"""
|
243 |
-
assert isinstance(
|
|
|
244 |
|
245 |
-
if p.placements == (Shard(dim=0),):
|
246 |
# Case for FSDP
|
247 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
248 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
@@ -269,11 +271,12 @@ class Muon(torch.optim.Optimizer):
|
|
269 |
total_flops += flops
|
270 |
|
271 |
if self.debug:
|
272 |
-
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
|
|
273 |
|
274 |
-
ordered_params = sorted(
|
275 |
-
|
276 |
-
|
277 |
|
278 |
round_robin = 0
|
279 |
mesh = None
|
@@ -369,28 +372,29 @@ class Muon(torch.optim.Optimizer):
|
|
369 |
p.grad = g
|
370 |
|
371 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
372 |
-
params, group
|
373 |
-
)
|
374 |
|
375 |
def enqueue_gathers(start_idx, chunk_size):
|
376 |
-
for p in ordered_params[start_idx
|
377 |
state = param_to_state[id(p)]
|
378 |
-
_gather(p, state, self.rank, self.comm_stream,
|
|
|
379 |
|
380 |
def enqueue_computes(start_idx, chunk_size):
|
381 |
-
for p in ordered_params[start_idx
|
382 |
state = param_to_state[id(p)]
|
383 |
-
_compute_u(state, group["ns_steps"], self.rank,
|
|
|
384 |
|
385 |
def enqueue_scatters(start_idx, chunk_size):
|
386 |
-
for p in ordered_params[start_idx
|
387 |
state = param_to_state[id(p)]
|
388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
389 |
-
_scatter(
|
390 |
-
|
391 |
-
)
|
392 |
|
393 |
-
chunk_size = dist.get_world_size(param_to_state[id(
|
|
|
394 |
|
395 |
# Wait grad update
|
396 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
@@ -436,15 +440,16 @@ class Muon(torch.optim.Optimizer):
|
|
436 |
continue
|
437 |
if isinstance(p.data, DTensor):
|
438 |
if all(
|
439 |
-
|
440 |
-
|
441 |
param_tensors.append(p)
|
442 |
else:
|
443 |
param_dtensors.append(p)
|
444 |
elif isinstance(p.data, torch.Tensor):
|
445 |
param_tensors.append(p)
|
446 |
else:
|
447 |
-
raise TypeError(
|
|
|
448 |
|
449 |
if self.debug:
|
450 |
print(
|
@@ -479,7 +484,9 @@ class Muon(torch.optim.Optimizer):
|
|
479 |
# AdamW backup #
|
480 |
############################
|
481 |
|
482 |
-
params = [
|
|
|
|
|
483 |
lr = group["lr"]
|
484 |
beta1, beta2 = group["adamw_betas"]
|
485 |
eps = group["adamw_eps"]
|
|
|
59 |
|
60 |
if rank == state.worker_rank:
|
61 |
num_ranks = dist.get_world_size(group=state.process_group)
|
62 |
+
gather_list = [
|
63 |
+
torch.empty_like(g.to_local()) for _ in range(num_ranks)
|
64 |
+
]
|
65 |
else:
|
66 |
gather_list = None
|
67 |
|
|
|
75 |
if rank == state.worker_rank:
|
76 |
if state.gathered_grad is not None:
|
77 |
raise RuntimeError(
|
78 |
+
"Gather event already exists, which should not happen.")
|
|
|
79 |
state.gathered_grad = torch.cat(gather_list, dim=0)
|
80 |
state.gather_event = torch.cuda.Event()
|
81 |
state.gather_event.record()
|
|
|
241 |
"""
|
242 |
Get the shard mesh for a parameter p on the given rank.
|
243 |
"""
|
244 |
+
assert isinstance(
|
245 |
+
p, DTensor), "Parallel Muon only supports DTensor parameters."
|
246 |
|
247 |
+
if p.placements == (Shard(dim=0), ):
|
248 |
# Case for FSDP
|
249 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
250 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
|
|
271 |
total_flops += flops
|
272 |
|
273 |
if self.debug:
|
274 |
+
print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
|
275 |
+
flush=True)
|
276 |
|
277 |
+
ordered_params = sorted(params,
|
278 |
+
key=lambda p: param_to_flops[id(p)],
|
279 |
+
reverse=True)
|
280 |
|
281 |
round_robin = 0
|
282 |
mesh = None
|
|
|
372 |
p.grad = g
|
373 |
|
374 |
param_to_state, ordered_params = self.init_state_and_assign_params(
|
375 |
+
params, group)
|
|
|
376 |
|
377 |
def enqueue_gathers(start_idx, chunk_size):
|
378 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
379 |
state = param_to_state[id(p)]
|
380 |
+
_gather(p, state, self.rank, self.comm_stream,
|
381 |
+
group["none_grad"])
|
382 |
|
383 |
def enqueue_computes(start_idx, chunk_size):
|
384 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
385 |
state = param_to_state[id(p)]
|
386 |
+
_compute_u(state, group["ns_steps"], self.rank,
|
387 |
+
self.compute_stream)
|
388 |
|
389 |
def enqueue_scatters(start_idx, chunk_size):
|
390 |
+
for p in ordered_params[start_idx:start_idx + chunk_size]:
|
391 |
state = param_to_state[id(p)]
|
392 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
393 |
+
_scatter(p, state, lr, adjusted_lr, weight_decay, self.rank,
|
394 |
+
self.comm_stream)
|
|
|
395 |
|
396 |
+
chunk_size = dist.get_world_size(param_to_state[id(
|
397 |
+
params[0])].process_group)
|
398 |
|
399 |
# Wait grad update
|
400 |
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
440 |
continue
|
441 |
if isinstance(p.data, DTensor):
|
442 |
if all(
|
443 |
+
isinstance(placement, Replicate)
|
444 |
+
for placement in p.placements):
|
445 |
param_tensors.append(p)
|
446 |
else:
|
447 |
param_dtensors.append(p)
|
448 |
elif isinstance(p.data, torch.Tensor):
|
449 |
param_tensors.append(p)
|
450 |
else:
|
451 |
+
raise TypeError(
|
452 |
+
f"Unsupported parameter type: {type(p.data)}")
|
453 |
|
454 |
if self.debug:
|
455 |
print(
|
|
|
484 |
# AdamW backup #
|
485 |
############################
|
486 |
|
487 |
+
params = [
|
488 |
+
p for p in group["params"] if not self.state[p]["use_muon"]
|
489 |
+
]
|
490 |
lr = group["lr"]
|
491 |
beta1, beta2 = group["adamw_betas"]
|
492 |
eps = group["adamw_eps"]
|