Upload init_opengr00t_zero.py
Browse files- init_opengr00t_zero.py +313 -0
init_opengr00t_zero.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# ─────────────────── make local repo override any wheel ────────────────────
|
3 |
+
import sys, os; sys.path.insert(0, os.path.abspath("."))
|
4 |
+
|
5 |
+
# ─────────────────── Flash-Attention + CUDA stubs ──────────────────────────
|
6 |
+
import types, torch, torch.nn.functional as F, importlib.machinery as im
|
7 |
+
flash_pkg = types.ModuleType("flash_attn"); flash_pkg.__spec__ = im.ModuleSpec("flash_attn", loader=None, is_package=True); flash_pkg.__path__=[]
|
8 |
+
sys.modules["flash_attn"] = flash_pkg
|
9 |
+
fa = types.ModuleType("flash_attn.flash_attn_interface"); fa.__spec__ = im.ModuleSpec("flash_attn.flash_attn_interface", loader=None)
|
10 |
+
def _sdpa(qkv,*_,causal=False,**__): q,k,v = qkv.unbind(1); q,k,v = (t.unsqueeze(0) for t in(q,k,v)); return F.scaled_dot_product_attention(q,k,v,is_causal=causal).squeeze(0)
|
11 |
+
for s in ("flash_attn_unpadded_qkvpacked_func","flash_attn_unpadded_kvpacked_func","flash_attn_varlen_qkvpacked_func","flash_attn_varlen_kvpacked_func"): setattr(fa, s, _sdpa)
|
12 |
+
sys.modules["flash_attn.flash_attn_interface"] = fa; flash_pkg.flash_attn_interface = fa
|
13 |
+
pad = types.ModuleType("flash_attn.bert_padding"); pad.__spec__ = im.ModuleSpec("flash_attn.bert_padding", loader=None)
|
14 |
+
pad.pad_input = lambda x,*a,**k:(x,None); pad.unpad_input = lambda x,*a,**k:x
|
15 |
+
sys.modules["flash_attn.bert_padding"] = pad; flash_pkg.bert_padding = pad
|
16 |
+
|
17 |
+
if not torch.cuda.is_available():
|
18 |
+
torch.cuda.is_available=lambda:False
|
19 |
+
torch.cuda.get_device_capability=lambda dev=None:(0,0)
|
20 |
+
torch.cuda.current_device=lambda:0
|
21 |
+
torch.cuda.get_device_properties=lambda dev=None:types.SimpleNamespace(major=0,minor=0)
|
22 |
+
|
23 |
+
import importlib.metadata as _im
|
24 |
+
if "flash_attn" not in _im.packages_distributions():
|
25 |
+
rv, rd = _im.version, _im.distribution
|
26 |
+
_im.version = lambda p:"0.0.0" if p=="flash_attn" else rv(p)
|
27 |
+
_im.distribution = lambda p:types.SimpleNamespace(version="0.0.0") if p=="flash_attn" else rd(p)
|
28 |
+
|
29 |
+
# ─────────────────── std imports ───────────────────────────────────────────
|
30 |
+
from pathlib import Path
|
31 |
+
import argparse, json, shutil
|
32 |
+
from huggingface_hub import hf_hub_download
|
33 |
+
from transformers import AutoConfig
|
34 |
+
from gr00t.model.gr00t_n1 import GR00T_N1_5
|
35 |
+
|
36 |
+
# ─────────────────── helpers ───────────────────────────────────────────────
|
37 |
+
def patched_cfg():
|
38 |
+
p = hf_hub_download("nvidia/GR00T-N1.5-3B", "config.json")
|
39 |
+
d = json.load(open(p))
|
40 |
+
if d.get("model_type") != "gr00t_n1_5":
|
41 |
+
d["model_type"] = "gr00t_n1_5"
|
42 |
+
patched = Path(p).with_name("config_patched.json")
|
43 |
+
patched.write_text(json.dumps(d)); return str(patched)
|
44 |
+
return p
|
45 |
+
|
46 |
+
def build_blank():
|
47 |
+
cfg = AutoConfig.from_pretrained(patched_cfg(),
|
48 |
+
trust_remote_code=True,
|
49 |
+
local_files_only=True)
|
50 |
+
cfg.backbone_cfg.update(dict(tune_llm=True)) # enable L-tower
|
51 |
+
cfg.backbone_cfg.pop("checkpoint_path", None)
|
52 |
+
cfg.backbone_cfg.pop("use_pretrained", None)
|
53 |
+
cfg.action_head_cfg.pop("checkpoint_path", None)
|
54 |
+
torch.manual_seed(0)
|
55 |
+
return GR00T_N1_5(cfg, local_model_path="") # random weights
|
56 |
+
|
57 |
+
def maybe_add_lm_head(model):
|
58 |
+
"""Ensure lm_head is properly initialized with weights"""
|
59 |
+
# Navigate to the language model
|
60 |
+
lm = model.backbone.eagle_model.language_model
|
61 |
+
|
62 |
+
# Get dimensions from embed_tokens
|
63 |
+
embed_tokens = lm.model.embed_tokens
|
64 |
+
vocab_size = embed_tokens.num_embeddings
|
65 |
+
hidden_size = embed_tokens.embedding_dim
|
66 |
+
|
67 |
+
print(f"Embedding dimensions: vocab_size={vocab_size}, hidden_size={hidden_size}")
|
68 |
+
|
69 |
+
# Expected shape based on architecture: [151680, 2048]
|
70 |
+
if vocab_size != 151680 or hidden_size != 2048:
|
71 |
+
print(f"⚠️ Warning: Unexpected dimensions. Expected vocab=151680, hidden=2048")
|
72 |
+
|
73 |
+
# Check if lm_head exists
|
74 |
+
if hasattr(lm, "lm_head"):
|
75 |
+
print(f"lm_head attribute exists: {lm.lm_head is not None}")
|
76 |
+
|
77 |
+
# Even if lm_head exists, it might not have weights properly initialized
|
78 |
+
# Just replace it with a properly initialized one
|
79 |
+
print("Creating new lm_head with proper initialization...")
|
80 |
+
else:
|
81 |
+
print("lm_head attribute missing, creating...")
|
82 |
+
|
83 |
+
# Create a new lm_head with proper initialization
|
84 |
+
# Note: nn.Linear uses (in_features, out_features), so it's (hidden_size, vocab_size)
|
85 |
+
new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
|
86 |
+
|
87 |
+
# Initialize weights with normal distribution (std=0.02 is standard for LM heads)
|
88 |
+
torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02)
|
89 |
+
|
90 |
+
# Convert to bfloat16 to match backbone
|
91 |
+
new_lm_head.weight.data = new_lm_head.weight.data.to(torch.bfloat16)
|
92 |
+
|
93 |
+
# Replace the lm_head
|
94 |
+
lm.lm_head = new_lm_head
|
95 |
+
|
96 |
+
print(f"✓ Created lm_head: Linear({hidden_size}, {vocab_size}, bias=False)")
|
97 |
+
print(f" Weight shape: {lm.lm_head.weight.shape}")
|
98 |
+
print(f" Weight dtype: {lm.lm_head.weight.dtype}")
|
99 |
+
print(f" Parameters: {lm.lm_head.weight.numel() / 1e6:.1f}M")
|
100 |
+
|
101 |
+
def set_mixed(model):
|
102 |
+
"""Set mixed precision: backbone in bf16, action head in fp32"""
|
103 |
+
for n,p in model.named_parameters():
|
104 |
+
if n.startswith("backbone.") or "lm_head" in n:
|
105 |
+
p.data = p.data.to(torch.bfloat16)
|
106 |
+
else:
|
107 |
+
p.data = p.data.to(torch.float32)
|
108 |
+
|
109 |
+
def copy_tokenizer(out):
|
110 |
+
for f in ("tokenizer.json","tokenizer_config.json","vocab.txt","special_tokens_map.json"):
|
111 |
+
try: shutil.copy(hf_hub_download("nvidia/GR00T-N1.5-3B", f), out/f)
|
112 |
+
except Exception: pass
|
113 |
+
|
114 |
+
def diagnose_model(model):
|
115 |
+
"""Print diagnostic info about the model"""
|
116 |
+
print("\nModel diagnostics:")
|
117 |
+
total_params = sum(p.numel() for p in model.parameters())
|
118 |
+
print(f" Total params: {total_params/1e6:,.0f}M")
|
119 |
+
|
120 |
+
# Check for key components
|
121 |
+
has_lm_head = False
|
122 |
+
lm_head_params = 0
|
123 |
+
lm_head_location = None
|
124 |
+
|
125 |
+
for name, param in model.named_parameters():
|
126 |
+
if "lm_head" in name:
|
127 |
+
has_lm_head = True
|
128 |
+
lm_head_params += param.numel()
|
129 |
+
lm_head_location = name
|
130 |
+
|
131 |
+
print(f" Has lm_head: {'✓' if has_lm_head else '✗'}")
|
132 |
+
if has_lm_head:
|
133 |
+
print(f" lm_head params: {lm_head_params/1e6:,.0f}M")
|
134 |
+
print(f" lm_head location: {lm_head_location}")
|
135 |
+
|
136 |
+
# Check if the params are actually counted in the total
|
137 |
+
lm = model.backbone.eagle_model.language_model
|
138 |
+
if hasattr(lm, 'lm_head') and lm.lm_head is not None:
|
139 |
+
actual_params = lm.lm_head.weight.numel()
|
140 |
+
print(f" lm_head actual params: {actual_params/1e6:,.0f}M")
|
141 |
+
print(f" lm_head weight shape: {lm.lm_head.weight.shape}")
|
142 |
+
print(f" lm_head weight dtype: {lm.lm_head.weight.dtype}")
|
143 |
+
|
144 |
+
def validate_model_architecture(model):
|
145 |
+
"""Validate model against the architecture specification"""
|
146 |
+
print("\n" + "="*60)
|
147 |
+
print("ARCHITECTURE VALIDATION")
|
148 |
+
print("="*60)
|
149 |
+
|
150 |
+
# Expected architecture based on the spec
|
151 |
+
expected_shapes = {
|
152 |
+
# Key layers to check - using actual parameter names with .weight suffix
|
153 |
+
"backbone.eagle_model.language_model.lm_head.weight": (151680, 2048),
|
154 |
+
"backbone.eagle_model.language_model.model.embed_tokens.weight": (151680, 2048),
|
155 |
+
"backbone.eagle_model.language_model.model.norm.weight": (2048,),
|
156 |
+
"backbone.eagle_model.mlp1.0.weight": (2048, 1152),
|
157 |
+
"backbone.eagle_model.mlp1.0.bias": (2048,),
|
158 |
+
"action_head.position_embedding.weight": (1024, 1536), # Fixed: added .weight
|
159 |
+
"action_head.vlln.weight": (2048,),
|
160 |
+
"action_head.vlln.bias": (2048,),
|
161 |
+
}
|
162 |
+
|
163 |
+
errors = []
|
164 |
+
warnings = []
|
165 |
+
|
166 |
+
# Get all parameters
|
167 |
+
param_dict = dict(model.named_parameters())
|
168 |
+
|
169 |
+
# Debug: print actual action_head parameter names to see the pattern
|
170 |
+
action_head_params = [name for name in param_dict.keys() if name.startswith("action_head.position")]
|
171 |
+
if action_head_params:
|
172 |
+
print("\nFound position embedding parameters:")
|
173 |
+
for name in action_head_params[:5]:
|
174 |
+
print(f" {name}: {param_dict[name].shape}")
|
175 |
+
|
176 |
+
# Check key shapes
|
177 |
+
for name, expected_shape in expected_shapes.items():
|
178 |
+
if name in param_dict:
|
179 |
+
actual_shape = tuple(param_dict[name].shape)
|
180 |
+
if actual_shape != expected_shape:
|
181 |
+
errors.append(f"Shape mismatch for {name}: expected {expected_shape}, got {actual_shape}")
|
182 |
+
else:
|
183 |
+
print(f"✓ {name}: {actual_shape}")
|
184 |
+
else:
|
185 |
+
errors.append(f"Missing parameter: {name}")
|
186 |
+
|
187 |
+
# Check dtypes
|
188 |
+
dtype_issues = []
|
189 |
+
for name, param in param_dict.items():
|
190 |
+
if name.startswith("backbone."):
|
191 |
+
if param.dtype != torch.bfloat16:
|
192 |
+
dtype_issues.append(f"{name}: expected bfloat16, got {param.dtype}")
|
193 |
+
elif name.startswith("action_head."):
|
194 |
+
if param.dtype != torch.float32:
|
195 |
+
dtype_issues.append(f"{name}: expected float32, got {param.dtype}")
|
196 |
+
|
197 |
+
if dtype_issues:
|
198 |
+
warnings.extend(dtype_issues[:5]) # Only show first 5
|
199 |
+
|
200 |
+
# Count parameters by component
|
201 |
+
component_params = {
|
202 |
+
"backbone": 0,
|
203 |
+
"action_head": 0,
|
204 |
+
"other": 0
|
205 |
+
}
|
206 |
+
|
207 |
+
for name, param in param_dict.items():
|
208 |
+
count = param.numel()
|
209 |
+
if name.startswith("backbone."):
|
210 |
+
component_params["backbone"] += count
|
211 |
+
elif name.startswith("action_head."):
|
212 |
+
component_params["action_head"] += count
|
213 |
+
else:
|
214 |
+
component_params["other"] += count
|
215 |
+
|
216 |
+
# Special check for lm_head
|
217 |
+
lm_head_found = False
|
218 |
+
lm_head_params = 0
|
219 |
+
for name, param in param_dict.items():
|
220 |
+
if "lm_head" in name:
|
221 |
+
lm_head_found = True
|
222 |
+
lm_head_params += param.numel()
|
223 |
+
|
224 |
+
# Report results
|
225 |
+
print("\nValidation Results:")
|
226 |
+
print(f" Errors: {len(errors)}")
|
227 |
+
print(f" Warnings: {len(warnings)}")
|
228 |
+
|
229 |
+
if errors:
|
230 |
+
print("\n❌ ERRORS:")
|
231 |
+
for error in errors:
|
232 |
+
print(f" - {error}")
|
233 |
+
|
234 |
+
if warnings:
|
235 |
+
print("\n⚠️ WARNINGS (showing first 5):")
|
236 |
+
for warning in warnings[:5]:
|
237 |
+
print(f" - {warning}")
|
238 |
+
if len(warnings) > 5:
|
239 |
+
print(f" ... and {len(warnings) - 5} more")
|
240 |
+
|
241 |
+
print("\n📊 Parameter Summary:")
|
242 |
+
total = sum(component_params.values())
|
243 |
+
print(f" Total: {total/1e6:,.1f}M")
|
244 |
+
print(f" Backbone: {component_params['backbone']/1e6:,.1f}M")
|
245 |
+
print(f" Action Head: {component_params['action_head']/1e6:,.1f}M")
|
246 |
+
if component_params['other'] > 0:
|
247 |
+
print(f" Other: {component_params['other']/1e6:,.1f}M")
|
248 |
+
|
249 |
+
print(f"\n lm_head found: {'✓' if lm_head_found else '✗'}")
|
250 |
+
if lm_head_found:
|
251 |
+
print(f" lm_head params: {lm_head_params/1e6:.1f}M (expected: 311.1M)")
|
252 |
+
|
253 |
+
# Expected totals based on NVIDIA model
|
254 |
+
expected_total = 2724 # Million params
|
255 |
+
actual_total = total / 1e6
|
256 |
+
diff = actual_total - expected_total
|
257 |
+
|
258 |
+
print(f"\n Expected total: {expected_total}M")
|
259 |
+
print(f" Actual total: {actual_total:.1f}M")
|
260 |
+
print(f" Difference: {diff:+.1f}M")
|
261 |
+
|
262 |
+
if abs(diff) < 1: # Within 1M params
|
263 |
+
print("\n✅ Model architecture matches expected specification!")
|
264 |
+
else:
|
265 |
+
print("\n❌ Model architecture does NOT match specification!")
|
266 |
+
|
267 |
+
return len(errors) == 0
|
268 |
+
|
269 |
+
# ─────────────────── main ──────────────────────────────────────────────────
|
270 |
+
def main(device: str, out_dir: str):
|
271 |
+
print("="*60)
|
272 |
+
print("Creating blank GR00T-N1.5-3B model")
|
273 |
+
print("="*60)
|
274 |
+
|
275 |
+
model = build_blank()
|
276 |
+
|
277 |
+
# Add diagnostics before adding lm_head
|
278 |
+
print("\nBefore adding lm_head:")
|
279 |
+
diagnose_model(model)
|
280 |
+
|
281 |
+
maybe_add_lm_head(model)
|
282 |
+
|
283 |
+
# Add diagnostics after adding lm_head
|
284 |
+
print("\nAfter adding lm_head:")
|
285 |
+
diagnose_model(model)
|
286 |
+
|
287 |
+
set_mixed(model)
|
288 |
+
model = model.to(device)
|
289 |
+
|
290 |
+
# Validate against architecture spec
|
291 |
+
validate_model_architecture(model)
|
292 |
+
|
293 |
+
out = Path(out_dir).expanduser(); out.mkdir(parents=True, exist_ok=True)
|
294 |
+
|
295 |
+
print(f"\nSaving model to {out}...")
|
296 |
+
model.save_pretrained(out, max_shard_size="2GB")
|
297 |
+
copy_tokenizer(out)
|
298 |
+
(out/"README.md").write_text("Random GR00T-N1.5-3B | backbone bf16 | action_head fp32 | Apache-2.0\n")
|
299 |
+
|
300 |
+
# Final summary
|
301 |
+
print("\n" + "="*60)
|
302 |
+
print("FINAL SUMMARY")
|
303 |
+
print("="*60)
|
304 |
+
print(f"✅ Saved blank model ({sum(p.numel() for p in model.parameters())/1e6:,.0f}M params) → {out}")
|
305 |
+
print(f"✅ Model has lm_head with {model.backbone.eagle_model.language_model.lm_head.weight.numel()/1e6:.1f}M params")
|
306 |
+
print(f"✅ Ready for training with Apache-2.0 license")
|
307 |
+
|
308 |
+
# ─────────────────── CLI ───────────────────────────────────────────────────
|
309 |
+
if __name__ == "__main__":
|
310 |
+
ap = argparse.ArgumentParser()
|
311 |
+
ap.add_argument("--device", default="cpu")
|
312 |
+
ap.add_argument("--out_dir", default="OpenGR00T-N1.5-3B-Zero")
|
313 |
+
args = ap.parse_args(); main(args.device, args.out_dir)
|