SegformerPlusPlus / mmengineToHFCheckpoint.py
Tim77777767
Anpassungen für HF, Checkpoint umgewandelt, config.json angepasst
1a260cd
import torch
from mix_vision_transformer_config import MySegformerConfig
from modeling_my_segformer import MySegformerForSemanticSegmentation
def convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir):
# 1. Lade mmengine checkpoint
mm_ckpt = torch.load(mm_checkpoint_path, map_location="cpu")
if 'state_dict' in mm_ckpt:
mm_state_dict = mm_ckpt['state_dict']
else:
mm_state_dict = mm_ckpt
# 2. Erstelle Config & Modell (achte darauf, dass Config-Parameter zum Checkpoint passen)
config = MySegformerConfig(
embed_dims=[64, 128, 320, 512], # <--- korrekte Liste mit 4 Werten
num_stages=4,
num_layers=[3, 4, 6, 3],
num_heads=[1, 2, 4, 8],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
out_indices=(0, 1, 2, 3),
num_classes=19
)
model = MySegformerForSemanticSegmentation(config)
# 3. Mappe mmengine Keys auf HF Keys
hf_state_dict = {}
for k, v in mm_state_dict.items():
new_k = k
# Falls "module." als Prefix da ist (DataParallel), entfernen
if new_k.startswith("module."):
new_k = new_k[len("module."):]
# Mapping von decode_head.* -> segmentation_head.*
if new_k.startswith("decode_head."):
new_k = new_k.replace("decode_head.", "segmentation_head.")
# BatchNorm-Namen vereinheitlichen
new_k = new_k.replace(".bn.", ".")
# Nur Keys übernehmen, die im HF-Modell existieren
if new_k not in model.state_dict():
print(f"⚠️ Ignoriere {new_k} (nicht im HF-Modell)")
continue
hf_state_dict[new_k] = v
# 4. Lade die Gewichte ins Modell
missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
# 5. Speichere das HF-kompatible Modell & Config
model.save_pretrained(hf_save_dir)
config.save_pretrained(hf_save_dir)
print(f"✅ Model und Config erfolgreich gespeichert in {hf_save_dir}")
# 5b. Auch als klassische .pth-Datei speichern
pth_path = hf_save_dir.rstrip("/") + ".pth"
torch.save(model.state_dict(), pth_path)
print(f"✅ Zusätzlich als .pth gespeichert unter {pth_path}")
if __name__ == "__main__":
mm_checkpoint_path = "./segformer-b5-bsm_hq.pth"
hf_save_dir = "hf_segformer_converted"
convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir)