statisticalplumber commited on
Commit
a326400
·
verified ·
1 Parent(s): a2eb489

Create merge_qwen_coder_to_vision.py

Browse files
Files changed (1) hide show
  1. merge_qwen_coder_to_vision.py +59 -0
merge_qwen_coder_to_vision.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM
3
+ from tqdm import tqdm
4
+
5
+ def copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path):
6
+ """
7
+ Copy the language model weights from Qwen2.5-Coder-3B-Instruct into
8
+ Qwen2.5-VL-3B-Instruct, preserving its vision-language components.
9
+ """
10
+
11
+ print(f"Loading Qwen2.5-Coder-3B-Instruct model from {coder_model_id}...")
12
+ coder_model = AutoModelForCausalLM.from_pretrained(
13
+ coder_model_id,
14
+ torch_dtype=torch.bfloat16,
15
+ device_map="cpu"
16
+ )
17
+
18
+ print(f"Loading Qwen2.5-VL-3B-Instruct model from {vl_model_id}...")
19
+ vl_model = AutoModelForCausalLM.from_pretrained(
20
+ vl_model_id,
21
+ torch_dtype=torch.bfloat16,
22
+ device_map="cpu"
23
+ )
24
+
25
+ coder_state = coder_model.state_dict()
26
+ vl_state = vl_model.state_dict()
27
+
28
+ print("Copying language weights from Coder model to VL model...")
29
+
30
+ updated_keys = 0
31
+ skipped_keys = []
32
+
33
+ for key in coder_state.keys():
34
+ # Focus on the shared transformer block
35
+ if key.startswith("transformer."):
36
+ if key in vl_state and coder_state[key].shape == vl_state[key].shape:
37
+ vl_state[key] = coder_state[key].clone()
38
+ updated_keys += 1
39
+ else:
40
+ skipped_keys.append(key)
41
+
42
+ print(f"✅ Updated {updated_keys} keys from Coder to VL.")
43
+ if skipped_keys:
44
+ print(f"⚠️ Skipped {len(skipped_keys)} keys due to shape mismatch or missing keys.")
45
+ for key in skipped_keys[:5]:
46
+ print(f" - Skipped: {key} (showing up to 5...)")
47
+
48
+ print("Saving updated Qwen2.5-VL-3B-Instruct model...")
49
+ vl_model.load_state_dict(vl_state)
50
+ vl_model.save_pretrained(output_path, safe_serialization=True)
51
+
52
+ print(f"✅ Model saved to: {output_path}")
53
+
54
+ if __name__ == "__main__":
55
+ coder_model_id = "Qwen/Qwen2.5-Coder-3B-Instruct"
56
+ vl_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
57
+ output_path = "./Qwen2.5-VL-3B-Instruct-CoderMerged"
58
+
59
+ copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path)