animationInterpolation / generate_eval.py
pineappleSoup's picture
Upload folder using huggingface_hub
57db94b verified
raw
history blame
6.65 kB
import os
import torch
import cv2
import numpy as np
import shutil
from combined import IFNet, warp
import torchvision.transforms as transforms
import time
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create eval_data directory
eval_dir = "eval_data"
if os.path.exists(eval_dir):
# Clear the directory if it already exists
shutil.rmtree(eval_dir)
os.makedirs(eval_dir)
print(f"Created directory: {eval_dir}")
# Initialize model
model = IFNet().to(device)
# Load checkpoint if available
checkpoint_path = "save_checkpoints/model_epoch_50.pth"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']} with PSNR: {checkpoint.get('psnr', 'N/A')} dB")
else:
print("No checkpoint found, using uninitialized model")
model.eval()
# Define the preprocessing transforms as in your test script
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Function to preprocess images (same as in your test.py)
def preprocess_images(img0_path, img1_path, gt_path=None):
# Read images
img0 = cv2.imread(img0_path)
img1 = cv2.imread(img1_path)
# Check if images were loaded successfully
if img0 is None or img1 is None:
raise ValueError(f"Could not read images: {img0_path}, {img1_path}")
gt = None
if gt_path and os.path.exists(gt_path):
gt = cv2.imread(gt_path)
if gt is None:
print(f"Warning: Could not read ground truth image: {gt_path}")
gt = None
else:
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
# Save original dimensions for later
original_size = (img0.shape[0], img0.shape[1])
# Store original images for display
orig_img0 = img0.copy()
orig_img1 = img1.copy()
# Resize to model's expected input size
img0_resized = cv2.resize(img0, (256, 256))
img1_resized = cv2.resize(img1, (256, 256))
# Apply transformations
img0_tensor = transform(img0_resized)
img1_tensor = transform(img1_resized)
# Stack tensors - make sure everything is on the same device
input_tensor = torch.cat((img0_tensor, img1_tensor), 0).unsqueeze(0).to(device)
return input_tensor, original_size, orig_img0, orig_img1, gt
# Function to denormalize and convert tensor to image (same as in your test.py)
def tensor_to_image(tensor):
# Make sure tensor is on CPU for numpy conversion
tensor = tensor.cpu()
# Denormalize
tensor = tensor * 0.5 + 0.5
tensor = tensor.clamp(0, 1)
# Convert to numpy array
img = tensor.numpy().transpose(1, 2, 0) * 255
return img.astype(np.uint8)
# Counter for output filename indexing
counter = 1
# Process all subdirectories in the dataset
dataset_dir = "datasets/test_2k"
print(f"Looking for frames in: {dataset_dir}")
# Get all immediate subdirectories in the test_2k folder
subdirs = []
try:
# Get all items in the dataset directory
items = os.listdir(dataset_dir)
# Filter to get only directories
subdirs = [item for item in items if os.path.isdir(os.path.join(dataset_dir, item))]
print(f"Found {len(subdirs)} subdirectories in {dataset_dir}")
except Exception as e:
print(f"Error listing subdirectories: {e}")
# Process each subdirectory
for subdir in subdirs:
subdir_path = os.path.join(dataset_dir, subdir)
# Check if this directory contains the required frames
frame1_path = os.path.join(subdir_path, "frame1.png")
frame2_path = os.path.join(subdir_path, "frame2.png")
frame3_path = os.path.join(subdir_path, "frame3.png")
# Check for different possible extensions if .png doesn't exist
if not os.path.exists(frame1_path):
for ext in ['.jpg', '.jpeg', '']:
test_path = os.path.join(subdir_path, f"frame1{ext}")
if os.path.exists(test_path):
frame1_path = test_path
break
if not os.path.exists(frame2_path):
for ext in ['.jpg', '.jpeg', '']:
test_path = os.path.join(subdir_path, f"frame2{ext}")
if os.path.exists(test_path):
frame2_path = test_path
break
if not os.path.exists(frame3_path):
for ext in ['.jpg', '.jpeg', '']:
test_path = os.path.join(subdir_path, f"frame3{ext}")
if os.path.exists(test_path):
frame3_path = test_path
break
if not (os.path.exists(frame1_path) and os.path.exists(frame2_path) and os.path.exists(frame3_path)):
print(f"Skipping {subdir_path} - missing required frames")
continue
print(f"Processing {subdir_path}")
try:
# Preprocess images
input_tensor, original_size, orig_img0, orig_img1, gt = preprocess_images(
frame1_path, frame3_path, frame2_path
)
# Generate interpolation
start_time = time.time()
with torch.no_grad():
# Call the model to generate the interpolated frame
flow, mask, interpolated = model(input_tensor)
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.4f} seconds")
# Convert output tensor to image
interpolated_img = tensor_to_image(interpolated[0])
# Resize back to original dimensions if needed
if interpolated_img.shape[:2] != original_size:
interpolated_img = cv2.resize(interpolated_img, (original_size[1], original_size[0]))
# Save the interpolated frame and ground truth
output_io = os.path.join(eval_dir, f"io{counter}.png")
output_gt = os.path.join(eval_dir, f"gt{counter}.png")
# Save images (convert RGB to BGR for OpenCV)
cv2.imwrite(output_io, cv2.cvtColor(interpolated_img, cv2.COLOR_RGB2BGR))
cv2.imwrite(output_gt, cv2.cvtColor(gt, cv2.COLOR_RGB2BGR))
print(f"Saved pair {counter}: {output_io} and {output_gt}")
# Increment the counter for the next pair
counter += 1
except Exception as e:
print(f"Error processing {subdir_path}: {e}")
import traceback
traceback.print_exc()
print(f"Processing complete. Generated {counter-1} pairs of images in {eval_dir}")