Qwen2.5-VL-3B-Instruct-CoderMerged / merge_qwen_coder_to_vision.py
statisticalplumber's picture
Create merge_qwen_coder_to_vision.py
a326400 verified
import torch
from transformers import AutoModelForCausalLM
from tqdm import tqdm
def copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path):
"""
Copy the language model weights from Qwen2.5-Coder-3B-Instruct into
Qwen2.5-VL-3B-Instruct, preserving its vision-language components.
"""
print(f"Loading Qwen2.5-Coder-3B-Instruct model from {coder_model_id}...")
coder_model = AutoModelForCausalLM.from_pretrained(
coder_model_id,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
print(f"Loading Qwen2.5-VL-3B-Instruct model from {vl_model_id}...")
vl_model = AutoModelForCausalLM.from_pretrained(
vl_model_id,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
coder_state = coder_model.state_dict()
vl_state = vl_model.state_dict()
print("Copying language weights from Coder model to VL model...")
updated_keys = 0
skipped_keys = []
for key in coder_state.keys():
# Focus on the shared transformer block
if key.startswith("transformer."):
if key in vl_state and coder_state[key].shape == vl_state[key].shape:
vl_state[key] = coder_state[key].clone()
updated_keys += 1
else:
skipped_keys.append(key)
print(f"✅ Updated {updated_keys} keys from Coder to VL.")
if skipped_keys:
print(f"⚠️ Skipped {len(skipped_keys)} keys due to shape mismatch or missing keys.")
for key in skipped_keys[:5]:
print(f" - Skipped: {key} (showing up to 5...)")
print("Saving updated Qwen2.5-VL-3B-Instruct model...")
vl_model.load_state_dict(vl_state)
vl_model.save_pretrained(output_path, safe_serialization=True)
print(f"✅ Model saved to: {output_path}")
if __name__ == "__main__":
coder_model_id = "Qwen/Qwen2.5-Coder-3B-Instruct"
vl_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
output_path = "./Qwen2.5-VL-3B-Instruct-CoderMerged"
copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path)