Tim77777767
commited on
Commit
·
b7716fe
1
Parent(s):
762b1e6
Anpassungen preTrained
Browse files- preTrainedTest.py +0 -9
preTrainedTest.py
CHANGED
@@ -23,15 +23,9 @@ model.to(device).eval()
|
|
23 |
image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
|
24 |
image = Image.open(image_path).convert("RGB")
|
25 |
|
26 |
-
# --- Start of changes to match cityscape_benchmark ---
|
27 |
-
# Define the target image size as in cityscape_benchmark's default
|
28 |
-
# cityscape_benchmark uses (3, 1024, 1024), so spatial size is 1024x1024
|
29 |
target_image_height = 1024
|
30 |
target_image_width = 1024
|
31 |
|
32 |
-
# Calculate mean and std dynamically from the image as done in cityscape_benchmark
|
33 |
-
# Note: This is usually done over the entire training dataset for consistent normalization
|
34 |
-
# For a single image, this just normalizes to its own mean/std.
|
35 |
img_tensor_temp = T.ToTensor()(image)
|
36 |
mean = img_tensor_temp.mean(dim=(1, 2)).tolist()
|
37 |
std = img_tensor_temp.std(dim=(1, 2)).tolist()
|
@@ -39,21 +33,18 @@ std = img_tensor_temp.std(dim=(1, 2)).tolist()
|
|
39 |
print(f"Calculated Mean (for this image): {mean}")
|
40 |
print(f"Calculated Std (for this image): {std}")
|
41 |
|
42 |
-
# Preprocessing - Adjusted to match cityscape_benchmark's T.Resize and T.Normalize
|
43 |
transform = T.Compose([
|
44 |
T.Resize((target_image_height, target_image_width)), # Resize to 1024x1024
|
45 |
T.ToTensor(),
|
46 |
T.Normalize(mean=mean, std=std) # Use dynamically calculated mean/std
|
47 |
])
|
48 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
49 |
-
# --- End of changes ---
|
50 |
|
51 |
print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
|
52 |
|
53 |
# Inferenz
|
54 |
with torch.no_grad():
|
55 |
output = model(input_tensor)
|
56 |
-
# This ensures you're always getting the raw logits if the model returns an object
|
57 |
logits = output.logits if hasattr(output, "logits") else output
|
58 |
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
|
59 |
|
|
|
23 |
image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
|
24 |
image = Image.open(image_path).convert("RGB")
|
25 |
|
|
|
|
|
|
|
26 |
target_image_height = 1024
|
27 |
target_image_width = 1024
|
28 |
|
|
|
|
|
|
|
29 |
img_tensor_temp = T.ToTensor()(image)
|
30 |
mean = img_tensor_temp.mean(dim=(1, 2)).tolist()
|
31 |
std = img_tensor_temp.std(dim=(1, 2)).tolist()
|
|
|
33 |
print(f"Calculated Mean (for this image): {mean}")
|
34 |
print(f"Calculated Std (for this image): {std}")
|
35 |
|
|
|
36 |
transform = T.Compose([
|
37 |
T.Resize((target_image_height, target_image_width)), # Resize to 1024x1024
|
38 |
T.ToTensor(),
|
39 |
T.Normalize(mean=mean, std=std) # Use dynamically calculated mean/std
|
40 |
])
|
41 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
42 |
|
43 |
print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
|
44 |
|
45 |
# Inferenz
|
46 |
with torch.no_grad():
|
47 |
output = model(input_tensor)
|
|
|
48 |
logits = output.logits if hasattr(output, "logits") else output
|
49 |
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
|
50 |
|