SegformerPlusPlus / pthchecker.py
Tim77777767
pthchecker hinzugefügt um Checkpoint Layer Names auszulesen und SegFormer head angepasst, sodass die fehlenden Layers nicht erstellt werden
c76cec1
import torch
import os
# --- Konfiguration ---
# Pfad zu Ihrer originalen .pth-Datei
pth_checkpoint_path = "./segformer-b5-bsm_hq.pth"
# Name der Ausgabedatei für die Schlüssel
output_filename = "original_pth_keys.txt"
# --- Laden und Speichern der Schlüssel ---
try:
# Laden des Checkpoints
# map_location='cpu' ist gut, um Probleme zu vermeiden, wenn kein GPU verfügbar ist
checkpoint = torch.load(pth_checkpoint_path, map_location='cpu')
print(f"Erfolgreich geladen: {pth_checkpoint_path}")
print(f"Typ des geladenen Objekts: {type(checkpoint)}")
# Extrahieren des state_dict
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# Oft ist das state_dict in einem Schlüssel wie 'state_dict' oder 'model'
state_dict = checkpoint['state_dict']
print("\nCheckpoint ist ein Dictionary und enthält 'state_dict'.")
elif isinstance(checkpoint, dict):
# Manchmal ist das gesamte Dictionary das state_dict selbst
state_dict = checkpoint
print("\nCheckpoint ist ein Dictionary (wahrscheinlich das state_dict).")
else:
# Falls es direkt das state_dict ist (z.B. nur ein OrderedDict)
state_dict = checkpoint
print("\nCheckpoint ist direkt das state_dict.")
# Sammeln und Sortieren der Schlüssel
state_dict_keys = list(state_dict.keys())
state_dict_keys.sort() # Sortiert die Schlüssel alphabetisch für bessere Übersicht
# Schreiben der Schlüssel in die Textdatei
with open(output_filename, 'w') as f:
for key in state_dict_keys:
f.write(key + '\n') # Jeder Schlüssel in einer neuen Zeile
print(f"\nAlle {len(state_dict_keys)} Schlüssel wurden erfolgreich in '{output_filename}' gespeichert.")
# Optional: Beispiel-Layer-Informationen weiterhin auf der Konsole ausgeben
if state_dict_keys:
example_key = state_dict_keys[0] # Nimmt den ersten sortierten Schlüssel
print(f"\nBeispiel-Layer: '{example_key}'")
print(f"Shape: {state_dict[example_key].shape}")
print(f"Datentyp: {state_dict[example_key].dtype}")
else:
print("\nKeine Schlüssel im state_dict gefunden.")
except FileNotFoundError:
print(f"FEHLER: Die Datei '{pth_checkpoint_path}' wurde nicht gefunden. Bitte den Pfad überprüfen.")
except Exception as e:
print(f"Ein unerwarteter Fehler ist aufgetreten: {e}")