File size: 5,014 Bytes
57db94b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import math
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
import lpips
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from eval.distance_transform_v0 import ChamferDistance2dMetric
import torch
from pytorch_msssim import ssim_matlab as ssim
import glob
import re

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def calc_psnr(pred, gt):
    return -10 * math.log10(((pred - gt) * (pred - gt)).mean())

# Initialize metrics and lists
loss_fn_alex = lpips.LPIPS(net='alex')
transform = transforms.Compose([
    transforms.PILToTensor()
])
cd = ChamferDistance2dMetric()
torch.set_printoptions(precision=4)

# Initialize metric lists
psnrlisto = []
ssimlisto = []
ielisto = []
cdlisto = []

# Define the directory with the eval data
eval_dir = "eval_data"

# Check if directory exists
if not os.path.exists(eval_dir):
    print(f"Error: Directory '{eval_dir}' not found")
else:
    # Find all io (interpolated) files
    io_files = glob.glob(os.path.join(eval_dir, "io*.png"))
    print(f"Found {len(io_files)} interpolated images in {eval_dir}")
    
    # Sort files to ensure proper matching
    io_files.sort(key=lambda x: int(re.findall(r'io(\d+)\.png', x)[0]))
    
    # Process each pair of images
    for io_path in io_files:
        # Extract the index number
        index_match = re.search(r'io(\d+)\.png', io_path)
        if not index_match:
            print(f"Skipping {io_path} - unable to extract index")
            continue
            
        index = index_match.group(1)
        gt_path = os.path.join(eval_dir, f"gt{index}.png")
        
        # Check if the corresponding ground truth exists
        if not os.path.exists(gt_path):
            print(f"Warning: Ground truth image '{gt_path}' not found for '{io_path}'")
            continue
        
        print(f"Processing pair {index}: {io_path} and {gt_path}")
        
        try:
            # Load predicted image
            pio = Image.open(io_path).convert("RGB").resize((384, 192))
            pred1o = transform(pio).unsqueeze(0).float()/255.
            predo = np.asarray(pio)
            
            # Load ground truth image
            gi = Image.open(gt_path).convert("RGB").resize((384, 192))
            gt1 = transform(gi).unsqueeze(0).float()/255.
            gt = np.asarray(gi)
            
            # Calculate metrics
            try:
                psnr_val = psnr(predo, gt)
                psnrlisto.append(psnr_val)
                print(f"PSNR: {psnr_val:.4f}")
            except Exception as e:
                print(f"Error calculating PSNR: {e}")
                
            try:
                ie_val = math.sqrt(mse(predo, gt))
                ielisto.append(ie_val)
                print(f"IE: {ie_val:.4f}")
            except Exception as e:
                print(f"Error calculating IE: {e}")
                
            try:
                cd_val = cd.calc(pred1o, gt1)
                cdlisto.append(cd_val)
                print(f"CD: {cd_val:.4f}")
            except Exception as e:
                print(f"Error calculating Chamfer Distance: {e}")
                
            try:
                ssim_val = ssim(pred1o, gt1)
                ssimlisto.append(ssim_val)
                print(f"SSIM: {ssim_val:.4f}")
            except Exception as e:
                print(f"Error calculating SSIM: {e}")
                
        except Exception as e:
            print(f"Error processing images: {e}")
            import traceback
            traceback.print_exc()

# Calculate averages with safety checks
psnr_avgo = np.average(psnrlisto) if psnrlisto else 0
ie_avgo = np.average(ielisto) if ielisto else 0
cd_avgo = np.average(cdlisto) if cdlisto else 0
ssim_avgo = np.average(ssimlisto) if ssimlisto else 0

# Print individual metrics for all processed pairs
print("\nMetrics for each processed pair:")
for i, (psnr_val, ssim_val, cd_val, ie_val) in enumerate(zip(psnrlisto, ssimlisto, cdlisto, ielisto)):
    print(f"Pair {i+1}:")
    print(f"  PSNR: {psnr_val:.4f}")
    print(f"  SSIM: {ssim_val:.4f}")
    print(f"  CD: {cd_val:.4f}")
    print(f"  IE: {ie_val:.4f}")

# Print average results
print("\nFinal Results (Averages):")
print(f"Number of image pairs processed: {len(psnrlisto)}")
print("cdval: {:.4f}".format(cd_avgo))
print("ssim: {:.4f}".format(ssim_avgo))
print("psnr: {:.4f}".format(psnr_avgo))
print("ie: {:.4f}".format(ie_avgo))

# Optionally save results to a file
results_file = "eval_results.txt"
with open(results_file, "w") as f:
    f.write("Evaluation Results\n")
    f.write(f"Number of image pairs processed: {len(psnrlisto)}\n")
    f.write("cdval: {:.4f}\n".format(cd_avgo))
    f.write("ssim: {:.4f}\n".format(ssim_avgo))
    f.write("psnr: {:.4f}\n".format(psnr_avgo))
    f.write("ie: {:.4f}\n".format(ie_avgo))

print(f"\nResults saved to {results_file}")