wyldecat commited on
Commit
3261444
·
1 Parent(s): 79fc8ba

feat(muon): add test for muon

Browse files
README.md CHANGED
@@ -77,3 +77,7 @@ The following tools are run via pre-commit:
77
  ```bash
78
  pre-commit run isort --all-files
79
  ```
 
 
 
 
 
77
  ```bash
78
  pre-commit run isort --all-files
79
  ```
80
+
81
+ ### Test
82
+
83
+ - There is a [simple unittest for Parallel Muon](./test/test_muon/README.md)
test/test_muon/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Muon Optimizer Test
2
+
3
+ This directory contains a test script for the **Muon optimizer**.
4
+
5
+ To execute the test, simply run:
6
+
7
+ ```bash
8
+ # By default, the test will use 8 GPUs.
9
+ ./run_test.sh
10
+ ```
11
+
12
+ The number of GPUs can be controlled with the NGPU environment variable.
13
+ For example, to run with 4 GPUs:
14
+
15
+ ```bash
16
+ NGPU=4 ./run_test.sh
17
+ ```
18
+
19
+ ## Limitations:
20
+ - Multi-node execution is not supported yet.
21
+ - Ensure that the specified number of GPUs is available on your machine before running.
test/test_muon/__init__.py ADDED
File without changes
test/test_muon/muon.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../torch-ext/optimizer/muon.py
test/test_muon/run_test.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ NGPU=${NGPU:-"8"}
2
+ torchrun --nproc-per-node=8 test.py
test/test_muon/test.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from muon import Muon
6
+ from torch.distributed.fsdp import FSDPModule, fully_shard
7
+ from torch.distributed.tensor import DTensor
8
+ from torch.distributed.tensor.placement_types import Replicate
9
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
10
+ PreTrainedTokenizerBase)
11
+
12
+ logger = logging.getLogger(__name__)
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+
16
+ def load_model(fsdp: bool) -> torch.nn.Module:
17
+ model_name = "Motif-Technologies/Motif-2.6B"
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_name,
20
+ trust_remote_code=True,
21
+ ).bfloat16().cuda()
22
+
23
+ torch.manual_seed(0)
24
+ random_grads = []
25
+ for param in model.parameters():
26
+ random_grad = torch.randn_like(param,
27
+ device=param.device,
28
+ dtype=param.dtype)
29
+ random_grads.append(random_grad)
30
+
31
+ if fsdp:
32
+ for layer in model.model.layers:
33
+ fully_shard(layer)
34
+ layer.reshard()
35
+ fully_shard(model)
36
+ model.reshard()
37
+
38
+ for i, param in enumerate(model.parameters()):
39
+ if isinstance(param.data, DTensor):
40
+ unsharded_grad = DTensor.from_local(
41
+ random_grads[i],
42
+ device_mesh=param.data.device_mesh,
43
+ placements=[Replicate()] * param.data.device_mesh.ndim,
44
+ )
45
+ sharded_grad = unsharded_grad.redistribute(
46
+ device_mesh=param.data.device_mesh,
47
+ placements=param.data.placements)
48
+ param.grad = sharded_grad
49
+ else:
50
+ param.grad = random_grads[i]
51
+
52
+ return model
53
+
54
+
55
+ def run_muon(fsdp: bool) -> torch.nn.Module:
56
+ model = load_model(fsdp=fsdp)
57
+ optim = Muon(model)
58
+ optim.step()
59
+
60
+ return model
61
+
62
+
63
+ def compare_results(parallel_muon_result: torch.nn.Module,
64
+ sequential_muon_result: torch.nn.Module) -> None:
65
+ for (name_p, p), (name_s,
66
+ s) in zip(parallel_muon_result.named_parameters(),
67
+ sequential_muon_result.named_parameters()):
68
+ p = p.data.full_tensor()
69
+ s = s.data
70
+ # Parallel Muon should exactly match Sequential Muon
71
+ if torch.abs(p - s).max() > 0:
72
+ max_diff_index = torch.argmax(torch.abs(p - s))
73
+ logger.error(f"Models differ at parameter {name_p}")
74
+ return
75
+ logger.info("Models match!")
76
+
77
+
78
+ def test_muon():
79
+ parallel_muon_result = run_muon(fsdp=True)
80
+ sequential_muon_result = run_muon(fsdp=False)
81
+
82
+ compare_results(parallel_muon_result, sequential_muon_result)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ dist.init_process_group(backend="nccl")
87
+ torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
88
+ test_muon()