|
import torch |
|
import os |
|
|
|
|
|
|
|
input_checkpoint_path = "./pytorch_model_oldLayerNames.bin" |
|
|
|
output_checkpoint_path = "./pytorch_model_renamed_final.bin" |
|
|
|
|
|
|
|
specific_segmentation_head_layers = [ |
|
'decode_head.conv_seg.bias', |
|
'decode_head.conv_seg.weight', |
|
'decode_head.convs.0.conv.bias', |
|
'decode_head.convs.0.conv.weight', |
|
'decode_head.convs.1.conv.bias', |
|
'decode_head.convs.1.conv.weight', |
|
'decode_head.convs.2.conv.bias', |
|
'decode_head.convs.2.conv.weight', |
|
'decode_head.convs.3.conv.bias', |
|
'decode_head.convs.3.conv.weight', |
|
'decode_head.fusion_conv.conv.bias', |
|
'decode_head.fusion_conv.conv.weight' |
|
] |
|
|
|
|
|
if not os.path.exists(input_checkpoint_path): |
|
print(f"Fehler: Eingabedatei nicht gefunden unter {input_checkpoint_path}. Bitte den Pfad korrigieren. ❌") |
|
else: |
|
|
|
state_dict = torch.load(input_checkpoint_path, map_location="cpu") |
|
|
|
|
|
new_state_dict = {} |
|
renamed_count_segmentation = 0 |
|
renamed_count_segformer = 0 |
|
skipped_count = 0 |
|
|
|
for old_key, value in state_dict.items(): |
|
if old_key in specific_segmentation_head_layers: |
|
|
|
new_key = old_key.replace('decode_head.', 'segmentation_head.', 1) |
|
new_state_dict[new_key] = value |
|
renamed_count_segmentation += 1 |
|
elif old_key.startswith('decode_head.'): |
|
|
|
new_key = old_key.replace('decode_head.', 'segformer_head.', 1) |
|
new_state_dict[new_key] = value |
|
renamed_count_segformer += 1 |
|
else: |
|
|
|
new_state_dict[old_key] = value |
|
skipped_count += 1 |
|
|
|
|
|
torch.save(new_state_dict, output_checkpoint_path) |
|
|
|
print(f"✅ Fertig! Die umbenannte Datei wurde gespeichert unter: {output_checkpoint_path}") |
|
print(f"Zusammenfassung der Umbenennungen:") |
|
print(f" - '{renamed_count_segmentation}' Layer von 'decode_head.' zu 'segmentation_head.' umbenannt.") |
|
print(f" - '{renamed_count_segformer}' Layer von 'decode_head.' zu 'segformer_head.' umbenannt.") |
|
print(f" - '{skipped_count}' Layer behielten ihren ursprünglichen Namen (z.B. Backbone).") |