File size: 2,666 Bytes
1a260cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from mix_vision_transformer_config import MySegformerConfig
from modeling_my_segformer import MySegformerForSemanticSegmentation

def convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir):
    # 1. Lade mmengine checkpoint
    mm_ckpt = torch.load(mm_checkpoint_path, map_location="cpu")
    if 'state_dict' in mm_ckpt:
        mm_state_dict = mm_ckpt['state_dict']
    else:
        mm_state_dict = mm_ckpt

    # 2. Erstelle Config & Modell (achte darauf, dass Config-Parameter zum Checkpoint passen)
    config = MySegformerConfig(
        embed_dims=[64, 128, 320, 512],  # <--- korrekte Liste mit 4 Werten
        num_stages=4,
        num_layers=[3, 4, 6, 3],
        num_heads=[1, 2, 4, 8],
        patch_sizes=[7, 3, 3, 3],
        strides=[4, 2, 2, 2],
        sr_ratios=[8, 4, 2, 1],
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        out_indices=(0, 1, 2, 3),
        num_classes=19
    )

    model = MySegformerForSemanticSegmentation(config)

    # 3. Mappe mmengine Keys auf HF Keys
    hf_state_dict = {}

    for k, v in mm_state_dict.items():
        new_k = k

        # Falls "module." als Prefix da ist (DataParallel), entfernen
        if new_k.startswith("module."):
            new_k = new_k[len("module."):]

        # Mapping von decode_head.* -> segmentation_head.*
        if new_k.startswith("decode_head."):
            new_k = new_k.replace("decode_head.", "segmentation_head.")

        # BatchNorm-Namen vereinheitlichen
        new_k = new_k.replace(".bn.", ".")

        # Nur Keys übernehmen, die im HF-Modell existieren
        if new_k not in model.state_dict():
            print(f"⚠️ Ignoriere {new_k} (nicht im HF-Modell)")
            continue

        hf_state_dict[new_k] = v

    # 4. Lade die Gewichte ins Modell
    missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False)

    print("Missing keys:", missing_keys)
    print("Unexpected keys:", unexpected_keys)

    # 5. Speichere das HF-kompatible Modell & Config
    model.save_pretrained(hf_save_dir)
    config.save_pretrained(hf_save_dir)

    print(f"✅ Model und Config erfolgreich gespeichert in {hf_save_dir}")

    # 5b. Auch als klassische .pth-Datei speichern
    pth_path = hf_save_dir.rstrip("/") + ".pth"
    torch.save(model.state_dict(), pth_path)
    print(f"✅ Zusätzlich als .pth gespeichert unter {pth_path}")


if __name__ == "__main__":
    mm_checkpoint_path = "./segformer-b5-bsm_hq.pth"
    hf_save_dir = "hf_segformer_converted"

    convert_mmengine_checkpoint_to_hf(mm_checkpoint_path, hf_save_dir)