SegformerPlusPlus / preTrainedTest.py
Tim77777767
pthchecker hinzugefügt um Checkpoint Layer Names auszulesen und SegFormer head angepasst, sodass die fehlenden Layers nicht erstellt werden
c76cec1
raw
history blame
2.45 kB
import torch
from PIL import Image
import torchvision.transforms as T
import numpy as np
import os
from modeling_my_segformer import MySegformerForSemanticSegmentation
from mix_vision_transformer_config import MySegformerConfig
# Gerät auswählen
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Modell laden
model_name_or_path = "TimM77/SegformerPlusPlus"
print("Starte config_load")
config = MySegformerConfig.from_pretrained(model_name_or_path)
print("Starte Model_load")
model = MySegformerForSemanticSegmentation.from_pretrained(model_name_or_path, config=config)
model.to(device).eval()
# Bild laden
image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
image = Image.open(image_path).convert("RGB")
# --- Start of changes to match cityscape_benchmark ---
# Define the target image size as in cityscape_benchmark's default
# cityscape_benchmark uses (3, 1024, 1024), so spatial size is 1024x1024
target_image_height = 1024
target_image_width = 1024
# Calculate mean and std dynamically from the image as done in cityscape_benchmark
# Note: This is usually done over the entire training dataset for consistent normalization
# For a single image, this just normalizes to its own mean/std.
img_tensor_temp = T.ToTensor()(image)
mean = img_tensor_temp.mean(dim=(1, 2)).tolist()
std = img_tensor_temp.std(dim=(1, 2)).tolist()
print(f"Calculated Mean (for this image): {mean}")
print(f"Calculated Std (for this image): {std}")
# Preprocessing - Adjusted to match cityscape_benchmark's T.Resize and T.Normalize
transform = T.Compose([
T.Resize((target_image_height, target_image_width)), # Resize to 1024x1024
T.ToTensor(),
T.Normalize(mean=mean, std=std) # Use dynamically calculated mean/std
])
input_tensor = transform(image).unsqueeze(0).to(device)
# --- End of changes ---
print("Modell geladen, Bild geladen, Preprocessing abgeschlossen")
# Inferenz
with torch.no_grad():
output = model(input_tensor)
# This ensures you're always getting the raw logits if the model returns an object
logits = output.logits if hasattr(output, "logits") else output
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
# Ergebnis als Textdatei speichern
output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt")
np.savetxt(output_path, pred, fmt="%d")
print(f"Prediction saved as {output_path}")