chore: add .gitignore
Browse files- .gitignore +22 -0
- torch-ext/optimizer/muon.py +11 -6
.gitignore
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.idea
|
3 |
+
.DS_Store
|
4 |
+
*.egg-info
|
5 |
+
build
|
6 |
+
outputs
|
7 |
+
dist/*
|
8 |
+
.vscode
|
9 |
+
|
10 |
+
# data
|
11 |
+
data
|
12 |
+
out
|
13 |
+
wandb
|
14 |
+
|
15 |
+
torchtitan/datasets/**/*.model
|
16 |
+
torchtitan/experiments/flux/assets/*
|
17 |
+
|
18 |
+
# temp files
|
19 |
+
*.log
|
20 |
+
error.json
|
21 |
+
_remote_module_non_scriptable.py
|
22 |
+
.git_disabled/
|
torch-ext/optimizer/muon.py
CHANGED
@@ -83,6 +83,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
|
|
83 |
state.gathered_grad = None
|
84 |
state.gather_event = None
|
85 |
if none_grad:
|
|
|
86 |
p.grad = None
|
87 |
|
88 |
|
@@ -98,6 +99,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
98 |
state.compute_event = torch.cuda.Event()
|
99 |
state.compute_event.record()
|
100 |
# Clear the gathered gradient to free memory
|
|
|
101 |
state.gathered_grad = None
|
102 |
else:
|
103 |
state.computed_u = None
|
@@ -106,7 +108,6 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
106 |
|
107 |
@torch.no_grad()
|
108 |
def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
|
109 |
-
u = state.computed_u
|
110 |
|
111 |
with torch.cuda.stream(comm_stream):
|
112 |
if rank == state.worker_rank:
|
@@ -114,27 +115,31 @@ def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
|
|
114 |
if state.compute_event is None:
|
115 |
raise RuntimeError("Compute event must be set before scatter.")
|
116 |
comm_stream.wait_event(state.compute_event)
|
|
|
117 |
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
|
|
118 |
else:
|
119 |
scatter_list = None
|
120 |
|
121 |
-
|
122 |
torch.distributed.scatter(
|
123 |
-
|
124 |
scatter_list=scatter_list,
|
125 |
src=state.worker_rank,
|
126 |
group=state.process_group,
|
127 |
)
|
128 |
if rank == state.worker_rank:
|
129 |
# Clear u to free memory
|
|
|
130 |
state.computed_u = None
|
131 |
-
|
132 |
-
|
|
|
133 |
placements=p.placements,
|
134 |
device_mesh=p.device_mesh,
|
135 |
)
|
136 |
p.data.mul_(1 - lr * weight_decay)
|
137 |
-
p.data.add_(
|
138 |
|
139 |
|
140 |
def default_is_muon(x, name):
|
|
|
83 |
state.gathered_grad = None
|
84 |
state.gather_event = None
|
85 |
if none_grad:
|
86 |
+
p.grad.record_stream(comm_stream)
|
87 |
p.grad = None
|
88 |
|
89 |
|
|
|
99 |
state.compute_event = torch.cuda.Event()
|
100 |
state.compute_event.record()
|
101 |
# Clear the gathered gradient to free memory
|
102 |
+
state.gathered_grad.record_stream(compute_stream)
|
103 |
state.gathered_grad = None
|
104 |
else:
|
105 |
state.computed_u = None
|
|
|
108 |
|
109 |
@torch.no_grad()
|
110 |
def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
|
|
|
111 |
|
112 |
with torch.cuda.stream(comm_stream):
|
113 |
if rank == state.worker_rank:
|
|
|
115 |
if state.compute_event is None:
|
116 |
raise RuntimeError("Compute event must be set before scatter.")
|
117 |
comm_stream.wait_event(state.compute_event)
|
118 |
+
u = state.computed_u
|
119 |
scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
|
120 |
+
scatter_list = [s.contiguous() for s in scatter_list]
|
121 |
else:
|
122 |
scatter_list = None
|
123 |
|
124 |
+
u_received = torch.empty_like(p.to_local())
|
125 |
torch.distributed.scatter(
|
126 |
+
u_received,
|
127 |
scatter_list=scatter_list,
|
128 |
src=state.worker_rank,
|
129 |
group=state.process_group,
|
130 |
)
|
131 |
if rank == state.worker_rank:
|
132 |
# Clear u to free memory
|
133 |
+
state.computed_u.record_stream(comm_stream)
|
134 |
state.computed_u = None
|
135 |
+
|
136 |
+
u_dtensor = DTensor.from_local(
|
137 |
+
u_received,
|
138 |
placements=p.placements,
|
139 |
device_mesh=p.device_mesh,
|
140 |
)
|
141 |
p.data.mul_(1 - lr * weight_decay)
|
142 |
+
p.data.add_(u_dtensor, alpha=-adjusted_lr)
|
143 |
|
144 |
|
145 |
def default_is_muon(x, name):
|