|
import torch |
|
from transformers import AutoModelForCausalLM |
|
from tqdm import tqdm |
|
|
|
def copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path): |
|
""" |
|
Copy the language model weights from Qwen2.5-Coder-3B-Instruct into |
|
Qwen2.5-VL-3B-Instruct, preserving its vision-language components. |
|
""" |
|
|
|
print(f"Loading Qwen2.5-Coder-3B-Instruct model from {coder_model_id}...") |
|
coder_model = AutoModelForCausalLM.from_pretrained( |
|
coder_model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu" |
|
) |
|
|
|
print(f"Loading Qwen2.5-VL-3B-Instruct model from {vl_model_id}...") |
|
vl_model = AutoModelForCausalLM.from_pretrained( |
|
vl_model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu" |
|
) |
|
|
|
coder_state = coder_model.state_dict() |
|
vl_state = vl_model.state_dict() |
|
|
|
print("Copying language weights from Coder model to VL model...") |
|
|
|
updated_keys = 0 |
|
skipped_keys = [] |
|
|
|
for key in coder_state.keys(): |
|
|
|
if key.startswith("transformer."): |
|
if key in vl_state and coder_state[key].shape == vl_state[key].shape: |
|
vl_state[key] = coder_state[key].clone() |
|
updated_keys += 1 |
|
else: |
|
skipped_keys.append(key) |
|
|
|
print(f"✅ Updated {updated_keys} keys from Coder to VL.") |
|
if skipped_keys: |
|
print(f"⚠️ Skipped {len(skipped_keys)} keys due to shape mismatch or missing keys.") |
|
for key in skipped_keys[:5]: |
|
print(f" - Skipped: {key} (showing up to 5...)") |
|
|
|
print("Saving updated Qwen2.5-VL-3B-Instruct model...") |
|
vl_model.load_state_dict(vl_state) |
|
vl_model.save_pretrained(output_path, safe_serialization=True) |
|
|
|
print(f"✅ Model saved to: {output_path}") |
|
|
|
if __name__ == "__main__": |
|
coder_model_id = "Qwen/Qwen2.5-Coder-3B-Instruct" |
|
vl_model_id = "Qwen/Qwen2.5-VL-3B-Instruct" |
|
output_path = "./Qwen2.5-VL-3B-Instruct-CoderMerged" |
|
|
|
copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path) |
|
|