Tim77777767
commited on
Commit
·
c76cec1
1
Parent(s):
d21bb56
pthchecker hinzugefügt um Checkpoint Layer Names auszulesen und SegFormer head angepasst, sodass die fehlenden Layers nicht erstellt werden
Browse files- preTrainedTest.py +22 -5
- pthchecker.py +56 -0
preTrainedTest.py
CHANGED
@@ -23,24 +23,41 @@ model.to(device).eval()
|
|
23 |
image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
|
24 |
image = Image.open(image_path).convert("RGB")
|
25 |
|
26 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
transform = T.Compose([
|
28 |
-
T.Resize((
|
29 |
T.ToTensor(),
|
30 |
-
T.Normalize(mean=
|
31 |
-
std=[0.229, 0.224, 0.225])
|
32 |
])
|
33 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
34 |
|
35 |
print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
|
36 |
|
37 |
# Inferenz
|
38 |
with torch.no_grad():
|
39 |
output = model(input_tensor)
|
|
|
40 |
logits = output.logits if hasattr(output, "logits") else output
|
41 |
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
|
42 |
|
43 |
# Ergebnis als Textdatei speichern
|
44 |
output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt")
|
45 |
np.savetxt(output_path, pred, fmt="%d")
|
46 |
-
print(f"Prediction saved as {output_path}")
|
|
|
23 |
image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
|
24 |
image = Image.open(image_path).convert("RGB")
|
25 |
|
26 |
+
# --- Start of changes to match cityscape_benchmark ---
|
27 |
+
# Define the target image size as in cityscape_benchmark's default
|
28 |
+
# cityscape_benchmark uses (3, 1024, 1024), so spatial size is 1024x1024
|
29 |
+
target_image_height = 1024
|
30 |
+
target_image_width = 1024
|
31 |
+
|
32 |
+
# Calculate mean and std dynamically from the image as done in cityscape_benchmark
|
33 |
+
# Note: This is usually done over the entire training dataset for consistent normalization
|
34 |
+
# For a single image, this just normalizes to its own mean/std.
|
35 |
+
img_tensor_temp = T.ToTensor()(image)
|
36 |
+
mean = img_tensor_temp.mean(dim=(1, 2)).tolist()
|
37 |
+
std = img_tensor_temp.std(dim=(1, 2)).tolist()
|
38 |
+
|
39 |
+
print(f"Calculated Mean (for this image): {mean}")
|
40 |
+
print(f"Calculated Std (for this image): {std}")
|
41 |
+
|
42 |
+
# Preprocessing - Adjusted to match cityscape_benchmark's T.Resize and T.Normalize
|
43 |
transform = T.Compose([
|
44 |
+
T.Resize((target_image_height, target_image_width)), # Resize to 1024x1024
|
45 |
T.ToTensor(),
|
46 |
+
T.Normalize(mean=mean, std=std) # Use dynamically calculated mean/std
|
|
|
47 |
])
|
48 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
49 |
+
# --- End of changes ---
|
50 |
|
51 |
print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
|
52 |
|
53 |
# Inferenz
|
54 |
with torch.no_grad():
|
55 |
output = model(input_tensor)
|
56 |
+
# This ensures you're always getting the raw logits if the model returns an object
|
57 |
logits = output.logits if hasattr(output, "logits") else output
|
58 |
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
|
59 |
|
60 |
# Ergebnis als Textdatei speichern
|
61 |
output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt")
|
62 |
np.savetxt(output_path, pred, fmt="%d")
|
63 |
+
print(f"Prediction saved as {output_path}")
|
pthchecker.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
# --- Konfiguration ---
|
5 |
+
# Pfad zu Ihrer originalen .pth-Datei
|
6 |
+
pth_checkpoint_path = "./segformer-b5-bsm_hq.pth"
|
7 |
+
# Name der Ausgabedatei für die Schlüssel
|
8 |
+
output_filename = "original_pth_keys.txt"
|
9 |
+
|
10 |
+
# --- Laden und Speichern der Schlüssel ---
|
11 |
+
try:
|
12 |
+
# Laden des Checkpoints
|
13 |
+
# map_location='cpu' ist gut, um Probleme zu vermeiden, wenn kein GPU verfügbar ist
|
14 |
+
checkpoint = torch.load(pth_checkpoint_path, map_location='cpu')
|
15 |
+
|
16 |
+
print(f"Erfolgreich geladen: {pth_checkpoint_path}")
|
17 |
+
print(f"Typ des geladenen Objekts: {type(checkpoint)}")
|
18 |
+
|
19 |
+
# Extrahieren des state_dict
|
20 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
21 |
+
# Oft ist das state_dict in einem Schlüssel wie 'state_dict' oder 'model'
|
22 |
+
state_dict = checkpoint['state_dict']
|
23 |
+
print("\nCheckpoint ist ein Dictionary und enthält 'state_dict'.")
|
24 |
+
elif isinstance(checkpoint, dict):
|
25 |
+
# Manchmal ist das gesamte Dictionary das state_dict selbst
|
26 |
+
state_dict = checkpoint
|
27 |
+
print("\nCheckpoint ist ein Dictionary (wahrscheinlich das state_dict).")
|
28 |
+
else:
|
29 |
+
# Falls es direkt das state_dict ist (z.B. nur ein OrderedDict)
|
30 |
+
state_dict = checkpoint
|
31 |
+
print("\nCheckpoint ist direkt das state_dict.")
|
32 |
+
|
33 |
+
# Sammeln und Sortieren der Schlüssel
|
34 |
+
state_dict_keys = list(state_dict.keys())
|
35 |
+
state_dict_keys.sort() # Sortiert die Schlüssel alphabetisch für bessere Übersicht
|
36 |
+
|
37 |
+
# Schreiben der Schlüssel in die Textdatei
|
38 |
+
with open(output_filename, 'w') as f:
|
39 |
+
for key in state_dict_keys:
|
40 |
+
f.write(key + '\n') # Jeder Schlüssel in einer neuen Zeile
|
41 |
+
|
42 |
+
print(f"\nAlle {len(state_dict_keys)} Schlüssel wurden erfolgreich in '{output_filename}' gespeichert.")
|
43 |
+
|
44 |
+
# Optional: Beispiel-Layer-Informationen weiterhin auf der Konsole ausgeben
|
45 |
+
if state_dict_keys:
|
46 |
+
example_key = state_dict_keys[0] # Nimmt den ersten sortierten Schlüssel
|
47 |
+
print(f"\nBeispiel-Layer: '{example_key}'")
|
48 |
+
print(f"Shape: {state_dict[example_key].shape}")
|
49 |
+
print(f"Datentyp: {state_dict[example_key].dtype}")
|
50 |
+
else:
|
51 |
+
print("\nKeine Schlüssel im state_dict gefunden.")
|
52 |
+
|
53 |
+
except FileNotFoundError:
|
54 |
+
print(f"FEHLER: Die Datei '{pth_checkpoint_path}' wurde nicht gefunden. Bitte den Pfad überprüfen.")
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Ein unerwarteter Fehler ist aufgetreten: {e}")
|