ehartford commited on
Commit
e50ecc3
·
verified ·
1 Parent(s): fbc91b3

Upload init_opengr00t_zero.py

Browse files
Files changed (1) hide show
  1. init_opengr00t_zero.py +313 -0
init_opengr00t_zero.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ─────────────────── make local repo override any wheel ────────────────────
3
+ import sys, os; sys.path.insert(0, os.path.abspath("."))
4
+
5
+ # ─────────────────── Flash-Attention + CUDA stubs ──────────────────────────
6
+ import types, torch, torch.nn.functional as F, importlib.machinery as im
7
+ flash_pkg = types.ModuleType("flash_attn"); flash_pkg.__spec__ = im.ModuleSpec("flash_attn", loader=None, is_package=True); flash_pkg.__path__=[]
8
+ sys.modules["flash_attn"] = flash_pkg
9
+ fa = types.ModuleType("flash_attn.flash_attn_interface"); fa.__spec__ = im.ModuleSpec("flash_attn.flash_attn_interface", loader=None)
10
+ def _sdpa(qkv,*_,causal=False,**__): q,k,v = qkv.unbind(1); q,k,v = (t.unsqueeze(0) for t in(q,k,v)); return F.scaled_dot_product_attention(q,k,v,is_causal=causal).squeeze(0)
11
+ for s in ("flash_attn_unpadded_qkvpacked_func","flash_attn_unpadded_kvpacked_func","flash_attn_varlen_qkvpacked_func","flash_attn_varlen_kvpacked_func"): setattr(fa, s, _sdpa)
12
+ sys.modules["flash_attn.flash_attn_interface"] = fa; flash_pkg.flash_attn_interface = fa
13
+ pad = types.ModuleType("flash_attn.bert_padding"); pad.__spec__ = im.ModuleSpec("flash_attn.bert_padding", loader=None)
14
+ pad.pad_input = lambda x,*a,**k:(x,None); pad.unpad_input = lambda x,*a,**k:x
15
+ sys.modules["flash_attn.bert_padding"] = pad; flash_pkg.bert_padding = pad
16
+
17
+ if not torch.cuda.is_available():
18
+ torch.cuda.is_available=lambda:False
19
+ torch.cuda.get_device_capability=lambda dev=None:(0,0)
20
+ torch.cuda.current_device=lambda:0
21
+ torch.cuda.get_device_properties=lambda dev=None:types.SimpleNamespace(major=0,minor=0)
22
+
23
+ import importlib.metadata as _im
24
+ if "flash_attn" not in _im.packages_distributions():
25
+ rv, rd = _im.version, _im.distribution
26
+ _im.version = lambda p:"0.0.0" if p=="flash_attn" else rv(p)
27
+ _im.distribution = lambda p:types.SimpleNamespace(version="0.0.0") if p=="flash_attn" else rd(p)
28
+
29
+ # ─────────────────── std imports ───────────────────────────────────────────
30
+ from pathlib import Path
31
+ import argparse, json, shutil
32
+ from huggingface_hub import hf_hub_download
33
+ from transformers import AutoConfig
34
+ from gr00t.model.gr00t_n1 import GR00T_N1_5
35
+
36
+ # ─────────────────── helpers ───────────────────────────────────────────────
37
+ def patched_cfg():
38
+ p = hf_hub_download("nvidia/GR00T-N1.5-3B", "config.json")
39
+ d = json.load(open(p))
40
+ if d.get("model_type") != "gr00t_n1_5":
41
+ d["model_type"] = "gr00t_n1_5"
42
+ patched = Path(p).with_name("config_patched.json")
43
+ patched.write_text(json.dumps(d)); return str(patched)
44
+ return p
45
+
46
+ def build_blank():
47
+ cfg = AutoConfig.from_pretrained(patched_cfg(),
48
+ trust_remote_code=True,
49
+ local_files_only=True)
50
+ cfg.backbone_cfg.update(dict(tune_llm=True)) # enable L-tower
51
+ cfg.backbone_cfg.pop("checkpoint_path", None)
52
+ cfg.backbone_cfg.pop("use_pretrained", None)
53
+ cfg.action_head_cfg.pop("checkpoint_path", None)
54
+ torch.manual_seed(0)
55
+ return GR00T_N1_5(cfg, local_model_path="") # random weights
56
+
57
+ def maybe_add_lm_head(model):
58
+ """Ensure lm_head is properly initialized with weights"""
59
+ # Navigate to the language model
60
+ lm = model.backbone.eagle_model.language_model
61
+
62
+ # Get dimensions from embed_tokens
63
+ embed_tokens = lm.model.embed_tokens
64
+ vocab_size = embed_tokens.num_embeddings
65
+ hidden_size = embed_tokens.embedding_dim
66
+
67
+ print(f"Embedding dimensions: vocab_size={vocab_size}, hidden_size={hidden_size}")
68
+
69
+ # Expected shape based on architecture: [151680, 2048]
70
+ if vocab_size != 151680 or hidden_size != 2048:
71
+ print(f"⚠️ Warning: Unexpected dimensions. Expected vocab=151680, hidden=2048")
72
+
73
+ # Check if lm_head exists
74
+ if hasattr(lm, "lm_head"):
75
+ print(f"lm_head attribute exists: {lm.lm_head is not None}")
76
+
77
+ # Even if lm_head exists, it might not have weights properly initialized
78
+ # Just replace it with a properly initialized one
79
+ print("Creating new lm_head with proper initialization...")
80
+ else:
81
+ print("lm_head attribute missing, creating...")
82
+
83
+ # Create a new lm_head with proper initialization
84
+ # Note: nn.Linear uses (in_features, out_features), so it's (hidden_size, vocab_size)
85
+ new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
86
+
87
+ # Initialize weights with normal distribution (std=0.02 is standard for LM heads)
88
+ torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02)
89
+
90
+ # Convert to bfloat16 to match backbone
91
+ new_lm_head.weight.data = new_lm_head.weight.data.to(torch.bfloat16)
92
+
93
+ # Replace the lm_head
94
+ lm.lm_head = new_lm_head
95
+
96
+ print(f"✓ Created lm_head: Linear({hidden_size}, {vocab_size}, bias=False)")
97
+ print(f" Weight shape: {lm.lm_head.weight.shape}")
98
+ print(f" Weight dtype: {lm.lm_head.weight.dtype}")
99
+ print(f" Parameters: {lm.lm_head.weight.numel() / 1e6:.1f}M")
100
+
101
+ def set_mixed(model):
102
+ """Set mixed precision: backbone in bf16, action head in fp32"""
103
+ for n,p in model.named_parameters():
104
+ if n.startswith("backbone.") or "lm_head" in n:
105
+ p.data = p.data.to(torch.bfloat16)
106
+ else:
107
+ p.data = p.data.to(torch.float32)
108
+
109
+ def copy_tokenizer(out):
110
+ for f in ("tokenizer.json","tokenizer_config.json","vocab.txt","special_tokens_map.json"):
111
+ try: shutil.copy(hf_hub_download("nvidia/GR00T-N1.5-3B", f), out/f)
112
+ except Exception: pass
113
+
114
+ def diagnose_model(model):
115
+ """Print diagnostic info about the model"""
116
+ print("\nModel diagnostics:")
117
+ total_params = sum(p.numel() for p in model.parameters())
118
+ print(f" Total params: {total_params/1e6:,.0f}M")
119
+
120
+ # Check for key components
121
+ has_lm_head = False
122
+ lm_head_params = 0
123
+ lm_head_location = None
124
+
125
+ for name, param in model.named_parameters():
126
+ if "lm_head" in name:
127
+ has_lm_head = True
128
+ lm_head_params += param.numel()
129
+ lm_head_location = name
130
+
131
+ print(f" Has lm_head: {'✓' if has_lm_head else '✗'}")
132
+ if has_lm_head:
133
+ print(f" lm_head params: {lm_head_params/1e6:,.0f}M")
134
+ print(f" lm_head location: {lm_head_location}")
135
+
136
+ # Check if the params are actually counted in the total
137
+ lm = model.backbone.eagle_model.language_model
138
+ if hasattr(lm, 'lm_head') and lm.lm_head is not None:
139
+ actual_params = lm.lm_head.weight.numel()
140
+ print(f" lm_head actual params: {actual_params/1e6:,.0f}M")
141
+ print(f" lm_head weight shape: {lm.lm_head.weight.shape}")
142
+ print(f" lm_head weight dtype: {lm.lm_head.weight.dtype}")
143
+
144
+ def validate_model_architecture(model):
145
+ """Validate model against the architecture specification"""
146
+ print("\n" + "="*60)
147
+ print("ARCHITECTURE VALIDATION")
148
+ print("="*60)
149
+
150
+ # Expected architecture based on the spec
151
+ expected_shapes = {
152
+ # Key layers to check - using actual parameter names with .weight suffix
153
+ "backbone.eagle_model.language_model.lm_head.weight": (151680, 2048),
154
+ "backbone.eagle_model.language_model.model.embed_tokens.weight": (151680, 2048),
155
+ "backbone.eagle_model.language_model.model.norm.weight": (2048,),
156
+ "backbone.eagle_model.mlp1.0.weight": (2048, 1152),
157
+ "backbone.eagle_model.mlp1.0.bias": (2048,),
158
+ "action_head.position_embedding.weight": (1024, 1536), # Fixed: added .weight
159
+ "action_head.vlln.weight": (2048,),
160
+ "action_head.vlln.bias": (2048,),
161
+ }
162
+
163
+ errors = []
164
+ warnings = []
165
+
166
+ # Get all parameters
167
+ param_dict = dict(model.named_parameters())
168
+
169
+ # Debug: print actual action_head parameter names to see the pattern
170
+ action_head_params = [name for name in param_dict.keys() if name.startswith("action_head.position")]
171
+ if action_head_params:
172
+ print("\nFound position embedding parameters:")
173
+ for name in action_head_params[:5]:
174
+ print(f" {name}: {param_dict[name].shape}")
175
+
176
+ # Check key shapes
177
+ for name, expected_shape in expected_shapes.items():
178
+ if name in param_dict:
179
+ actual_shape = tuple(param_dict[name].shape)
180
+ if actual_shape != expected_shape:
181
+ errors.append(f"Shape mismatch for {name}: expected {expected_shape}, got {actual_shape}")
182
+ else:
183
+ print(f"✓ {name}: {actual_shape}")
184
+ else:
185
+ errors.append(f"Missing parameter: {name}")
186
+
187
+ # Check dtypes
188
+ dtype_issues = []
189
+ for name, param in param_dict.items():
190
+ if name.startswith("backbone."):
191
+ if param.dtype != torch.bfloat16:
192
+ dtype_issues.append(f"{name}: expected bfloat16, got {param.dtype}")
193
+ elif name.startswith("action_head."):
194
+ if param.dtype != torch.float32:
195
+ dtype_issues.append(f"{name}: expected float32, got {param.dtype}")
196
+
197
+ if dtype_issues:
198
+ warnings.extend(dtype_issues[:5]) # Only show first 5
199
+
200
+ # Count parameters by component
201
+ component_params = {
202
+ "backbone": 0,
203
+ "action_head": 0,
204
+ "other": 0
205
+ }
206
+
207
+ for name, param in param_dict.items():
208
+ count = param.numel()
209
+ if name.startswith("backbone."):
210
+ component_params["backbone"] += count
211
+ elif name.startswith("action_head."):
212
+ component_params["action_head"] += count
213
+ else:
214
+ component_params["other"] += count
215
+
216
+ # Special check for lm_head
217
+ lm_head_found = False
218
+ lm_head_params = 0
219
+ for name, param in param_dict.items():
220
+ if "lm_head" in name:
221
+ lm_head_found = True
222
+ lm_head_params += param.numel()
223
+
224
+ # Report results
225
+ print("\nValidation Results:")
226
+ print(f" Errors: {len(errors)}")
227
+ print(f" Warnings: {len(warnings)}")
228
+
229
+ if errors:
230
+ print("\n❌ ERRORS:")
231
+ for error in errors:
232
+ print(f" - {error}")
233
+
234
+ if warnings:
235
+ print("\n⚠️ WARNINGS (showing first 5):")
236
+ for warning in warnings[:5]:
237
+ print(f" - {warning}")
238
+ if len(warnings) > 5:
239
+ print(f" ... and {len(warnings) - 5} more")
240
+
241
+ print("\n📊 Parameter Summary:")
242
+ total = sum(component_params.values())
243
+ print(f" Total: {total/1e6:,.1f}M")
244
+ print(f" Backbone: {component_params['backbone']/1e6:,.1f}M")
245
+ print(f" Action Head: {component_params['action_head']/1e6:,.1f}M")
246
+ if component_params['other'] > 0:
247
+ print(f" Other: {component_params['other']/1e6:,.1f}M")
248
+
249
+ print(f"\n lm_head found: {'✓' if lm_head_found else '✗'}")
250
+ if lm_head_found:
251
+ print(f" lm_head params: {lm_head_params/1e6:.1f}M (expected: 311.1M)")
252
+
253
+ # Expected totals based on NVIDIA model
254
+ expected_total = 2724 # Million params
255
+ actual_total = total / 1e6
256
+ diff = actual_total - expected_total
257
+
258
+ print(f"\n Expected total: {expected_total}M")
259
+ print(f" Actual total: {actual_total:.1f}M")
260
+ print(f" Difference: {diff:+.1f}M")
261
+
262
+ if abs(diff) < 1: # Within 1M params
263
+ print("\n✅ Model architecture matches expected specification!")
264
+ else:
265
+ print("\n❌ Model architecture does NOT match specification!")
266
+
267
+ return len(errors) == 0
268
+
269
+ # ─────────────────── main ──────────────────────────────────────────────────
270
+ def main(device: str, out_dir: str):
271
+ print("="*60)
272
+ print("Creating blank GR00T-N1.5-3B model")
273
+ print("="*60)
274
+
275
+ model = build_blank()
276
+
277
+ # Add diagnostics before adding lm_head
278
+ print("\nBefore adding lm_head:")
279
+ diagnose_model(model)
280
+
281
+ maybe_add_lm_head(model)
282
+
283
+ # Add diagnostics after adding lm_head
284
+ print("\nAfter adding lm_head:")
285
+ diagnose_model(model)
286
+
287
+ set_mixed(model)
288
+ model = model.to(device)
289
+
290
+ # Validate against architecture spec
291
+ validate_model_architecture(model)
292
+
293
+ out = Path(out_dir).expanduser(); out.mkdir(parents=True, exist_ok=True)
294
+
295
+ print(f"\nSaving model to {out}...")
296
+ model.save_pretrained(out, max_shard_size="2GB")
297
+ copy_tokenizer(out)
298
+ (out/"README.md").write_text("Random GR00T-N1.5-3B | backbone bf16 | action_head fp32 | Apache-2.0\n")
299
+
300
+ # Final summary
301
+ print("\n" + "="*60)
302
+ print("FINAL SUMMARY")
303
+ print("="*60)
304
+ print(f"✅ Saved blank model ({sum(p.numel() for p in model.parameters())/1e6:,.0f}M params) → {out}")
305
+ print(f"✅ Model has lm_head with {model.backbone.eagle_model.language_model.lm_head.weight.numel()/1e6:.1f}M params")
306
+ print(f"✅ Ready for training with Apache-2.0 license")
307
+
308
+ # ─────────────────── CLI ───────────────────────────────────────────────────
309
+ if __name__ == "__main__":
310
+ ap = argparse.ArgumentParser()
311
+ ap.add_argument("--device", default="cpu")
312
+ ap.add_argument("--out_dir", default="OpenGR00T-N1.5-3B-Zero")
313
+ args = ap.parse_args(); main(args.device, args.out_dir)