ehartford commited on
Commit
ea3f081
·
verified ·
1 Parent(s): 3bd1be9

Create qwen2to3_diagnostic.py

Browse files
Files changed (1) hide show
  1. qwen2to3_diagnostic.py +222 -0
qwen2to3_diagnostic.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: qwen2to3_diagnostic.py
2
+
3
+ import torch
4
+ import os
5
+ import json
6
+ import re
7
+ from datetime import datetime
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
10
+ from transformers import Qwen3Config, Qwen3ForCausalLM
11
+ from collections import Counter
12
+
13
+ # --- DIAGNOSTIC HELPERS ---
14
+
15
+ def log_tensor_stats(tensor, name):
16
+ """Prints statistics for a given tensor."""
17
+ if tensor.numel() == 0:
18
+ print(f" - DIAGNOSTIC STATS for '{name}': Tensor is empty.")
19
+ return
20
+ print(
21
+ f" - DIAGNOSTIC STATS for '{name}':\n"
22
+ f" - Shape: {tensor.shape}, Dtype: {tensor.dtype}\n"
23
+ f" - Mean: {tensor.float().mean().item():.4f}, Std: {tensor.float().std().item():.4f}\n"
24
+ f" - Min: {tensor.float().min().item():.4f}, Max: {tensor.float().max().item():.4f}\n"
25
+ f" - Has NaN: {torch.isnan(tensor).any().item()}, Has Inf: {torch.isinf(tensor).any().item()}"
26
+ )
27
+
28
+ def verify_embedding_transfer(s_embeds, t_embeds, mapping, t_tokenizer, num_samples=5):
29
+ """Verifies that some shared token embeddings were copied correctly."""
30
+ print("\n - DIAGNOSTIC: Verifying embedding transfer for sample tokens...")
31
+ verified_count = 0
32
+ for t_id, s_id in mapping.items():
33
+ if verified_count >= num_samples:
34
+ break
35
+ if s_id != -1:
36
+ token = t_tokenizer.convert_ids_to_tokens(t_id)
37
+ source_vec = s_embeds[s_id]
38
+ target_vec = t_embeds[t_id]
39
+ diff = torch.sum(torch.abs(source_vec - target_vec)).item()
40
+ if diff < 1e-6:
41
+ print(f" - ✓ Token '{token}' (ID {t_id}) transferred successfully (diff: {diff:.2e}).")
42
+ else:
43
+ print(f" - ✗ FAILED: Token '{token}' (ID {t_id}) has a large difference after transfer (diff: {diff:.2e}).")
44
+ verified_count += 1
45
+
46
+ def verify_grafted_layer(target_state_dict, donor_state_dict, target_layer_idx, donor_layer_idx):
47
+ """Verifies that cyclical grafting for q_norm/k_norm worked."""
48
+ print(f"\n - DIAGNOSTIC: Verifying cyclical graft for target layer {target_layer_idx} from donor layer {donor_layer_idx}...")
49
+ for norm_type in ['q_norm', 'k_norm']:
50
+ target_key = f'model.layers.{target_layer_idx}.self_attn.{norm_type}.weight'
51
+ donor_key = f'model.layers.{donor_layer_idx}.self_attn.{norm_type}.weight'
52
+ diff = torch.sum(torch.abs(target_state_dict[target_key] - donor_state_dict[donor_key])).item()
53
+ if diff < 1e-6:
54
+ print(f" - ✓ {norm_type} weights match (diff: {diff:.2e}).")
55
+ else:
56
+ print(f" - ✗ FAILED: {norm_type} weights DO NOT match (diff: {diff:.2e}).")
57
+
58
+ def check_for_nan_inf(state_dict):
59
+ """Scans the entire state_dict for NaN or Inf values."""
60
+ print("\n - DIAGNOSTIC: Scanning final state dictionary for NaN/Inf values...")
61
+ found_issue = False
62
+ for key, tensor in tqdm(state_dict.items(), desc="Scanning tensors"):
63
+ if torch.isnan(tensor).any() or torch.isinf(tensor).any():
64
+ print(f" - ✗ FAILED: Found NaN or Inf in tensor '{key}'!")
65
+ found_issue = True
66
+ if not found_issue:
67
+ print(" - ✓ All tensors in the final state dictionary are clean.")
68
+ return not found_issue
69
+
70
+ # --- STANDARD HELPERS ---
71
+
72
+ def create_vocab_mapping(s_tok, t_tok):
73
+ # ... (code is unchanged)
74
+ s_vocab, t_vocab = s_tok.get_vocab(), t_tok.get_vocab()
75
+ s_tok_to_id = {t: i for t, i in s_vocab.items()}
76
+ mapping = {t_id: s_tok_to_id.get(t, -1) for t, t_id in t_vocab.items()}
77
+ matches = sum(1 for v in mapping.values() if v != -1)
78
+ print(f"Vocabulary overlap: {matches}/{len(t_vocab)} tokens ({matches/len(t_vocab)*100:.1f}%) will be transferred.")
79
+ return mapping
80
+
81
+ def verify_special_tokens(s_tok, t_tok, mapping):
82
+ # ... (code is unchanged)
83
+ print("\nVerifying special token mappings...")
84
+ for name, token_value in t_tok.special_tokens_map.items():
85
+ def _process_token(token_str):
86
+ if token_str and token_str in t_tok.get_vocab():
87
+ t_id = t_tok.convert_tokens_to_ids(token_str)
88
+ s_id = mapping.get(t_id, -1)
89
+ status = f"Mapped (T: {t_id} -> S: {s_id})" if s_id != -1 else "NOT FOUND in source (initialized with mean)"
90
+ print(f" ✓ ('{token_str}'): {status}")
91
+ if isinstance(token_value, str): _process_token(token_value)
92
+ elif isinstance(token_value, list):
93
+ for token_str_in_list in token_value: _process_token(token_str_in_list)
94
+
95
+ def create_hybrid_matrix(s_matrix, mapping, shape):
96
+ # ... (code is unchanged, but we'll add logging inside the main function)
97
+ mean_embedding = s_matrix.mean(dim=0, keepdim=True)
98
+ hybrid = torch.zeros(shape, dtype=s_matrix.dtype, device='cpu')
99
+ for t_id, s_id in mapping.items():
100
+ hybrid[t_id] = s_matrix[s_id] if s_id != -1 else mean_embedding
101
+ return hybrid.to(s_matrix.device)
102
+
103
+ def validate_model_diagnostic(path):
104
+ print("\n[Step 6/6] Running DIAGNOSTIC validation...")
105
+ try:
106
+ tokenizer = AutoTokenizer.from_pretrained(path)
107
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.bfloat16)
108
+ model.eval()
109
+ prompt = "The theory of relativity states that"
110
+ print(f"\nValidation Prompt: '{prompt}'")
111
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
112
+ with torch.no_grad():
113
+ outputs = model.generate(**inputs, max_new_tokens=25, do_sample=False, pad_token_id=tokenizer.eos_token_id)
114
+
115
+ print("\n--- DIAGNOSTIC: RAW TOKEN IDs ---")
116
+ print(outputs[0].tolist())
117
+
118
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
119
+ print("\n--- DIAGNOSTIC: Decoded Response ---")
120
+ print(f"'{response}'")
121
+
122
+ if '�' in response:
123
+ print("\n - ✗ VALIDATION FAILED: Found replacement character '�' in output. This indicates a tokenization/decoding issue.")
124
+ return
125
+
126
+ # A more robust check for coherence
127
+ if "states that states that" in response or "the the the" in response or len(set(response.split())) < 5:
128
+ print("\n - ✗ VALIDATION FAILED: Output appears repetitive or incoherent.")
129
+ else:
130
+ print("\n - ✓ Validation check passed: Model generated non-repetitive text.")
131
+
132
+ except Exception as e:
133
+ print(f"\n ✗ Validation FAILED with an exception: {e}")
134
+
135
+ # --- Main Conversion Logic ---
136
+ def convert_qwen2_to_qwen3_diagnostic():
137
+ source_model_id, donor_model_id = "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen3-32B"
138
+ target_model_path = "./Qwen3-72B"
139
+ print("Starting DIAGNOSTIC conversion process (v6.0)...")
140
+
141
+ # --- Step 1 & 2: Load everything ---
142
+ s_config = AutoConfig.from_pretrained(source_model_id)
143
+ d_config = AutoConfig.from_pretrained(donor_model_id)
144
+ dtype = torch.bfloat16
145
+ s_model = AutoModelForCausalLM.from_pretrained(source_model_id, torch_dtype=dtype, device_map="auto")
146
+ d_model = AutoModelForCausalLM.from_pretrained(donor_model_id, torch_dtype=dtype, device_map="auto")
147
+ s_tokenizer = AutoTokenizer.from_pretrained(source_model_id)
148
+ t_tokenizer = AutoTokenizer.from_pretrained(donor_model_id)
149
+
150
+ # --- Step 3: Create Target ---
151
+ t_config = Qwen3Config(hidden_size=s_config.hidden_size, intermediate_size=s_config.intermediate_size, num_hidden_layers=s_config.num_hidden_layers, num_attention_heads=s_config.num_attention_heads, num_key_value_heads=s_config.num_key_value_heads, max_position_embeddings=s_config.max_position_embeddings, max_window_layers=s_config.max_window_layers, sliding_window=s_config.sliding_window, attention_bias=d_config.attention_bias, hidden_act=d_config.hidden_act, initializer_range=d_config.initializer_range, rms_norm_eps=d_config.rms_norm_eps, rope_theta=d_config.rope_theta, vocab_size=d_config.vocab_size, tie_word_embeddings=True)
152
+ with torch.device("meta"): t_model = Qwen3ForCausalLM(t_config)
153
+
154
+ # --- Step 4: Convert and DIAGNOSE Weights ---
155
+ print("\n[Step 4/6] Converting weights (DIAGNOSTIC mode)...")
156
+ s_state_dict = {k: v.cpu() for k, v in tqdm(s_model.state_dict().items(), desc="Source state dict to CPU")}
157
+ d_state_dict = {k: v.cpu() for k, v in tqdm(d_model.state_dict().items(), desc="Donor state dict to CPU")}
158
+ vocab_mapping = create_vocab_mapping(s_tokenizer, t_tokenizer)
159
+ verify_special_tokens(s_tokenizer, t_tokenizer, vocab_mapping)
160
+
161
+ new_state_dict = {}
162
+ num_donor_layers = d_config.num_hidden_layers
163
+
164
+ # --- Create and Diagnose Hybrid Embeddings ---
165
+ print("\n--- Creating and Diagnosing Embedding and LM Head ---")
166
+
167
+ # Embeddings
168
+ print("Processing model.embed_tokens.weight...")
169
+ s_embeds = s_state_dict['model.embed_tokens.weight']
170
+ mean_embedding = s_embeds.mean(dim=0, keepdim=True)
171
+ log_tensor_stats(mean_embedding, "Mean Initializer Vector")
172
+ new_embed_matrix = create_hybrid_matrix(s_embeds, vocab_mapping, (t_config.vocab_size, t_config.hidden_size))
173
+ log_tensor_stats(new_embed_matrix, "Final Hybrid Embedding Matrix")
174
+ new_state_dict['model.embed_tokens.weight'] = new_embed_matrix
175
+ verify_embedding_transfer(s_embeds, new_embed_matrix, vocab_mapping, t_tokenizer)
176
+
177
+ # LM Head
178
+ print("\nProcessing lm_head.weight...")
179
+ s_lm_head = s_state_dict['lm_head.weight']
180
+ new_lm_head_matrix = create_hybrid_matrix(s_lm_head, vocab_mapping, (t_config.vocab_size, t_config.hidden_size))
181
+ log_tensor_stats(new_lm_head_matrix, "Final Hybrid LM Head Matrix")
182
+ new_state_dict['lm_head.weight'] = new_lm_head_matrix
183
+
184
+ # --- Process remaining layers ---
185
+ print("\n--- Processing Transformer Layers ---")
186
+ # Get all keys except the ones we already handled
187
+ remaining_keys = [k for k in t_model.state_dict().keys() if 'embed_tokens' not in k and 'lm_head' not in k]
188
+
189
+ for key in tqdm(remaining_keys, desc="Transferring layer weights"):
190
+ if "q_norm" in key or "k_norm" in key:
191
+ match = re.search(r'layers\.(\d+)\.', key)
192
+ if match:
193
+ target_layer_idx = int(match.group(1))
194
+ donor_layer_idx = target_layer_idx % num_donor_layers
195
+ donor_key = key.replace(f'layers.{target_layer_idx}.', f'layers.{donor_layer_idx}.')
196
+ new_state_dict[key] = d_state_dict[donor_key].clone()
197
+ elif key in s_state_dict:
198
+ new_state_dict[key] = s_state_dict[key].clone()
199
+ else:
200
+ print(f" ⚠️ Unhandled key: {key} (not in source, skipping)")
201
+
202
+ # --- Final Diagnostic Checks ---
203
+ verify_grafted_layer(new_state_dict, d_state_dict, target_layer_idx=num_donor_layers, donor_layer_idx=0)
204
+ all_clean = check_for_nan_inf(new_state_dict)
205
+ if not all_clean:
206
+ print("\nCRITICAL ERROR: NaN/Inf values detected. Aborting before save.")
207
+ return # Stop the process
208
+
209
+ # --- Step 5: Save ---
210
+ print("\n[Step 5/6] Saving final model and supporting files...")
211
+ t_model.load_state_dict(new_state_dict) # Load into meta-model
212
+ t_model.save_pretrained(target_model_path, safe_serialization=True, state_dict=new_state_dict)
213
+ t_tokenizer.save_pretrained(target_model_path)
214
+ print(f"✅ Model saved to: {target_model_path}")
215
+
216
+ # --- Step 6: Validate ---
217
+ del s_model, d_model, s_state_dict, d_state_dict, new_state_dict, t_model
218
+ torch.cuda.empty_cache()
219
+ validate_model_diagnostic(path=target_model_path)
220
+
221
+ if __name__ == "__main__":
222
+ convert_qwen2_to_qwen3_diagnostic()