File size: 2,535 Bytes
e98bd8c
 
 
 
 
 
 
 
 
85ebba9
e98bd8c
 
 
 
 
 
 
 
 
 
 
02508fb
 
 
 
e98bd8c
 
 
85ebba9
 
10e4799
e98bd8c
 
 
 
10e4799
 
e98bd8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import torch
import numpy as np
import argparse

from .build_model import create_model
from .cityscape_benchmark import cityscape_benchmark

parser = argparse.ArgumentParser(description="Segformer Benchmarking Script")
parser.add_argument('--backbone', type=str, default='b5', choices=['b0', 'b1', 'b2', 'b3', 'b4', 'b5'], help='Model backbone version')
parser.add_argument('--head', type=str, default='bsm_hq', choices=['bsm_hq', 'bsm_fast', 'n2d_2x2'], help='Model head type')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to .pth checkpoint file (optional)')
args = parser.parse_args()

model = create_model(args.backbone, args.head, pretrained=True)

if args.checkpoint:
    checkpoint_path = os.path.expanduser(args.checkpoint)
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint)

    # state_dict nach Laden des Checkpoints abspeichern
    #save_path = os.path.join(os.getcwd(), "pytorch_model.bin")
    #torch.save(model.state_dict(), save_path)
else:
    print("No checkpoint provided – using model as initialized.")

cwd = os.getcwd()

image_path = os.path.join(cwd, 'segformer_plusplus', 'cityscape', 'berlin_000543_000019_leftImg8bit.png')
result = cityscape_benchmark(model, image_path)

print("Cityscapes Benchmark Results:", result)

reference_txt_path = os.path.join(cwd, 'segformer_plusplus', 'cityscapes_prediction_output_reference_b05_bsm_hq_nocheckpoint.txt')
generated_txt_path = os.path.join(cwd, 'segformer_plusplus', 'cityscapes_prediction_output.txt')

if os.path.exists(reference_txt_path) and os.path.exists(generated_txt_path):
    ref_arr = np.loadtxt(reference_txt_path, dtype=int)
    gen_arr = np.loadtxt(generated_txt_path, dtype=int)

    if ref_arr.shape != gen_arr.shape:
        print(f"Files have different shapes: {ref_arr.shape} vs. {gen_arr.shape}")
    else:
        total_elements = ref_arr.size
        equal_elements = np.sum(ref_arr == gen_arr)
        similarity = equal_elements / total_elements

        threshold = 0.999
        if similarity >= threshold:
            print(f"Outputs are {similarity*100:.4f}% identical (>= {threshold*100}%).")
        else:
            print(f"Outputs differ by {100 - similarity*100:.4f}%.")
else:
    if not os.path.exists(reference_txt_path):
        print(f"Reference file not found: {reference_txt_path}")
    if not os.path.exists(generated_txt_path):
        print(f"Generated output file not found: {generated_txt_path}")