Tim77777767
pthchecker hinzugefügt um Checkpoint Layer Names auszulesen und SegFormer head angepasst, sodass die fehlenden Layers nicht erstellt werden
c76cec1
import torch | |
import os | |
# --- Konfiguration --- | |
# Pfad zu Ihrer originalen .pth-Datei | |
pth_checkpoint_path = "./segformer-b5-bsm_hq.pth" | |
# Name der Ausgabedatei für die Schlüssel | |
output_filename = "original_pth_keys.txt" | |
# --- Laden und Speichern der Schlüssel --- | |
try: | |
# Laden des Checkpoints | |
# map_location='cpu' ist gut, um Probleme zu vermeiden, wenn kein GPU verfügbar ist | |
checkpoint = torch.load(pth_checkpoint_path, map_location='cpu') | |
print(f"Erfolgreich geladen: {pth_checkpoint_path}") | |
print(f"Typ des geladenen Objekts: {type(checkpoint)}") | |
# Extrahieren des state_dict | |
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
# Oft ist das state_dict in einem Schlüssel wie 'state_dict' oder 'model' | |
state_dict = checkpoint['state_dict'] | |
print("\nCheckpoint ist ein Dictionary und enthält 'state_dict'.") | |
elif isinstance(checkpoint, dict): | |
# Manchmal ist das gesamte Dictionary das state_dict selbst | |
state_dict = checkpoint | |
print("\nCheckpoint ist ein Dictionary (wahrscheinlich das state_dict).") | |
else: | |
# Falls es direkt das state_dict ist (z.B. nur ein OrderedDict) | |
state_dict = checkpoint | |
print("\nCheckpoint ist direkt das state_dict.") | |
# Sammeln und Sortieren der Schlüssel | |
state_dict_keys = list(state_dict.keys()) | |
state_dict_keys.sort() # Sortiert die Schlüssel alphabetisch für bessere Übersicht | |
# Schreiben der Schlüssel in die Textdatei | |
with open(output_filename, 'w') as f: | |
for key in state_dict_keys: | |
f.write(key + '\n') # Jeder Schlüssel in einer neuen Zeile | |
print(f"\nAlle {len(state_dict_keys)} Schlüssel wurden erfolgreich in '{output_filename}' gespeichert.") | |
# Optional: Beispiel-Layer-Informationen weiterhin auf der Konsole ausgeben | |
if state_dict_keys: | |
example_key = state_dict_keys[0] # Nimmt den ersten sortierten Schlüssel | |
print(f"\nBeispiel-Layer: '{example_key}'") | |
print(f"Shape: {state_dict[example_key].shape}") | |
print(f"Datentyp: {state_dict[example_key].dtype}") | |
else: | |
print("\nKeine Schlüssel im state_dict gefunden.") | |
except FileNotFoundError: | |
print(f"FEHLER: Die Datei '{pth_checkpoint_path}' wurde nicht gefunden. Bitte den Pfad überprüfen.") | |
except Exception as e: | |
print(f"Ein unerwarteter Fehler ist aufgetreten: {e}") |