import torch from mix_vision_transformer_config import MySegformerConfig from modeling_my_segformer import MySegformerForSemanticSegmentation def convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir): # 1. Lade mmengine checkpoint mm_ckpt = torch.load(mm_checkpoint_path, map_location="cpu") if 'state_dict' in mm_ckpt: mm_state_dict = mm_ckpt['state_dict'] else: mm_state_dict = mm_ckpt # 2. Erstelle Config & Modell (achte darauf, dass Config-Parameter zum Checkpoint passen) config = MySegformerConfig( embed_dims=[64, 128, 320, 512], # <--- korrekte Liste mit 4 Werten num_stages=4, num_layers=[3, 4, 6, 3], num_heads=[1, 2, 4, 8], patch_sizes=[7, 3, 3, 3], strides=[4, 2, 2, 2], sr_ratios=[8, 4, 2, 1], mlp_ratio=4, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, out_indices=(0, 1, 2, 3), num_classes=19 ) model = MySegformerForSemanticSegmentation(config) # 3. Mappe mmengine Keys auf HF Keys hf_state_dict = {} for k, v in mm_state_dict.items(): new_k = k # Falls "module." als Prefix da ist (DataParallel), entfernen if new_k.startswith("module."): new_k = new_k[len("module."):] # Mapping von decode_head.* -> segmentation_head.* if new_k.startswith("decode_head."): new_k = new_k.replace("decode_head.", "segmentation_head.") # BatchNorm-Namen vereinheitlichen new_k = new_k.replace(".bn.", ".") # Nur Keys übernehmen, die im HF-Modell existieren if new_k not in model.state_dict(): print(f"⚠️ Ignoriere {new_k} (nicht im HF-Modell)") continue hf_state_dict[new_k] = v # 4. Lade die Gewichte ins Modell missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False) print("Missing keys:", missing_keys) print("Unexpected keys:", unexpected_keys) # 5. Speichere das HF-kompatible Modell & Config model.save_pretrained(hf_save_dir) config.save_pretrained(hf_save_dir) print(f"✅ Model und Config erfolgreich gespeichert in {hf_save_dir}") # 5b. Auch als klassische .pth-Datei speichern pth_path = hf_save_dir.rstrip("/") + ".pth" torch.save(model.state_dict(), pth_path) print(f"✅ Zusätzlich als .pth gespeichert unter {pth_path}") if __name__ == "__main__": mm_checkpoint_path = "./segformer-b5-bsm_hq.pth" hf_save_dir = "hf_segformer_converted" convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir)