feat(muon): add test for muon
Browse files- README.md +4 -0
- test/test_muon/README.md +21 -0
- test/test_muon/__init__.py +0 -0
- test/test_muon/muon.py +1 -0
- test/test_muon/run_test.sh +2 -0
- test/test_muon/test.py +88 -0
README.md
CHANGED
@@ -77,3 +77,7 @@ The following tools are run via pre-commit:
|
|
77 |
```bash
|
78 |
pre-commit run isort --all-files
|
79 |
```
|
|
|
|
|
|
|
|
|
|
77 |
```bash
|
78 |
pre-commit run isort --all-files
|
79 |
```
|
80 |
+
|
81 |
+
### Test
|
82 |
+
|
83 |
+
- There is a [simple unittest for Parallel Muon](./test/test_muon/README.md)
|
test/test_muon/README.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Muon Optimizer Test
|
2 |
+
|
3 |
+
This directory contains a test script for the **Muon optimizer**.
|
4 |
+
|
5 |
+
To execute the test, simply run:
|
6 |
+
|
7 |
+
```bash
|
8 |
+
# By default, the test will use 8 GPUs.
|
9 |
+
./run_test.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
The number of GPUs can be controlled with the NGPU environment variable.
|
13 |
+
For example, to run with 4 GPUs:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
NGPU=4 ./run_test.sh
|
17 |
+
```
|
18 |
+
|
19 |
+
## Limitations:
|
20 |
+
- Multi-node execution is not supported yet.
|
21 |
+
- Ensure that the specified number of GPUs is available on your machine before running.
|
test/test_muon/__init__.py
ADDED
File without changes
|
test/test_muon/muon.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../torch-ext/optimizer/muon.py
|
test/test_muon/run_test.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
NGPU=${NGPU:-"8"}
|
2 |
+
torchrun --nproc-per-node=8 test.py
|
test/test_muon/test.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from muon import Muon
|
6 |
+
from torch.distributed.fsdp import FSDPModule, fully_shard
|
7 |
+
from torch.distributed.tensor import DTensor
|
8 |
+
from torch.distributed.tensor.placement_types import Replicate
|
9 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
10 |
+
PreTrainedTokenizerBase)
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(fsdp: bool) -> torch.nn.Module:
|
17 |
+
model_name = "Motif-Technologies/Motif-2.6B"
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
model_name,
|
20 |
+
trust_remote_code=True,
|
21 |
+
).bfloat16().cuda()
|
22 |
+
|
23 |
+
torch.manual_seed(0)
|
24 |
+
random_grads = []
|
25 |
+
for param in model.parameters():
|
26 |
+
random_grad = torch.randn_like(param,
|
27 |
+
device=param.device,
|
28 |
+
dtype=param.dtype)
|
29 |
+
random_grads.append(random_grad)
|
30 |
+
|
31 |
+
if fsdp:
|
32 |
+
for layer in model.model.layers:
|
33 |
+
fully_shard(layer)
|
34 |
+
layer.reshard()
|
35 |
+
fully_shard(model)
|
36 |
+
model.reshard()
|
37 |
+
|
38 |
+
for i, param in enumerate(model.parameters()):
|
39 |
+
if isinstance(param.data, DTensor):
|
40 |
+
unsharded_grad = DTensor.from_local(
|
41 |
+
random_grads[i],
|
42 |
+
device_mesh=param.data.device_mesh,
|
43 |
+
placements=[Replicate()] * param.data.device_mesh.ndim,
|
44 |
+
)
|
45 |
+
sharded_grad = unsharded_grad.redistribute(
|
46 |
+
device_mesh=param.data.device_mesh,
|
47 |
+
placements=param.data.placements)
|
48 |
+
param.grad = sharded_grad
|
49 |
+
else:
|
50 |
+
param.grad = random_grads[i]
|
51 |
+
|
52 |
+
return model
|
53 |
+
|
54 |
+
|
55 |
+
def run_muon(fsdp: bool) -> torch.nn.Module:
|
56 |
+
model = load_model(fsdp=fsdp)
|
57 |
+
optim = Muon(model)
|
58 |
+
optim.step()
|
59 |
+
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def compare_results(parallel_muon_result: torch.nn.Module,
|
64 |
+
sequential_muon_result: torch.nn.Module) -> None:
|
65 |
+
for (name_p, p), (name_s,
|
66 |
+
s) in zip(parallel_muon_result.named_parameters(),
|
67 |
+
sequential_muon_result.named_parameters()):
|
68 |
+
p = p.data.full_tensor()
|
69 |
+
s = s.data
|
70 |
+
# Parallel Muon should exactly match Sequential Muon
|
71 |
+
if torch.abs(p - s).max() > 0:
|
72 |
+
max_diff_index = torch.argmax(torch.abs(p - s))
|
73 |
+
logger.error(f"Models differ at parameter {name_p}")
|
74 |
+
return
|
75 |
+
logger.info("Models match!")
|
76 |
+
|
77 |
+
|
78 |
+
def test_muon():
|
79 |
+
parallel_muon_result = run_muon(fsdp=True)
|
80 |
+
sequential_muon_result = run_muon(fsdp=False)
|
81 |
+
|
82 |
+
compare_results(parallel_muon_result, sequential_muon_result)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
dist.init_process_group(backend="nccl")
|
87 |
+
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
88 |
+
test_muon()
|