import torch from transformers import AutoModelForCausalLM from tqdm import tqdm def copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path): """ Copy the language model weights from Qwen2.5-Coder-3B-Instruct into Qwen2.5-VL-3B-Instruct, preserving its vision-language components. """ print(f"Loading Qwen2.5-Coder-3B-Instruct model from {coder_model_id}...") coder_model = AutoModelForCausalLM.from_pretrained( coder_model_id, torch_dtype=torch.bfloat16, device_map="cpu" ) print(f"Loading Qwen2.5-VL-3B-Instruct model from {vl_model_id}...") vl_model = AutoModelForCausalLM.from_pretrained( vl_model_id, torch_dtype=torch.bfloat16, device_map="cpu" ) coder_state = coder_model.state_dict() vl_state = vl_model.state_dict() print("Copying language weights from Coder model to VL model...") updated_keys = 0 skipped_keys = [] for key in coder_state.keys(): # Focus on the shared transformer block if key.startswith("transformer."): if key in vl_state and coder_state[key].shape == vl_state[key].shape: vl_state[key] = coder_state[key].clone() updated_keys += 1 else: skipped_keys.append(key) print(f"✅ Updated {updated_keys} keys from Coder to VL.") if skipped_keys: print(f"⚠️ Skipped {len(skipped_keys)} keys due to shape mismatch or missing keys.") for key in skipped_keys[:5]: print(f" - Skipped: {key} (showing up to 5...)") print("Saving updated Qwen2.5-VL-3B-Instruct model...") vl_model.load_state_dict(vl_state) vl_model.save_pretrained(output_path, safe_serialization=True) print(f"✅ Model saved to: {output_path}") if __name__ == "__main__": coder_model_id = "Qwen/Qwen2.5-Coder-3B-Instruct" vl_model_id = "Qwen/Qwen2.5-VL-3B-Instruct" output_path = "./Qwen2.5-VL-3B-Instruct-CoderMerged" copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path)