Tim77777767
pthchecker hinzugefügt um Checkpoint Layer Names auszulesen und SegFormer head angepasst, sodass die fehlenden Layers nicht erstellt werden
c76cec1
import torch | |
from PIL import Image | |
import torchvision.transforms as T | |
import numpy as np | |
import os | |
from modeling_my_segformer import MySegformerForSemanticSegmentation | |
from mix_vision_transformer_config import MySegformerConfig | |
# Gerät auswählen | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Modell laden | |
model_name_or_path = "TimM77/SegformerPlusPlus" | |
print("Starte config_load") | |
config = MySegformerConfig.from_pretrained(model_name_or_path) | |
print("Starte Model_load") | |
model = MySegformerForSemanticSegmentation.from_pretrained(model_name_or_path, config=config) | |
model.to(device).eval() | |
# Bild laden | |
image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png" | |
image = Image.open(image_path).convert("RGB") | |
# --- Start of changes to match cityscape_benchmark --- | |
# Define the target image size as in cityscape_benchmark's default | |
# cityscape_benchmark uses (3, 1024, 1024), so spatial size is 1024x1024 | |
target_image_height = 1024 | |
target_image_width = 1024 | |
# Calculate mean and std dynamically from the image as done in cityscape_benchmark | |
# Note: This is usually done over the entire training dataset for consistent normalization | |
# For a single image, this just normalizes to its own mean/std. | |
img_tensor_temp = T.ToTensor()(image) | |
mean = img_tensor_temp.mean(dim=(1, 2)).tolist() | |
std = img_tensor_temp.std(dim=(1, 2)).tolist() | |
print(f"Calculated Mean (for this image): {mean}") | |
print(f"Calculated Std (for this image): {std}") | |
# Preprocessing - Adjusted to match cityscape_benchmark's T.Resize and T.Normalize | |
transform = T.Compose([ | |
T.Resize((target_image_height, target_image_width)), # Resize to 1024x1024 | |
T.ToTensor(), | |
T.Normalize(mean=mean, std=std) # Use dynamically calculated mean/std | |
]) | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
# --- End of changes --- | |
print("Modell geladen, Bild geladen, Preprocessing abgeschlossen") | |
# Inferenz | |
with torch.no_grad(): | |
output = model(input_tensor) | |
# This ensures you're always getting the raw logits if the model returns an object | |
logits = output.logits if hasattr(output, "logits") else output | |
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy() | |
# Ergebnis als Textdatei speichern | |
output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt") | |
np.savetxt(output_path, pred, fmt="%d") | |
print(f"Prediction saved as {output_path}") |