|
import os |
|
import math |
|
from skimage.metrics import peak_signal_noise_ratio as psnr |
|
from skimage.metrics import mean_squared_error as mse |
|
import lpips |
|
from PIL import Image |
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
from eval.distance_transform_v0 import ChamferDistance2dMetric |
|
import torch |
|
from pytorch_msssim import ssim_matlab as ssim |
|
import glob |
|
import re |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
def calc_psnr(pred, gt): |
|
return -10 * math.log10(((pred - gt) * (pred - gt)).mean()) |
|
|
|
|
|
loss_fn_alex = lpips.LPIPS(net='alex') |
|
transform = transforms.Compose([ |
|
transforms.PILToTensor() |
|
]) |
|
cd = ChamferDistance2dMetric() |
|
torch.set_printoptions(precision=4) |
|
|
|
|
|
psnrlisto = [] |
|
ssimlisto = [] |
|
ielisto = [] |
|
cdlisto = [] |
|
|
|
|
|
eval_dir = "eval_data" |
|
|
|
|
|
if not os.path.exists(eval_dir): |
|
print(f"Error: Directory '{eval_dir}' not found") |
|
else: |
|
|
|
io_files = glob.glob(os.path.join(eval_dir, "io*.png")) |
|
print(f"Found {len(io_files)} interpolated images in {eval_dir}") |
|
|
|
|
|
io_files.sort(key=lambda x: int(re.findall(r'io(\d+)\.png', x)[0])) |
|
|
|
|
|
for io_path in io_files: |
|
|
|
index_match = re.search(r'io(\d+)\.png', io_path) |
|
if not index_match: |
|
print(f"Skipping {io_path} - unable to extract index") |
|
continue |
|
|
|
index = index_match.group(1) |
|
gt_path = os.path.join(eval_dir, f"gt{index}.png") |
|
|
|
|
|
if not os.path.exists(gt_path): |
|
print(f"Warning: Ground truth image '{gt_path}' not found for '{io_path}'") |
|
continue |
|
|
|
print(f"Processing pair {index}: {io_path} and {gt_path}") |
|
|
|
try: |
|
|
|
pio = Image.open(io_path).convert("RGB").resize((384, 192)) |
|
pred1o = transform(pio).unsqueeze(0).float()/255. |
|
predo = np.asarray(pio) |
|
|
|
|
|
gi = Image.open(gt_path).convert("RGB").resize((384, 192)) |
|
gt1 = transform(gi).unsqueeze(0).float()/255. |
|
gt = np.asarray(gi) |
|
|
|
|
|
try: |
|
psnr_val = psnr(predo, gt) |
|
psnrlisto.append(psnr_val) |
|
print(f"PSNR: {psnr_val:.4f}") |
|
except Exception as e: |
|
print(f"Error calculating PSNR: {e}") |
|
|
|
try: |
|
ie_val = math.sqrt(mse(predo, gt)) |
|
ielisto.append(ie_val) |
|
print(f"IE: {ie_val:.4f}") |
|
except Exception as e: |
|
print(f"Error calculating IE: {e}") |
|
|
|
try: |
|
cd_val = cd.calc(pred1o, gt1) |
|
cdlisto.append(cd_val) |
|
print(f"CD: {cd_val:.4f}") |
|
except Exception as e: |
|
print(f"Error calculating Chamfer Distance: {e}") |
|
|
|
try: |
|
ssim_val = ssim(pred1o, gt1) |
|
ssimlisto.append(ssim_val) |
|
print(f"SSIM: {ssim_val:.4f}") |
|
except Exception as e: |
|
print(f"Error calculating SSIM: {e}") |
|
|
|
except Exception as e: |
|
print(f"Error processing images: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
psnr_avgo = np.average(psnrlisto) if psnrlisto else 0 |
|
ie_avgo = np.average(ielisto) if ielisto else 0 |
|
cd_avgo = np.average(cdlisto) if cdlisto else 0 |
|
ssim_avgo = np.average(ssimlisto) if ssimlisto else 0 |
|
|
|
|
|
print("\nMetrics for each processed pair:") |
|
for i, (psnr_val, ssim_val, cd_val, ie_val) in enumerate(zip(psnrlisto, ssimlisto, cdlisto, ielisto)): |
|
print(f"Pair {i+1}:") |
|
print(f" PSNR: {psnr_val:.4f}") |
|
print(f" SSIM: {ssim_val:.4f}") |
|
print(f" CD: {cd_val:.4f}") |
|
print(f" IE: {ie_val:.4f}") |
|
|
|
|
|
print("\nFinal Results (Averages):") |
|
print(f"Number of image pairs processed: {len(psnrlisto)}") |
|
print("cdval: {:.4f}".format(cd_avgo)) |
|
print("ssim: {:.4f}".format(ssim_avgo)) |
|
print("psnr: {:.4f}".format(psnr_avgo)) |
|
print("ie: {:.4f}".format(ie_avgo)) |
|
|
|
|
|
results_file = "eval_results.txt" |
|
with open(results_file, "w") as f: |
|
f.write("Evaluation Results\n") |
|
f.write(f"Number of image pairs processed: {len(psnrlisto)}\n") |
|
f.write("cdval: {:.4f}\n".format(cd_avgo)) |
|
f.write("ssim: {:.4f}\n".format(ssim_avgo)) |
|
f.write("psnr: {:.4f}\n".format(psnr_avgo)) |
|
f.write("ie: {:.4f}\n".format(ie_avgo)) |
|
|
|
print(f"\nResults saved to {results_file}") |