animationInterpolation / eval /pytorch_v0.py
pineappleSoup's picture
Upload folder using huggingface_hub
57db94b verified
try:
import torch
import torch.nn as nn
except:
pass
try:
import torchvision as tv
import torchvision.transforms as T
import torchvision.transforms.functional as F
except:
pass
try:
import pytorch_lightning as pl
except:
pass
try:
import torchmetrics
import lpips
except:
from argparse import Namespace
torchmetrics = Namespace(Metric=object)
try:
import wandb
except:
pass
try:
import kornia
except:
pass
try:
import detectron2
from detectron2 import model_zoo as _
from detectron2 import engine as _
from detectron2 import config as _
from detectron2 import data as _
from detectron2.utils import visualizer as _
except:
pass
try:
from nvidia import dali
from nvidia.dali.plugin import pytorch as _
except:
pass
try:
import cupy
except:
pass
try:
import skimage
from skimage import measure as _
from skimage import color as _
from skimage import segmentation as _
from skimage import filters as _
from scipy.spatial.transform import Rotation
except:
pass
# from pytorch_msssim import ssim as calc_ssim
import math
from twodee_v0 import *
from PIL import Image
# from torchmetrics.functional import structural_similarity_index_measure as calc_ssim
#################### UTILITIES ####################
try:
# @cupy.memoize(for_each_device=True)
def cupy_launch(func, kernel):
return cupy.cuda.compile_with_cache(kernel).get_function(func)
except:
cupy_launch = lambda func,kernel: None
def reset_parameters(model):
for layer in model.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
return model
def channel_squeeze(x, dim=1):
a = x.shape[:dim]
b = x.shape[dim+2:]
return x.reshape(*a, -1, *b)
def channel_unsqueeze(x, shape, dim=1):
a = x.shape[:dim]
b = x.shape[dim+1:]
return x.reshape(*a, *shape, *b)
def default_collate(items, device=None):
return to(dict(torch.utils.data.dataloader.default_collate(items)), device)
def to(x, device):
if device is None:
return x
if issubclass(x.__class__, dict):
return dict({
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k,v in x.items()
})
if isinstance(x, torch.Tensor):
return x.to(device)
if isinstance(x, np.ndarray):
return torch.tensor(x).to(device)
assert 0, 'data not understood'
#################### LOSSES + METRICS ####################
class SSIMMetric(torchmetrics.Metric):
# torchmetrics has memory leak
def __init__(self, window_size=11, **kwargs):
super().__init__(**kwargs)
self.window_size = window_size
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.idd = 0
self.transform = T.ToPILImage()
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
for i in range(preds.size()[0]):
pp = self.transform(preds[i])
tt = self.transform(target[i])
# if (self.idd % 500 == 0):
# pp.save('/home/jiaming/eccvsample' + '/eccvP{}.png'.format(self.idd/500))
# tt.save('/home/jiaming/eccvsample' + '/eccvT{}.png'.format(self.idd/500))
# pp = Image.open('/home/jiaming/eccvsample' + '/eccvP{}.png'.format(self.idd))
# tt = Image.open('/home/jiaming/eccvsample' + '/eccvT{}.png'.format(self.idd))
self.idd += 1
# ppnp = np.array(pp)
# ttnp = np.array(tt)
# ppten = torch.tensor(ppnp).permute(2,0,1)
# ttten = torch.tensor(ttnp).permute(2,0,1)
# print(ppten.size())
# pp = F.pil_to_tensor(pp)
# tt = F.pil_to_tensor(tt)
ssss = calc_ssim(pp, tt)
# print(ssss)
self.running_sum += ssss
# print(preds[i])
# print(target[i])
# print(preds.size())
# print(target.size())
# ssss = calc_ssim(preds, target, size_average=False, data_range=1.0)
self.running_count += preds.size()[0]
# ans = kornia.metrics.ssim(target, preds, self.window_size).mean((1,2,3))
# self.running_sum += ans.sum()
# self.running_count += len(ans)
return
def compute(self):
return self.running_sum.float() / self.running_count
class SSIMMetricCPU(torchmetrics.Metric):
full_state_update=False
def __init__(self, window_size=11, **kwargs):
super().__init__(**kwargs)
self.window_size = window_size
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
ans = kornia.metrics.ssim(target, preds, self.window_size).mean((1,2,3))
self.running_sum += ans.sum()
self.running_count += len(ans)
# for idx in range(preds.size()[0]):
# save_image(preds[idx], '/home/jiaming/eccvsample' + '/eccvP{}.png'.format(self.i))
# save_image(target[idx], '/home/jiaming/eccvsample' + '/eccvT{}.png'.format(self.i))
# self.i += 1
# ans = calc_ssim(
# preds,
# target,
# size_average=False,
# data_range=1
# )for p,t in zip(preds, target)
# print(ans)
# skimage.metrics.structural_similarity(
# p.permute(1,2,0).cpu().numpy(),
# t.permute(1,2,0).cpu().numpy(),
# multichannel=True,
# gaussian=True,
# # data_range=255,
# )
#
# self.running_sum += sum(ans)
# self.running_count += len(ans)
return
def compute(self):
return self.running_sum / self.running_count
class PSNRMetric(torchmetrics.Metric):
# torchmetrics averages samples before taking log
def __init__(self, data_range=1.0, **kwargs):
super().__init__(**kwargs)
self.data_range = torch.tensor(data_range)
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
ans = -10 * torch.log10( (target-preds).pow(2).mean((1,2,3)) )
self.running_sum += 20*torch.log10(self.data_range) + ans.sum()
self.running_count += len(ans)
return
def compute(self):
return self.running_sum.float() / self.running_count
class PSNRMetricCPU(torchmetrics.Metric):
full_state_update=False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
ans = [
skimage.metrics.peak_signal_noise_ratio(
p.permute(1,2,0).cpu().numpy(),
t.permute(1,2,0).cpu().numpy(),
# data_range=255,
)
for p,t in zip(preds, target)
]
self.running_sum += sum(ans)
self.running_count += len(ans)
return
def compute(self):
return self.running_sum / self.running_count
class LPIPSMetric(torchmetrics.Metric):
full_state_update=False
def __init__(self, net_type='alex', **kwargs):
super().__init__(**kwargs)
self.net_type = net_type
assert self.net_type in ['alex', 'vgg', 'squeeze']
self.model = lpips.LPIPS(net=self.net_type)
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum')
return
def update(self, preds: torch.Tensor, target: torch.Tensor):
if preds.requires_grad:
ans = self.model(preds, target).mean((1,2,3))
else:
with torch.no_grad():
ans = self.model(preds, target).mean((1,2,3))
self.running_sum += ans.sum()
self.running_count += len(ans)
return
def compute(self):
return self.running_sum.float() / self.running_count
class LPIPSLoss(nn.Module):
def __init__(self, net_type='alex', **kwargs):
super().__init__()
self.net_type = net_type
assert self.net_type in ['alex', 'vgg', 'squeeze']
self.model = lpips.LPIPS(net=self.net_type, **kwargs)
return
def forward(self, preds: torch.Tensor, target: torch.Tensor):
ans = self.model(preds, target).mean((1,2,3))
return ans
class LaplacianPyramidLoss(nn.Module):
def __init__(self, n_levels=3, colorspace=None, mode='l1'):
super().__init__()
self.n_levels = n_levels
self.colorspace = colorspace
self.mode = mode
assert self.mode in ['l1', 'l2']
return
def forward(self, preds, target, force_levels=None, force_mode=None):
if self.colorspace=='lab':
preds = kornia.color.rgb_to_lab(preds.float())
target = kornia.color.rgb_to_lab(target.float())
lvls = self.n_levels if force_levels==None else force_levels
preds = kornia.geometry.transform.build_pyramid(preds, lvls)
target = kornia.geometry.transform.build_pyramid(target, lvls)
mode = self.mode if force_mode==None else force_mode
if mode=='l1':
ans = torch.stack([
(p-t).abs().mean((1,2,3))
for p,t in zip(preds,target)
]).mean(0)
elif mode=='l2':
ans = torch.stack([
(p-t).norm(dim=1, keepdim=True).mean((1,2,3))
for p,t in zip(preds,target)
]).mean(0)
else:
assert 0
return ans
def make_grid(tensor, nrow=8, padding=2):
"""
Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images
"""
tensorlist = None
if isinstance(tensor, list):
tensorlist = tensor
numImages = len(tensorlist)
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
tensor = tensorlist[0].new(size)
for i in range(numImages):
tensor[i].copy_(tensorlist[i])
if tensor.dim() == 2: # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
if tensor.dim() == 3: # single image
if tensor.size(0) == 1:
tensor = torch.cat((tensor, tensor, tensor), 0)
return tensor
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(nmaps / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max())
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid.narrow(1, y*height+1+padding//2,height-padding)\
.narrow(2, x*width+1+padding//2, width-padding)\
.copy_(tensor[k])
k = k + 1
return grid
def save_image(tensor, filename, nrow=8, padding=2):
"""
Saves a given Tensor into an image file.
If given a mini-batch tensor, will save the tensor as a grid of images.
"""
tensor = tensor.cpu()
grid = make_grid(tensor, nrow=nrow, padding=padding)
ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0,2).transpose(0,1).numpy()
im = Image.fromarray(ndarr)
im.save(filename)