|
|
|
try: |
|
import torch |
|
import torch.nn as nn |
|
except: |
|
pass |
|
|
|
try: |
|
import torchvision as tv |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as F |
|
except: |
|
pass |
|
|
|
try: |
|
import pytorch_lightning as pl |
|
except: |
|
pass |
|
try: |
|
import torchmetrics |
|
import lpips |
|
except: |
|
from argparse import Namespace |
|
torchmetrics = Namespace(Metric=object) |
|
try: |
|
import wandb |
|
except: |
|
pass |
|
|
|
try: |
|
import kornia |
|
except: |
|
pass |
|
|
|
try: |
|
import detectron2 |
|
from detectron2 import model_zoo as _ |
|
from detectron2 import engine as _ |
|
from detectron2 import config as _ |
|
from detectron2 import data as _ |
|
from detectron2.utils import visualizer as _ |
|
except: |
|
pass |
|
|
|
try: |
|
from nvidia import dali |
|
from nvidia.dali.plugin import pytorch as _ |
|
except: |
|
pass |
|
|
|
try: |
|
import cupy |
|
except: |
|
pass |
|
|
|
try: |
|
import skimage |
|
from skimage import measure as _ |
|
from skimage import color as _ |
|
from skimage import segmentation as _ |
|
from skimage import filters as _ |
|
from scipy.spatial.transform import Rotation |
|
except: |
|
pass |
|
|
|
|
|
import math |
|
from twodee_v0 import * |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
def cupy_launch(func, kernel): |
|
return cupy.cuda.compile_with_cache(kernel).get_function(func) |
|
except: |
|
cupy_launch = lambda func,kernel: None |
|
|
|
def reset_parameters(model): |
|
for layer in model.children(): |
|
if hasattr(layer, 'reset_parameters'): |
|
layer.reset_parameters() |
|
return model |
|
|
|
def channel_squeeze(x, dim=1): |
|
a = x.shape[:dim] |
|
b = x.shape[dim+2:] |
|
return x.reshape(*a, -1, *b) |
|
def channel_unsqueeze(x, shape, dim=1): |
|
a = x.shape[:dim] |
|
b = x.shape[dim+1:] |
|
return x.reshape(*a, *shape, *b) |
|
|
|
def default_collate(items, device=None): |
|
return to(dict(torch.utils.data.dataloader.default_collate(items)), device) |
|
def to(x, device): |
|
if device is None: |
|
return x |
|
if issubclass(x.__class__, dict): |
|
return dict({ |
|
k: v.to(device) if isinstance(v, torch.Tensor) else v |
|
for k,v in x.items() |
|
}) |
|
if isinstance(x, torch.Tensor): |
|
return x.to(device) |
|
if isinstance(x, np.ndarray): |
|
return torch.tensor(x).to(device) |
|
assert 0, 'data not understood' |
|
|
|
|
|
|
|
class SSIMMetric(torchmetrics.Metric): |
|
|
|
def __init__(self, window_size=11, **kwargs): |
|
super().__init__(**kwargs) |
|
self.window_size = window_size |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.idd = 0 |
|
self.transform = T.ToPILImage() |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
|
|
for i in range(preds.size()[0]): |
|
|
|
pp = self.transform(preds[i]) |
|
tt = self.transform(target[i]) |
|
|
|
|
|
|
|
|
|
|
|
self.idd += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ssss = calc_ssim(pp, tt) |
|
|
|
self.running_sum += ssss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.running_count += preds.size()[0] |
|
|
|
|
|
|
|
|
|
return |
|
def compute(self): |
|
return self.running_sum.float() / self.running_count |
|
|
|
class SSIMMetricCPU(torchmetrics.Metric): |
|
full_state_update=False |
|
def __init__(self, window_size=11, **kwargs): |
|
super().__init__(**kwargs) |
|
self.window_size = window_size |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
ans = kornia.metrics.ssim(target, preds, self.window_size).mean((1,2,3)) |
|
self.running_sum += ans.sum() |
|
self.running_count += len(ans) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return |
|
def compute(self): |
|
return self.running_sum / self.running_count |
|
|
|
class PSNRMetric(torchmetrics.Metric): |
|
|
|
def __init__(self, data_range=1.0, **kwargs): |
|
super().__init__(**kwargs) |
|
self.data_range = torch.tensor(data_range) |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
ans = -10 * torch.log10( (target-preds).pow(2).mean((1,2,3)) ) |
|
self.running_sum += 20*torch.log10(self.data_range) + ans.sum() |
|
self.running_count += len(ans) |
|
return |
|
def compute(self): |
|
return self.running_sum.float() / self.running_count |
|
class PSNRMetricCPU(torchmetrics.Metric): |
|
full_state_update=False |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
ans = [ |
|
skimage.metrics.peak_signal_noise_ratio( |
|
p.permute(1,2,0).cpu().numpy(), |
|
t.permute(1,2,0).cpu().numpy(), |
|
|
|
) |
|
for p,t in zip(preds, target) |
|
] |
|
self.running_sum += sum(ans) |
|
self.running_count += len(ans) |
|
return |
|
def compute(self): |
|
return self.running_sum / self.running_count |
|
|
|
class LPIPSMetric(torchmetrics.Metric): |
|
full_state_update=False |
|
def __init__(self, net_type='alex', **kwargs): |
|
super().__init__(**kwargs) |
|
self.net_type = net_type |
|
assert self.net_type in ['alex', 'vgg', 'squeeze'] |
|
self.model = lpips.LPIPS(net=self.net_type) |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
if preds.requires_grad: |
|
ans = self.model(preds, target).mean((1,2,3)) |
|
else: |
|
with torch.no_grad(): |
|
ans = self.model(preds, target).mean((1,2,3)) |
|
self.running_sum += ans.sum() |
|
self.running_count += len(ans) |
|
return |
|
def compute(self): |
|
return self.running_sum.float() / self.running_count |
|
class LPIPSLoss(nn.Module): |
|
def __init__(self, net_type='alex', **kwargs): |
|
super().__init__() |
|
self.net_type = net_type |
|
assert self.net_type in ['alex', 'vgg', 'squeeze'] |
|
self.model = lpips.LPIPS(net=self.net_type, **kwargs) |
|
return |
|
def forward(self, preds: torch.Tensor, target: torch.Tensor): |
|
ans = self.model(preds, target).mean((1,2,3)) |
|
return ans |
|
|
|
class LaplacianPyramidLoss(nn.Module): |
|
def __init__(self, n_levels=3, colorspace=None, mode='l1'): |
|
super().__init__() |
|
self.n_levels = n_levels |
|
self.colorspace = colorspace |
|
self.mode = mode |
|
assert self.mode in ['l1', 'l2'] |
|
return |
|
def forward(self, preds, target, force_levels=None, force_mode=None): |
|
if self.colorspace=='lab': |
|
preds = kornia.color.rgb_to_lab(preds.float()) |
|
target = kornia.color.rgb_to_lab(target.float()) |
|
lvls = self.n_levels if force_levels==None else force_levels |
|
preds = kornia.geometry.transform.build_pyramid(preds, lvls) |
|
target = kornia.geometry.transform.build_pyramid(target, lvls) |
|
mode = self.mode if force_mode==None else force_mode |
|
if mode=='l1': |
|
ans = torch.stack([ |
|
(p-t).abs().mean((1,2,3)) |
|
for p,t in zip(preds,target) |
|
]).mean(0) |
|
elif mode=='l2': |
|
ans = torch.stack([ |
|
(p-t).norm(dim=1, keepdim=True).mean((1,2,3)) |
|
for p,t in zip(preds,target) |
|
]).mean(0) |
|
else: |
|
assert 0 |
|
return ans |
|
|
|
def make_grid(tensor, nrow=8, padding=2): |
|
""" |
|
Given a 4D mini-batch Tensor of shape (B x C x H x W), |
|
or a list of images all of the same size, |
|
makes a grid of images |
|
""" |
|
tensorlist = None |
|
if isinstance(tensor, list): |
|
tensorlist = tensor |
|
numImages = len(tensorlist) |
|
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size()) |
|
tensor = tensorlist[0].new(size) |
|
for i in range(numImages): |
|
tensor[i].copy_(tensorlist[i]) |
|
if tensor.dim() == 2: |
|
tensor = tensor.view(1, tensor.size(0), tensor.size(1)) |
|
if tensor.dim() == 3: |
|
if tensor.size(0) == 1: |
|
tensor = torch.cat((tensor, tensor, tensor), 0) |
|
return tensor |
|
if tensor.dim() == 4 and tensor.size(1) == 1: |
|
tensor = torch.cat((tensor, tensor, tensor), 1) |
|
|
|
nmaps = tensor.size(0) |
|
xmaps = min(nrow, nmaps) |
|
ymaps = int(math.ceil(nmaps / xmaps)) |
|
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) |
|
grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max()) |
|
k = 0 |
|
for y in range(ymaps): |
|
for x in range(xmaps): |
|
if k >= nmaps: |
|
break |
|
grid.narrow(1, y*height+1+padding//2,height-padding)\ |
|
.narrow(2, x*width+1+padding//2, width-padding)\ |
|
.copy_(tensor[k]) |
|
k = k + 1 |
|
return grid |
|
|
|
def save_image(tensor, filename, nrow=8, padding=2): |
|
""" |
|
Saves a given Tensor into an image file. |
|
If given a mini-batch tensor, will save the tensor as a grid of images. |
|
""" |
|
|
|
tensor = tensor.cpu() |
|
grid = make_grid(tensor, nrow=nrow, padding=padding) |
|
ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0,2).transpose(0,1).numpy() |
|
im = Image.fromarray(ndarr) |
|
im.save(filename) |
|
|
|
|
|
|