|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
import os |
|
import cv2 |
|
import numpy as np |
|
from torch.amp import autocast, GradScaler |
|
from torch.utils.data.distributed import DistributedSampler |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import os |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
import time |
|
from datetime import timedelta |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
checkpoint_dir = 'save_checkpoints/' |
|
if not os.path.exists(checkpoint_dir): |
|
os.makedirs(checkpoint_dir) |
|
|
|
backwarp_tenGrid = {} |
|
|
|
def warp(tenInput, tenFlow): |
|
if tenFlow.dim() == 3: |
|
tenFlow = tenFlow.unsqueeze(1) |
|
if tenFlow.size(1) != 2: |
|
raise ValueError(f"tenFlow must have 2 channels. Got {tenFlow.size(1)} channels.") |
|
|
|
k = (str(tenFlow.device), str(tenFlow.size())) |
|
if k not in backwarp_tenGrid: |
|
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( |
|
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) |
|
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( |
|
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) |
|
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) |
|
|
|
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), |
|
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) |
|
|
|
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) |
|
return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) |
|
|
|
|
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1): |
|
return nn.Sequential( |
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True), |
|
nn.LeakyReLU(0.1, inplace=True) |
|
) |
|
|
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): |
|
return nn.Sequential( |
|
nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), |
|
nn.LeakyReLU(0.1, inplace=True) |
|
) |
|
|
|
|
|
class IFBlock(nn.Module): |
|
def __init__(self, in_planes, c=48): |
|
super(IFBlock, self).__init__() |
|
self.conv0 = nn.Sequential( |
|
conv(in_planes, c, 3, 2, 1), |
|
conv(c, 2*c, 3, 2, 1), |
|
) |
|
self.convblock = nn.Sequential( |
|
conv(2*c, 2*c), |
|
conv(2*c, 2*c), |
|
conv(2*c, 2*c), |
|
conv(2*c, 2*c), |
|
) |
|
self.lastconv = nn.ConvTranspose2d(2*c, 5, 4, 2, 1) |
|
|
|
def forward(self, x, flow=None, scale=1): |
|
if scale != 1: |
|
x = F.interpolate(x, scale_factor=1./scale, mode="bilinear", align_corners=False) |
|
if flow is not None: |
|
flow = F.interpolate(flow, scale_factor=1./scale, mode="bilinear", align_corners=False) * 1./scale |
|
x = torch.cat((x, flow), 1) |
|
x = self.conv0(x) |
|
x = self.convblock(x) + x |
|
tmp = self.lastconv(x) |
|
tmp = F.interpolate(tmp, scale_factor=scale*2, mode="bilinear", align_corners=False) |
|
flow = tmp[:, :4] * scale*2 |
|
mask = tmp[:, 4:5] |
|
return flow, mask |
|
|
|
|
|
c = 48 |
|
|
|
|
|
class Contextnet(nn.Module): |
|
def __init__(self): |
|
super(Contextnet, self).__init__() |
|
self.conv1 = conv(3, c) |
|
self.conv2 = conv(c, 2*c) |
|
self.conv3 = conv(2*c, 4*c) |
|
|
|
def forward(self, x, flow): |
|
x = self.conv1(x) |
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
f1 = warp(x, flow) |
|
|
|
x = self.conv2(x) |
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
f2 = warp(x, flow) |
|
|
|
x = self.conv3(x) |
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
f3 = warp(x, flow) |
|
|
|
return [f1, f2, f3] |
|
|
|
class Unet(nn.Module): |
|
def __init__(self): |
|
super(Unet, self).__init__() |
|
|
|
self.down0 = conv(17, 2*c) |
|
self.down1 = conv(4*c, 4*c) |
|
self.down2 = conv(8*c, 8*c) |
|
|
|
self.up0 = deconv(8*c, 4*c) |
|
self.up1 = deconv(4*c, 2*c) |
|
self.up2 = deconv(2*c, c) |
|
self.conv = nn.Conv2d(c, 3, 3, 1, 1) |
|
|
|
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): |
|
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) |
|
|
|
c0_0 = F.interpolate(c0[0], size=(s0.size(2), s0.size(3)), mode="bilinear", align_corners=False) |
|
c1_0 = F.interpolate(c1[0], size=(s0.size(2), s0.size(3)), mode="bilinear", align_corners=False) |
|
s1 = self.down1(torch.cat((s0, c0_0, c1_0), 1)) |
|
|
|
c1_1 = F.interpolate(c1[1], size=(s1.size(2), s1.size(3)), mode="bilinear", align_corners=False) |
|
s2 = self.down2(torch.cat((s1, c0_1, c1_1), 1)) |
|
|
|
x = self.up0(s2) |
|
x = self.up1(x) |
|
x = self.up2(x) |
|
x = self.conv(x) |
|
return x |
|
|
|
class IFNet(nn.Module): |
|
def __init__(self): |
|
super(IFNet, self).__init__() |
|
self.block0 = IFBlock(6, c=64) |
|
self.block1 = IFBlock(13+4, c=64) |
|
self.contextnet = Contextnet() |
|
self.unet = Unet() |
|
|
|
def forward(self, x, scale=[4, 2, 1], timestep=0.5): |
|
img0 = x[:, :3] |
|
img1 = x[:, 3:6] |
|
gt = x[:, 6:] if x.shape[1] > 6 else None |
|
|
|
|
|
flow, mask = self.block0(torch.cat((img0, img1), 1), None, scale=scale[0]) |
|
|
|
|
|
warped_img0 = warp(img0, flow[:, :2]) |
|
warped_img1 = warp(img1, flow[:, 2:4]) |
|
|
|
flow_d, mask_d = self.block1(torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[1]) |
|
flow = flow + flow_d |
|
mask = mask + mask_d |
|
|
|
|
|
warped_img0 = warp(img0, flow[:, :2]) |
|
warped_img1 = warp(img1, flow[:, 2:4]) |
|
|
|
|
|
mask_final = torch.sigmoid(mask) |
|
merged = warped_img0 * mask_final + warped_img1 * (1 - mask_final) |
|
|
|
|
|
c0 = self.contextnet(img0, flow[:, :2]) |
|
c1 = self.contextnet(img1, flow[:, 2:4]) |
|
|
|
|
|
refined = self.unet(img0, img1, warped_img0, warped_img1, mask_final, flow, c0, c1) |
|
refined = refined[:, :3] * 2 - 1 |
|
|
|
|
|
if refined.size(2) != merged.size(2) or refined.size(3) != merged.size(3): |
|
refined = F.interpolate(refined, size=(merged.size(2), merged.size(3)), mode='bilinear', align_corners=False) |
|
|
|
|
|
final_output = torch.clamp(merged + refined, 0, 1) |
|
|
|
return flow, mask_final, final_output |
|
|
|
class FrameInterpolationDataset(Dataset): |
|
def __init__(self, data_dir, transform=None, resize=None, cache_size=100, augment=True): |
|
self.data_dir = data_dir |
|
self.transform = transform |
|
self.resize = resize |
|
self.frame_pairs = self._load_frame_pairs() |
|
self.cache = {} |
|
self.cache_size = cache_size |
|
self.augment = augment |
|
|
|
def _load_frame_pairs(self): |
|
frame_pairs = [] |
|
|
|
if not os.path.exists(self.data_dir): |
|
raise ValueError(f"Dataset directory does not exist: {self.data_dir}") |
|
|
|
for seq in os.listdir(self.data_dir): |
|
seq_dir = os.path.join(self.data_dir, seq) |
|
|
|
if not os.path.isdir(seq_dir): |
|
continue |
|
|
|
frames = sorted([f for f in os.listdir(seq_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]) |
|
|
|
if len(frames) < 3: |
|
continue |
|
|
|
for i in range(len(frames) - 2): |
|
frame_pairs.append(( |
|
os.path.join(seq_dir, frames[i]), |
|
os.path.join(seq_dir, frames[i+2]), |
|
os.path.join(seq_dir, frames[i+1]) |
|
)) |
|
|
|
if not frame_pairs: |
|
raise ValueError(f"No valid frame pairs found in {self.data_dir}") |
|
|
|
return frame_pairs |
|
|
|
def __len__(self): |
|
return len(self.frame_pairs) |
|
|
|
def __getitem__(self, idx): |
|
if idx in self.cache: |
|
return self.cache[idx] |
|
|
|
img0_path, img1_path, gt_path = self.frame_pairs[idx] |
|
|
|
img0 = cv2.imread(img0_path) |
|
img1 = cv2.imread(img1_path) |
|
gt = cv2.imread(gt_path) |
|
|
|
if img0 is None or img1 is None or gt is None: |
|
raise ValueError(f"Could not read one of the images: {self.frame_pairs[idx]}") |
|
|
|
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB) |
|
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) |
|
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) |
|
|
|
if self.resize: |
|
img0 = cv2.resize(img0, self.resize, interpolation=cv2.INTER_AREA) |
|
img1 = cv2.resize(img1, self.resize, interpolation=cv2.INTER_AREA) |
|
gt = cv2.resize(gt, self.resize, interpolation=cv2.INTER_AREA) |
|
|
|
if self.augment: |
|
if np.random.random() > 0.5: |
|
img0 = np.flip(img0, axis=1).copy() |
|
img1 = np.flip(img1, axis=1).copy() |
|
gt = np.flip(gt, axis=1).copy() |
|
|
|
if np.random.random() > 0.5: |
|
img0 = np.flip(img0, axis=0).copy() |
|
img1 = np.flip(img1, axis=0).copy() |
|
gt = np.flip(gt, axis=0).copy() |
|
|
|
if np.random.random() > 0.5: |
|
brightness = 0.9 + np.random.random() * 0.2 |
|
img0 = np.clip(img0 * brightness, 0, 255).astype(np.uint8) |
|
img1 = np.clip(img1 * brightness, 0, 255).astype(np.uint8) |
|
gt = np.clip(gt * brightness, 0, 255).astype(np.uint8) |
|
|
|
if self.transform: |
|
img0 = self.transform(img0) |
|
img1 = self.transform(img1) |
|
gt = self.transform(gt) |
|
|
|
result = torch.cat((img0, img1, gt), 0) |
|
|
|
if len(self.cache) < self.cache_size: |
|
self.cache[idx] = result |
|
|
|
return result |
|
|
|
def train_with_amp(model, train_dataloader, val_dataloader, optimizer, scheduler, criterion, |
|
num_epochs=10, patience=5, start_epoch=0, best_val_loss=float('inf'), best_val_psnr=0.0): |
|
scaler = GradScaler() |
|
patience_counter = 0 |
|
|
|
|
|
total_start_time = time.time() |
|
epoch_times = [] |
|
|
|
for epoch in range(start_epoch, num_epochs): |
|
epoch_start_time = time.time() |
|
model.train() |
|
train_loss = 0.0 |
|
|
|
for i, data in enumerate(train_dataloader): |
|
data = data.to(device, non_blocking=True) |
|
|
|
with autocast(device_type='cuda'): |
|
flow, mask, final_output = model(data) |
|
loss = criterion(final_output, data[:, 6:9]) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
train_loss += loss.item() |
|
|
|
if i % 10 == 0: |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_dataloader)}], Loss: {loss.item():.6f}, LR: {scheduler.get_last_lr()[0]:.6f}") |
|
|
|
scheduler.step() |
|
|
|
avg_train_loss = train_loss / len(train_dataloader) |
|
|
|
val_loss, val_psnr = validate_with_amp(model, val_dataloader, criterion) |
|
|
|
epoch_end_time = time.time() |
|
epoch_time = epoch_end_time - epoch_start_time |
|
epoch_times.append(epoch_time) |
|
|
|
avg_epoch_time = sum(epoch_times) / len(epoch_times) |
|
epochs_remaining = num_epochs - (epoch + 1) |
|
est_time_remaining = avg_epoch_time * epochs_remaining |
|
|
|
epoch_time_str = str(timedelta(seconds=int(epoch_time))) |
|
est_remaining_str = str(timedelta(seconds=int(est_time_remaining))) |
|
total_elapsed_str = str(timedelta(seconds=int(time.time() - total_start_time))) |
|
|
|
print(f"Epoch [{epoch+1}/{num_epochs}] completed. Train Loss: {avg_train_loss:.6f}, " |
|
f"Validation Loss: {val_loss:.6f}, Validation PSNR: {val_psnr:.4f} dB") |
|
print(f"Time: {epoch_time_str} | Total: {total_elapsed_str} | Remaining: {est_remaining_str}") |
|
|
|
checkpoint_path = f"{checkpoint_dir}/model_epoch_{epoch+1}.pth" |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'loss': val_loss, |
|
'psnr': val_psnr, |
|
}, checkpoint_path) |
|
|
|
if val_psnr > best_val_psnr: |
|
best_val_psnr = val_psnr |
|
torch.save(model.state_dict(), f"{checkpoint_dir}/best_psnr_model.pth") |
|
print(f"Model saved with improved validation PSNR: {best_val_psnr:.4f} dB") |
|
patience_counter = 0 |
|
elif val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
torch.save(model.state_dict(), f"{checkpoint_dir}/best_loss_model.pth") |
|
print(f"Model saved with improved validation loss: {best_val_loss:.6f}") |
|
patience_counter = 0 |
|
else: |
|
patience_counter += 1 |
|
if patience_counter >= patience: |
|
print(f"Early stopping triggered after {epoch+1} epochs") |
|
break |
|
|
|
total_training_time = time.time() - total_start_time |
|
total_time_str = str(timedelta(seconds=int(total_training_time))) |
|
avg_epoch_time = total_training_time / min(num_epochs, epoch+1) |
|
avg_epoch_time_str = str(timedelta(seconds=int(avg_epoch_time))) |
|
|
|
print(f"Training completed in {total_time_str} ({avg_epoch_time_str} per epoch)") |
|
print(f"Best validation PSNR: {best_val_psnr:.4f} dB") |
|
|
|
with open(f"{checkpoint_dir}/training_time_stats.txt", "w") as f: |
|
f.write(f"Total training time: {total_time_str}\n") |
|
f.write(f"Average epoch time: {avg_epoch_time_str}\n") |
|
f.write(f"Total epochs: {epoch+1}\n") |
|
f.write(f"Best validation PSNR: {best_val_psnr:.4f} dB\n") |
|
f.write(f"Final learning rate: {scheduler.get_last_lr()[0]:.8f}\n") |
|
|
|
f.write("\nEpoch times:\n") |
|
for i, e_time in enumerate(epoch_times): |
|
e_time_str = str(timedelta(seconds=int(e_time))) |
|
f.write(f"Epoch {i+1}: {e_time_str}\n") |
|
|
|
return best_val_psnr, total_time_str |
|
|
|
def validate_with_amp(model, dataloader, criterion): |
|
model.eval() |
|
total_loss = 0.0 |
|
total_psnr = 0.0 |
|
|
|
with torch.no_grad(): |
|
for data in dataloader: |
|
data = data.to(device, non_blocking=True) |
|
|
|
with autocast(device_type='cuda'): |
|
flow, mask, final_output = model(data) |
|
gt = data[:, 6:9] |
|
loss = criterion(final_output, gt) |
|
|
|
total_loss += loss.item() |
|
|
|
mse = F.mse_loss(final_output, gt).item() |
|
if mse > 0: |
|
psnr = 10 * np.log10(1.0 / mse) |
|
else: |
|
psnr = float('inf') |
|
total_psnr += psnr |
|
|
|
avg_loss = total_loss / len(dataloader) |
|
avg_psnr = total_psnr / len(dataloader) |
|
|
|
return avg_loss, avg_psnr |
|
|
|
def load_checkpoint_and_resume(checkpoint_path, model, optimizer, scheduler): |
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
try: |
|
checkpoint = torch.load(checkpoint_path, weights_only=False) |
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
if 'scheduler_state_dict' in checkpoint: |
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
start_epoch = checkpoint['epoch'] + 1 |
|
|
|
best_val_loss = checkpoint.get('loss', float('inf')) |
|
best_val_psnr = checkpoint.get('psnr', 0.0) |
|
|
|
print(f"Successfully loaded checkpoint from epoch {checkpoint['epoch']}") |
|
print(f"Resuming training from epoch {start_epoch}") |
|
print(f"Best validation loss: {best_val_loss:.6f}, Best PSNR: {best_val_psnr:.4f} dB") |
|
|
|
return model, optimizer, scheduler, start_epoch, best_val_loss, best_val_psnr |
|
|
|
except Exception as e: |
|
print(f"Error loading checkpoint: {e}") |
|
raise e |
|
|
|
if __name__ == "__main__": |
|
data_dir_train = "datasets/train_10k" |
|
data_dir_val = "datasets/test_2k" |
|
batch_size = 16 |
|
resize = (256, 256) |
|
world_size = torch.cuda.device_count() |
|
|
|
load_from_checkpoint = True |
|
checkpoint_path = f"{checkpoint_dir}/model_epoch_49.pth" |
|
|
|
|
|
if world_size <= 1: |
|
model = IFNet().to(device) |
|
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5) |
|
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6) |
|
|
|
start_epoch = 0 |
|
best_val_loss = float('inf') |
|
best_val_psnr = 0.0 |
|
|
|
if load_from_checkpoint: |
|
try: |
|
model, optimizer, scheduler, start_epoch, best_val_loss, best_val_psnr = load_checkpoint_and_resume( |
|
checkpoint_path, model, optimizer, scheduler |
|
) |
|
except Exception as e: |
|
print(f"Failed to load checkpoint: {e}") |
|
print("Starting training from scratch instead.") |
|
|
|
model = IFNet().to(device) |
|
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5) |
|
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6) |
|
start_epoch = 0 |
|
best_val_loss = float('inf') |
|
best_val_psnr = 0.0 |
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
train_dataset = FrameInterpolationDataset( |
|
data_dir=data_dir_train, |
|
transform=transform, |
|
resize=resize, |
|
augment=True |
|
) |
|
|
|
val_dataset = FrameInterpolationDataset( |
|
data_dir=data_dir_val, |
|
transform=transform, |
|
resize=resize, |
|
augment=False |
|
) |
|
|
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=4, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
val_dataloader = DataLoader( |
|
val_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=4, |
|
pin_memory=True |
|
) |
|
|
|
criterion = nn.L1Loss() |
|
|
|
best_psnr, total_time = train_with_amp( |
|
model, |
|
train_dataloader, |
|
val_dataloader, |
|
optimizer, |
|
scheduler, |
|
criterion, |
|
num_epochs=50, |
|
patience=10, |
|
start_epoch=start_epoch, |
|
best_val_loss=best_val_loss, |
|
best_val_psnr=best_val_psnr |
|
) |
|
print(f"Training completed in {total_time} with best PSNR of {best_psnr:.4f} dB") |
|
else: |
|
|
|
print("Distributed training not implemented in this example") |