wyldecat commited on
Commit
79fc8ba
·
1 Parent(s): 268d190

chore: add .gitignore

Browse files
Files changed (2) hide show
  1. .gitignore +22 -0
  2. torch-ext/optimizer/muon.py +11 -6
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .idea
3
+ .DS_Store
4
+ *.egg-info
5
+ build
6
+ outputs
7
+ dist/*
8
+ .vscode
9
+
10
+ # data
11
+ data
12
+ out
13
+ wandb
14
+
15
+ torchtitan/datasets/**/*.model
16
+ torchtitan/experiments/flux/assets/*
17
+
18
+ # temp files
19
+ *.log
20
+ error.json
21
+ _remote_module_non_scriptable.py
22
+ .git_disabled/
torch-ext/optimizer/muon.py CHANGED
@@ -83,6 +83,7 @@ def _gather(p, state, rank, comm_stream, none_grad):
83
  state.gathered_grad = None
84
  state.gather_event = None
85
  if none_grad:
 
86
  p.grad = None
87
 
88
 
@@ -98,6 +99,7 @@ def _compute_u(state, steps, rank, compute_stream):
98
  state.compute_event = torch.cuda.Event()
99
  state.compute_event.record()
100
  # Clear the gathered gradient to free memory
 
101
  state.gathered_grad = None
102
  else:
103
  state.computed_u = None
@@ -106,7 +108,6 @@ def _compute_u(state, steps, rank, compute_stream):
106
 
107
  @torch.no_grad()
108
  def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
109
- u = state.computed_u
110
 
111
  with torch.cuda.stream(comm_stream):
112
  if rank == state.worker_rank:
@@ -114,27 +115,31 @@ def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
114
  if state.compute_event is None:
115
  raise RuntimeError("Compute event must be set before scatter.")
116
  comm_stream.wait_event(state.compute_event)
 
117
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
 
118
  else:
119
  scatter_list = None
120
 
121
- u = torch.empty_like(p.to_local())
122
  torch.distributed.scatter(
123
- u,
124
  scatter_list=scatter_list,
125
  src=state.worker_rank,
126
  group=state.process_group,
127
  )
128
  if rank == state.worker_rank:
129
  # Clear u to free memory
 
130
  state.computed_u = None
131
- u = DTensor.from_local(
132
- u,
 
133
  placements=p.placements,
134
  device_mesh=p.device_mesh,
135
  )
136
  p.data.mul_(1 - lr * weight_decay)
137
- p.data.add_(u, alpha=-adjusted_lr)
138
 
139
 
140
  def default_is_muon(x, name):
 
83
  state.gathered_grad = None
84
  state.gather_event = None
85
  if none_grad:
86
+ p.grad.record_stream(comm_stream)
87
  p.grad = None
88
 
89
 
 
99
  state.compute_event = torch.cuda.Event()
100
  state.compute_event.record()
101
  # Clear the gathered gradient to free memory
102
+ state.gathered_grad.record_stream(compute_stream)
103
  state.gathered_grad = None
104
  else:
105
  state.computed_u = None
 
108
 
109
  @torch.no_grad()
110
  def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
 
111
 
112
  with torch.cuda.stream(comm_stream):
113
  if rank == state.worker_rank:
 
115
  if state.compute_event is None:
116
  raise RuntimeError("Compute event must be set before scatter.")
117
  comm_stream.wait_event(state.compute_event)
118
+ u = state.computed_u
119
  scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
120
+ scatter_list = [s.contiguous() for s in scatter_list]
121
  else:
122
  scatter_list = None
123
 
124
+ u_received = torch.empty_like(p.to_local())
125
  torch.distributed.scatter(
126
+ u_received,
127
  scatter_list=scatter_list,
128
  src=state.worker_rank,
129
  group=state.process_group,
130
  )
131
  if rank == state.worker_rank:
132
  # Clear u to free memory
133
+ state.computed_u.record_stream(comm_stream)
134
  state.computed_u = None
135
+
136
+ u_dtensor = DTensor.from_local(
137
+ u_received,
138
  placements=p.placements,
139
  device_mesh=p.device_mesh,
140
  )
141
  p.data.mul_(1 - lr * weight_decay)
142
+ p.data.add_(u_dtensor, alpha=-adjusted_lr)
143
 
144
 
145
  def default_is_muon(x, name):