|
import os |
|
import torch |
|
import numpy as np |
|
import argparse |
|
|
|
from .build_model import create_model |
|
from .cityscape_benchmark import cityscape_benchmark |
|
|
|
parser = argparse.ArgumentParser(description="Segformer Benchmarking Script") |
|
parser.add_argument('--backbone', type=str, default='b5', choices=['b0', 'b1', 'b2', 'b3', 'b4', 'b5'], help='Model backbone version') |
|
parser.add_argument('--head', type=str, default='bsm_hq', choices=['bsm_hq', 'bsm_fast', 'n2d_2x2'], help='Model head type') |
|
parser.add_argument('--checkpoint', type=str, default=None, help='Path to .pth checkpoint file (optional)') |
|
args = parser.parse_args() |
|
|
|
model = create_model(args.backbone, args.head, pretrained=True) |
|
|
|
if args.checkpoint: |
|
checkpoint_path = os.path.expanduser(args.checkpoint) |
|
print(f"Loading checkpoint: {checkpoint_path}") |
|
checkpoint = torch.load(checkpoint_path) |
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
|
|
|
|
else: |
|
print("No checkpoint provided – using model as initialized.") |
|
|
|
cwd = os.getcwd() |
|
|
|
image_path = os.path.join(cwd, 'segformer_plusplus', 'cityscape', 'berlin_000543_000019_leftImg8bit.png') |
|
result = cityscape_benchmark(model, image_path) |
|
|
|
print("Cityscapes Benchmark Results:", result) |
|
|
|
reference_txt_path = os.path.join(cwd, 'segformer_plusplus', 'cityscapes_prediction_output_reference_b05_bsm_hq_nocheckpoint.txt') |
|
generated_txt_path = os.path.join(cwd, 'segformer_plusplus', 'cityscapes_prediction_output.txt') |
|
|
|
if os.path.exists(reference_txt_path) and os.path.exists(generated_txt_path): |
|
ref_arr = np.loadtxt(reference_txt_path, dtype=int) |
|
gen_arr = np.loadtxt(generated_txt_path, dtype=int) |
|
|
|
if ref_arr.shape != gen_arr.shape: |
|
print(f"Files have different shapes: {ref_arr.shape} vs. {gen_arr.shape}") |
|
else: |
|
total_elements = ref_arr.size |
|
equal_elements = np.sum(ref_arr == gen_arr) |
|
similarity = equal_elements / total_elements |
|
|
|
threshold = 0.999 |
|
if similarity >= threshold: |
|
print(f"Outputs are {similarity*100:.4f}% identical (>= {threshold*100}%).") |
|
else: |
|
print(f"Outputs differ by {100 - similarity*100:.4f}%.") |
|
else: |
|
if not os.path.exists(reference_txt_path): |
|
print(f"Reference file not found: {reference_txt_path}") |
|
if not os.path.exists(generated_txt_path): |
|
print(f"Generated output file not found: {generated_txt_path}") |
|
|