animationInterpolation / combined.py
pineappleSoup's picture
Upload folder using huggingface_hub
57db94b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
import cv2
import numpy as np
from torch.amp import autocast, GradScaler
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.optim.lr_scheduler import CosineAnnealingLR
import time
from datetime import timedelta
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_dir = 'save_checkpoints/'
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
if tenFlow.dim() == 3:
tenFlow = tenFlow.unsqueeze(1)
if tenFlow.size(1) != 2:
raise ValueError(f"tenFlow must have 2 channels. Got {tenFlow.size(1)} channels.")
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
nn.LeakyReLU(0.1, inplace=True) # LeakyReLU instead of ReLU
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
nn.LeakyReLU(0.1, inplace=True) # LeakyReLU instead of ReLU
)
#IFBLOCK 6 conv layers total
class IFBlock(nn.Module):
def __init__(self, in_planes, c=48): # Increased channel count
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1),
conv(c, 2*c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.lastconv = nn.ConvTranspose2d(2*c, 5, 4, 2, 1)
def forward(self, x, flow=None, scale=1):
if scale != 1:
x = F.interpolate(x, scale_factor=1./scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor=1./scale, mode="bilinear", align_corners=False) * 1./scale
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
tmp = self.lastconv(x)
tmp = F.interpolate(tmp, scale_factor=scale*2, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale*2
mask = tmp[:, 4:5]
return flow, mask
#8, 16, 32, 48, 64 (no changes after 64)
c = 48
# Context Layers
class Contextnet(nn.Module):
def __init__(self):
super(Contextnet, self).__init__()
self.conv1 = conv(3, c)
self.conv2 = conv(c, 2*c)
self.conv3 = conv(2*c, 4*c)
def forward(self, x, flow):
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
return [f1, f2, f3]
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
# Input has 17 channels 2×RGB images (6) + 2×warped images (6) + mask (1) + flow (4)
self.down0 = conv(17, 2*c)
self.down1 = conv(4*c, 4*c)
self.down2 = conv(8*c, 8*c)
#3c 3d
self.up0 = deconv(8*c, 4*c)
self.up1 = deconv(4*c, 2*c)
self.up2 = deconv(2*c, c)
self.conv = nn.Conv2d(c, 3, 3, 1, 1)
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
c0_0 = F.interpolate(c0[0], size=(s0.size(2), s0.size(3)), mode="bilinear", align_corners=False)
c1_0 = F.interpolate(c1[0], size=(s0.size(2), s0.size(3)), mode="bilinear", align_corners=False)
s1 = self.down1(torch.cat((s0, c0_0, c1_0), 1))
c1_1 = F.interpolate(c1[1], size=(s1.size(2), s1.size(3)), mode="bilinear", align_corners=False)
s2 = self.down2(torch.cat((s1, c0_1, c1_1), 1))
x = self.up0(s2)
x = self.up1(x)
x = self.up2(x)
x = self.conv(x)
return x
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, c=64)
self.block1 = IFBlock(13+4, c=64)
self.contextnet = Contextnet()
self.unet = Unet()
def forward(self, x, scale=[4, 2, 1], timestep=0.5):
img0 = x[:, :3]
img1 = x[:, 3:6]
gt = x[:, 6:] if x.shape[1] > 6 else None
# First stage flow estimation
flow, mask = self.block0(torch.cat((img0, img1), 1), None, scale=scale[0])
# Second stage flow refinement
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
flow_d, mask_d = self.block1(torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[1])
flow = flow + flow_d
mask = mask + mask_d
# Final warping
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
# Apply mask
mask_final = torch.sigmoid(mask)
merged = warped_img0 * mask_final + warped_img1 * (1 - mask_final)
# Apply contextual enhancement
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
# Apply Unet refinement
refined = self.unet(img0, img1, warped_img0, warped_img1, mask_final, flow, c0, c1)
refined = refined[:, :3] * 2 - 1
if refined.size(2) != merged.size(2) or refined.size(3) != merged.size(3):
refined = F.interpolate(refined, size=(merged.size(2), merged.size(3)), mode='bilinear', align_corners=False)
final_output = torch.clamp(merged + refined, 0, 1)
return flow, mask_final, final_output
class FrameInterpolationDataset(Dataset):
def __init__(self, data_dir, transform=None, resize=None, cache_size=100, augment=True):
self.data_dir = data_dir
self.transform = transform
self.resize = resize
self.frame_pairs = self._load_frame_pairs()
self.cache = {}
self.cache_size = cache_size
self.augment = augment
def _load_frame_pairs(self):
frame_pairs = []
if not os.path.exists(self.data_dir):
raise ValueError(f"Dataset directory does not exist: {self.data_dir}")
for seq in os.listdir(self.data_dir):
seq_dir = os.path.join(self.data_dir, seq)
if not os.path.isdir(seq_dir):
continue
frames = sorted([f for f in os.listdir(seq_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
if len(frames) < 3:
continue
for i in range(len(frames) - 2):
frame_pairs.append((
os.path.join(seq_dir, frames[i]),
os.path.join(seq_dir, frames[i+2]),
os.path.join(seq_dir, frames[i+1])
))
if not frame_pairs:
raise ValueError(f"No valid frame pairs found in {self.data_dir}")
return frame_pairs
def __len__(self):
return len(self.frame_pairs)
def __getitem__(self, idx):
if idx in self.cache:
return self.cache[idx]
img0_path, img1_path, gt_path = self.frame_pairs[idx]
img0 = cv2.imread(img0_path)
img1 = cv2.imread(img1_path)
gt = cv2.imread(gt_path)
if img0 is None or img1 is None or gt is None:
raise ValueError(f"Could not read one of the images: {self.frame_pairs[idx]}")
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
if self.resize:
img0 = cv2.resize(img0, self.resize, interpolation=cv2.INTER_AREA)
img1 = cv2.resize(img1, self.resize, interpolation=cv2.INTER_AREA)
gt = cv2.resize(gt, self.resize, interpolation=cv2.INTER_AREA)
if self.augment:
if np.random.random() > 0.5:
img0 = np.flip(img0, axis=1).copy()
img1 = np.flip(img1, axis=1).copy()
gt = np.flip(gt, axis=1).copy()
if np.random.random() > 0.5:
img0 = np.flip(img0, axis=0).copy()
img1 = np.flip(img1, axis=0).copy()
gt = np.flip(gt, axis=0).copy()
if np.random.random() > 0.5:
brightness = 0.9 + np.random.random() * 0.2
img0 = np.clip(img0 * brightness, 0, 255).astype(np.uint8)
img1 = np.clip(img1 * brightness, 0, 255).astype(np.uint8)
gt = np.clip(gt * brightness, 0, 255).astype(np.uint8)
if self.transform:
img0 = self.transform(img0)
img1 = self.transform(img1)
gt = self.transform(gt)
result = torch.cat((img0, img1, gt), 0)
if len(self.cache) < self.cache_size:
self.cache[idx] = result
return result
def train_with_amp(model, train_dataloader, val_dataloader, optimizer, scheduler, criterion,
num_epochs=10, patience=5, start_epoch=0, best_val_loss=float('inf'), best_val_psnr=0.0):
scaler = GradScaler()
patience_counter = 0
# Runtime tracking
total_start_time = time.time()
epoch_times = []
for epoch in range(start_epoch, num_epochs):
epoch_start_time = time.time()
model.train()
train_loss = 0.0
for i, data in enumerate(train_dataloader):
data = data.to(device, non_blocking=True)
with autocast(device_type='cuda'):
flow, mask, final_output = model(data)
loss = criterion(final_output, data[:, 6:9])
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
if i % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_dataloader)}], Loss: {loss.item():.6f}, LR: {scheduler.get_last_lr()[0]:.6f}")
scheduler.step()
avg_train_loss = train_loss / len(train_dataloader)
val_loss, val_psnr = validate_with_amp(model, val_dataloader, criterion)
epoch_end_time = time.time()
epoch_time = epoch_end_time - epoch_start_time
epoch_times.append(epoch_time)
avg_epoch_time = sum(epoch_times) / len(epoch_times)
epochs_remaining = num_epochs - (epoch + 1)
est_time_remaining = avg_epoch_time * epochs_remaining
epoch_time_str = str(timedelta(seconds=int(epoch_time)))
est_remaining_str = str(timedelta(seconds=int(est_time_remaining)))
total_elapsed_str = str(timedelta(seconds=int(time.time() - total_start_time)))
print(f"Epoch [{epoch+1}/{num_epochs}] completed. Train Loss: {avg_train_loss:.6f}, "
f"Validation Loss: {val_loss:.6f}, Validation PSNR: {val_psnr:.4f} dB")
print(f"Time: {epoch_time_str} | Total: {total_elapsed_str} | Remaining: {est_remaining_str}")
checkpoint_path = f"{checkpoint_dir}/model_epoch_{epoch+1}.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': val_loss,
'psnr': val_psnr,
}, checkpoint_path)
if val_psnr > best_val_psnr:
best_val_psnr = val_psnr
torch.save(model.state_dict(), f"{checkpoint_dir}/best_psnr_model.pth")
print(f"Model saved with improved validation PSNR: {best_val_psnr:.4f} dB")
patience_counter = 0
elif val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), f"{checkpoint_dir}/best_loss_model.pth")
print(f"Model saved with improved validation loss: {best_val_loss:.6f}")
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping triggered after {epoch+1} epochs")
break
total_training_time = time.time() - total_start_time
total_time_str = str(timedelta(seconds=int(total_training_time)))
avg_epoch_time = total_training_time / min(num_epochs, epoch+1)
avg_epoch_time_str = str(timedelta(seconds=int(avg_epoch_time)))
print(f"Training completed in {total_time_str} ({avg_epoch_time_str} per epoch)")
print(f"Best validation PSNR: {best_val_psnr:.4f} dB")
with open(f"{checkpoint_dir}/training_time_stats.txt", "w") as f:
f.write(f"Total training time: {total_time_str}\n")
f.write(f"Average epoch time: {avg_epoch_time_str}\n")
f.write(f"Total epochs: {epoch+1}\n")
f.write(f"Best validation PSNR: {best_val_psnr:.4f} dB\n")
f.write(f"Final learning rate: {scheduler.get_last_lr()[0]:.8f}\n")
f.write("\nEpoch times:\n")
for i, e_time in enumerate(epoch_times):
e_time_str = str(timedelta(seconds=int(e_time)))
f.write(f"Epoch {i+1}: {e_time_str}\n")
return best_val_psnr, total_time_str
def validate_with_amp(model, dataloader, criterion):
model.eval()
total_loss = 0.0
total_psnr = 0.0
with torch.no_grad():
for data in dataloader:
data = data.to(device, non_blocking=True)
with autocast(device_type='cuda'):
flow, mask, final_output = model(data)
gt = data[:, 6:9]
loss = criterion(final_output, gt)
total_loss += loss.item()
mse = F.mse_loss(final_output, gt).item()
if mse > 0:
psnr = 10 * np.log10(1.0 / mse)
else:
psnr = float('inf')
total_psnr += psnr
avg_loss = total_loss / len(dataloader)
avg_psnr = total_psnr / len(dataloader)
return avg_loss, avg_psnr
def load_checkpoint_and_resume(checkpoint_path, model, optimizer, scheduler):
print(f"Loading checkpoint from {checkpoint_path}...")
try:
checkpoint = torch.load(checkpoint_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_val_loss = checkpoint.get('loss', float('inf'))
best_val_psnr = checkpoint.get('psnr', 0.0)
print(f"Successfully loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Resuming training from epoch {start_epoch}")
print(f"Best validation loss: {best_val_loss:.6f}, Best PSNR: {best_val_psnr:.4f} dB")
return model, optimizer, scheduler, start_epoch, best_val_loss, best_val_psnr
except Exception as e:
print(f"Error loading checkpoint: {e}")
raise e
if __name__ == "__main__":
data_dir_train = "datasets/train_10k"
data_dir_val = "datasets/test_2k"
batch_size = 16
resize = (256, 256)
world_size = torch.cuda.device_count()
load_from_checkpoint = True
checkpoint_path = f"{checkpoint_dir}/model_epoch_49.pth"
# For single GPU
if world_size <= 1:
model = IFNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
start_epoch = 0
best_val_loss = float('inf')
best_val_psnr = 0.0
if load_from_checkpoint:
try:
model, optimizer, scheduler, start_epoch, best_val_loss, best_val_psnr = load_checkpoint_and_resume(
checkpoint_path, model, optimizer, scheduler
)
except Exception as e:
print(f"Failed to load checkpoint: {e}")
print("Starting training from scratch instead.")
model = IFNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
start_epoch = 0
best_val_loss = float('inf')
best_val_psnr = 0.0
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = FrameInterpolationDataset(
data_dir=data_dir_train,
transform=transform,
resize=resize,
augment=True
)
val_dataset = FrameInterpolationDataset(
data_dir=data_dir_val,
transform=transform,
resize=resize,
augment=False
)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
criterion = nn.L1Loss()
best_psnr, total_time = train_with_amp(
model,
train_dataloader,
val_dataloader,
optimizer,
scheduler,
criterion,
num_epochs=50,
patience=10,
start_epoch=start_epoch,
best_val_loss=best_val_loss,
best_val_psnr=best_val_psnr
)
print(f"Training completed in {total_time} with best PSNR of {best_psnr:.4f} dB")
else:
print("Distributed training not implemented in this example")