pineappleSoup's picture
Upload folder using huggingface_hub
57db94b verified
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from torchvision import transforms
import os
import matplotlib.pyplot as plt
from PIL import Image
import time
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2lab
from combined import IFNet, warp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = IFNet().to(device)
checkpoint_path = "save_checkpoints/model_epoch_50.pth"
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded model from epoch {checkpoint['epoch']} with PSNR: {checkpoint.get('psnr', 'N/A')} dB")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def preprocess_images(img0_path, img1_path, gt_path=None):
# Read images
img0 = cv2.imread(img0_path)
img1 = cv2.imread(img1_path)
if img0 is None or img1 is None:
raise ValueError(f"Could not read images: {img0_path}, {img1_path}")
gt = None
if gt_path and os.path.exists(gt_path):
gt = cv2.imread(gt_path)
if gt is None:
print(f"Warning: Could not read ground truth image: {gt_path}")
gt = None
else:
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
original_size = (img0.shape[0], img0.shape[1])
orig_img0 = img0.copy()
orig_img1 = img1.copy()
img0_resized = cv2.resize(img0, (256, 256))
img1_resized = cv2.resize(img1, (256, 256))
img0_tensor = transform(img0_resized)
img1_tensor = transform(img1_resized)
input_tensor = torch.cat((img0_tensor, img1_tensor), 0).unsqueeze(0).to(device)
return input_tensor, original_size, orig_img0, orig_img1, gt
def tensor_to_image(tensor):
tensor = tensor.cpu()
tensor = tensor * 0.5 + 0.5
tensor = tensor.clamp(0, 1)
img = tensor.numpy().transpose(1, 2, 0) * 255
return img.astype(np.uint8)
def calculate_psnr(img1, img2):
mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
if mse == 0:
return float('inf')
return 10 * np.log10(255.0 ** 2 / mse)
def calculate_ssim(img1, img2):
if img1.ndim == 3 and img1.shape[2] == 3:
gray1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
gray2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
return ssim(gray1, gray2)
return ssim(img1, img2)
def calculate_cd(img1, img2):
lab1 = rgb2lab(img1 / 255.0)
lab2 = rgb2lab(img2 / 255.0)
delta_e = np.sqrt(np.sum((lab1 - lab2) ** 2, axis=2))
return np.mean(delta_e)
def calculate_ie(interpolated, gt):
return np.mean(np.abs(interpolated.astype(np.float32) - gt.astype(np.float32)))
def interpolate_frames(img0_path, img1_path, output_path, gt_path=None):
input_tensor, original_size, img0, img1, gt = preprocess_images(img0_path, img1_path, gt_path)
start_time = time.time()
with torch.no_grad():
flow, mask, interpolated = model(input_tensor)
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.4f} seconds")
interpolated_img = tensor_to_image(interpolated[0])
interpolated_img = cv2.resize(interpolated_img, (original_size[1], original_size[0]))
interpolated_img_bgr = cv2.cvtColor(interpolated_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, interpolated_img_bgr)
metrics = {}
if gt is not None:
metrics['psnr'] = calculate_psnr(interpolated_img, gt)
metrics['ssim'] = calculate_ssim(interpolated_img, gt)
metrics['cd'] = calculate_cd(interpolated_img, gt)
metrics['ie'] = calculate_ie(interpolated_img, gt)
print(f"Metrics (compared to ground truth):")
print(f" PSNR: {metrics['psnr']:.4f} dB")
print(f" SSIM: {metrics['ssim']:.4f}")
print(f" Color Difference (CD): {metrics['cd']:.4f}")
print(f" Interpolation Error (IE): {metrics['ie']:.4f}")
return img0, img1, interpolated_img, gt, metrics
def display_results(img0, img1, interpolated, gt, metrics, output_path):
has_gt = gt is not None
plt.figure(figsize=(15, 5 if not has_gt else 10))
plt.subplot(2 if has_gt else 1, 3, 1)
plt.imshow(img0)
plt.title('Frame 1')
plt.axis('off')
plt.subplot(2 if has_gt else 1, 3, 2)
plt.imshow(interpolated)
plt.title('Interpolated Frame')
plt.axis('off')
plt.subplot(2 if has_gt else 1, 3, 3)
plt.imshow(img1)
plt.title('Frame 2')
plt.axis('off')
if has_gt:
plt.subplot(2, 3, 4)
plt.imshow(gt)
plt.title('Ground Truth')
plt.axis('off')
plt.subplot(2, 3, 5)
diff = np.abs(interpolated.astype(np.float32) - gt.astype(np.float32))
plt.imshow(diff.astype(np.uint8))
plt.title('Difference')
plt.axis('off')
plt.subplot(2, 3, 6)
plt.axis('off')
metrics_text = "\n".join([
f"PSNR: {metrics['psnr']:.2f} dB",
f"SSIM: {metrics['ssim']:.4f}",
f"CD: {metrics['cd']:.2f}",
f"IE: {metrics['ie']:.2f}"
])
plt.text(0.1, 0.5, metrics_text, fontsize=12)
plt.title('Metrics')
plt.tight_layout()
plt.savefig(output_path.replace('.png', '_comparison.png'))
plt.show()
#CHANGE FILE PATH
test_pairs = [
("test_frames/frame1.png", "test_frames/frame3.png", "results/scene1_interpolated.png", "test_frames/frame2.png"),
]
os.makedirs("results", exist_ok=True)
for test_item in test_pairs:
img0_path, img1_path, output_path = test_item[0], test_item[1], test_item[2]
gt_path = test_item[3] if len(test_item) > 3 else None
print(f"Processing: {img0_path} and {img1_path}")
try:
img0, img1, interpolated, gt, metrics = interpolate_frames(img0_path, img1_path, output_path, gt_path)
display_results(img0, img1, interpolated, gt, metrics, output_path)
except Exception as e:
print(f"Error processing frames: {e}")
import traceback
traceback.print_exc()