Tim77777767 commited on
Commit
c76cec1
·
1 Parent(s): d21bb56

pthchecker hinzugefügt um Checkpoint Layer Names auszulesen und SegFormer head angepasst, sodass die fehlenden Layers nicht erstellt werden

Browse files
Files changed (2) hide show
  1. preTrainedTest.py +22 -5
  2. pthchecker.py +56 -0
preTrainedTest.py CHANGED
@@ -23,24 +23,41 @@ model.to(device).eval()
23
  image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
24
  image = Image.open(image_path).convert("RGB")
25
 
26
- # Preprocessing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  transform = T.Compose([
28
- T.Resize((512, 512)),
29
  T.ToTensor(),
30
- T.Normalize(mean=[0.485, 0.456, 0.406],
31
- std=[0.229, 0.224, 0.225])
32
  ])
33
  input_tensor = transform(image).unsqueeze(0).to(device)
 
34
 
35
  print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
36
 
37
  # Inferenz
38
  with torch.no_grad():
39
  output = model(input_tensor)
 
40
  logits = output.logits if hasattr(output, "logits") else output
41
  pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
42
 
43
  # Ergebnis als Textdatei speichern
44
  output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt")
45
  np.savetxt(output_path, pred, fmt="%d")
46
- print(f"Prediction saved as {output_path}")
 
23
  image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
24
  image = Image.open(image_path).convert("RGB")
25
 
26
+ # --- Start of changes to match cityscape_benchmark ---
27
+ # Define the target image size as in cityscape_benchmark's default
28
+ # cityscape_benchmark uses (3, 1024, 1024), so spatial size is 1024x1024
29
+ target_image_height = 1024
30
+ target_image_width = 1024
31
+
32
+ # Calculate mean and std dynamically from the image as done in cityscape_benchmark
33
+ # Note: This is usually done over the entire training dataset for consistent normalization
34
+ # For a single image, this just normalizes to its own mean/std.
35
+ img_tensor_temp = T.ToTensor()(image)
36
+ mean = img_tensor_temp.mean(dim=(1, 2)).tolist()
37
+ std = img_tensor_temp.std(dim=(1, 2)).tolist()
38
+
39
+ print(f"Calculated Mean (for this image): {mean}")
40
+ print(f"Calculated Std (for this image): {std}")
41
+
42
+ # Preprocessing - Adjusted to match cityscape_benchmark's T.Resize and T.Normalize
43
  transform = T.Compose([
44
+ T.Resize((target_image_height, target_image_width)), # Resize to 1024x1024
45
  T.ToTensor(),
46
+ T.Normalize(mean=mean, std=std) # Use dynamically calculated mean/std
 
47
  ])
48
  input_tensor = transform(image).unsqueeze(0).to(device)
49
+ # --- End of changes ---
50
 
51
  print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
52
 
53
  # Inferenz
54
  with torch.no_grad():
55
  output = model(input_tensor)
56
+ # This ensures you're always getting the raw logits if the model returns an object
57
  logits = output.logits if hasattr(output, "logits") else output
58
  pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
59
 
60
  # Ergebnis als Textdatei speichern
61
  output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt")
62
  np.savetxt(output_path, pred, fmt="%d")
63
+ print(f"Prediction saved as {output_path}")
pthchecker.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ # --- Konfiguration ---
5
+ # Pfad zu Ihrer originalen .pth-Datei
6
+ pth_checkpoint_path = "./segformer-b5-bsm_hq.pth"
7
+ # Name der Ausgabedatei für die Schlüssel
8
+ output_filename = "original_pth_keys.txt"
9
+
10
+ # --- Laden und Speichern der Schlüssel ---
11
+ try:
12
+ # Laden des Checkpoints
13
+ # map_location='cpu' ist gut, um Probleme zu vermeiden, wenn kein GPU verfügbar ist
14
+ checkpoint = torch.load(pth_checkpoint_path, map_location='cpu')
15
+
16
+ print(f"Erfolgreich geladen: {pth_checkpoint_path}")
17
+ print(f"Typ des geladenen Objekts: {type(checkpoint)}")
18
+
19
+ # Extrahieren des state_dict
20
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
21
+ # Oft ist das state_dict in einem Schlüssel wie 'state_dict' oder 'model'
22
+ state_dict = checkpoint['state_dict']
23
+ print("\nCheckpoint ist ein Dictionary und enthält 'state_dict'.")
24
+ elif isinstance(checkpoint, dict):
25
+ # Manchmal ist das gesamte Dictionary das state_dict selbst
26
+ state_dict = checkpoint
27
+ print("\nCheckpoint ist ein Dictionary (wahrscheinlich das state_dict).")
28
+ else:
29
+ # Falls es direkt das state_dict ist (z.B. nur ein OrderedDict)
30
+ state_dict = checkpoint
31
+ print("\nCheckpoint ist direkt das state_dict.")
32
+
33
+ # Sammeln und Sortieren der Schlüssel
34
+ state_dict_keys = list(state_dict.keys())
35
+ state_dict_keys.sort() # Sortiert die Schlüssel alphabetisch für bessere Übersicht
36
+
37
+ # Schreiben der Schlüssel in die Textdatei
38
+ with open(output_filename, 'w') as f:
39
+ for key in state_dict_keys:
40
+ f.write(key + '\n') # Jeder Schlüssel in einer neuen Zeile
41
+
42
+ print(f"\nAlle {len(state_dict_keys)} Schlüssel wurden erfolgreich in '{output_filename}' gespeichert.")
43
+
44
+ # Optional: Beispiel-Layer-Informationen weiterhin auf der Konsole ausgeben
45
+ if state_dict_keys:
46
+ example_key = state_dict_keys[0] # Nimmt den ersten sortierten Schlüssel
47
+ print(f"\nBeispiel-Layer: '{example_key}'")
48
+ print(f"Shape: {state_dict[example_key].shape}")
49
+ print(f"Datentyp: {state_dict[example_key].dtype}")
50
+ else:
51
+ print("\nKeine Schlüssel im state_dict gefunden.")
52
+
53
+ except FileNotFoundError:
54
+ print(f"FEHLER: Die Datei '{pth_checkpoint_path}' wurde nicht gefunden. Bitte den Pfad überprüfen.")
55
+ except Exception as e:
56
+ print(f"Ein unerwarteter Fehler ist aufgetreten: {e}")