|
import torch |
|
from mix_vision_transformer_config import MySegformerConfig |
|
from modeling_my_segformer import MySegformerForSemanticSegmentation |
|
|
|
def convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir): |
|
|
|
mm_ckpt = torch.load(mm_checkpoint_path, map_location="cpu") |
|
if 'state_dict' in mm_ckpt: |
|
mm_state_dict = mm_ckpt['state_dict'] |
|
else: |
|
mm_state_dict = mm_ckpt |
|
|
|
|
|
config = MySegformerConfig( |
|
embed_dims=[64, 128, 320, 512], |
|
num_stages=4, |
|
num_layers=[3, 4, 6, 3], |
|
num_heads=[1, 2, 4, 8], |
|
patch_sizes=[7, 3, 3, 3], |
|
strides=[4, 2, 2, 2], |
|
sr_ratios=[8, 4, 2, 1], |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
drop_rate=0.0, |
|
attn_drop_rate=0.0, |
|
drop_path_rate=0.0, |
|
out_indices=(0, 1, 2, 3), |
|
num_classes=19 |
|
) |
|
|
|
model = MySegformerForSemanticSegmentation(config) |
|
|
|
|
|
hf_state_dict = {} |
|
|
|
for k, v in mm_state_dict.items(): |
|
new_k = k |
|
|
|
|
|
if new_k.startswith("module."): |
|
new_k = new_k[len("module."):] |
|
|
|
|
|
if new_k.startswith("decode_head."): |
|
new_k = new_k.replace("decode_head.", "segmentation_head.") |
|
|
|
|
|
new_k = new_k.replace(".bn.", ".") |
|
|
|
|
|
if new_k not in model.state_dict(): |
|
print(f"⚠️ Ignoriere {new_k} (nicht im HF-Modell)") |
|
continue |
|
|
|
hf_state_dict[new_k] = v |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False) |
|
|
|
print("Missing keys:", missing_keys) |
|
print("Unexpected keys:", unexpected_keys) |
|
|
|
|
|
model.save_pretrained(hf_save_dir) |
|
config.save_pretrained(hf_save_dir) |
|
|
|
print(f"✅ Model und Config erfolgreich gespeichert in {hf_save_dir}") |
|
|
|
|
|
pth_path = hf_save_dir.rstrip("/") + ".pth" |
|
torch.save(model.state_dict(), pth_path) |
|
print(f"✅ Zusätzlich als .pth gespeichert unter {pth_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
mm_checkpoint_path = "./segformer-b5-bsm_hq.pth" |
|
hf_save_dir = "hf_segformer_converted" |
|
|
|
convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir) |
|
|