|
""" |
|
Contains an implementation of the U-Net architecture. |
|
U-Net paper by Ronneberger et al. (2015): https://arxiv.org/abs/1505.04597 |
|
|
|
This implementation is based on the original U-Net architecture, with options for |
|
normalization (batch normalization or layer normalization), bilinear upsampling, |
|
and padding in the convolution layers. |
|
|
|
Author: Ole-Christian Galbo Engstrøm |
|
E-mail: [email protected] |
|
""" |
|
|
|
from typing import Iterable |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
def conv3x3(in_channels: int, out_channels: int, bias: bool, pad: bool) -> nn.Conv2d: |
|
""" |
|
Applies a convolution with a 3x3 kernel. |
|
""" |
|
if pad: |
|
padding = 1 |
|
else: |
|
padding = "valid" |
|
layer = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
return layer |
|
|
|
|
|
def conv_block( |
|
in_channels: int, |
|
out_channels: int, |
|
non_linearity: nn.Module, |
|
normalization: None | str, |
|
bias: bool, |
|
pad: bool, |
|
) -> nn.Sequential: |
|
""" |
|
A block of two convolutional layers, each followed by a non-linearity |
|
and optionally a normalization layer. |
|
|
|
In the U-Net architecture illustration in the U-Net paper, |
|
this corresponds to two blue arrows. |
|
""" |
|
layers = [] |
|
for _ in range(2): |
|
layers.append( |
|
conv3x3( |
|
in_channels=in_channels, out_channels=out_channels, bias=bias, pad=pad |
|
) |
|
) |
|
layers.append(non_linearity) |
|
layers.append( |
|
get_norm_layer(normalization=normalization, in_channels=out_channels) |
|
) |
|
in_channels = out_channels |
|
return nn.Sequential(*layers) |
|
|
|
|
|
def batch_norm(in_channels: int) -> nn.Sequential: |
|
""" |
|
Apply Batch Normalization over the channel dimension. |
|
Batch Normalization paper by Ioffe and Szegedy (2015): https://arxiv.org/abs/1502.03167 |
|
""" |
|
return nn.BatchNorm2d(in_channels, momentum=0.01) |
|
|
|
|
|
class Permute(nn.Module): |
|
""" |
|
Permute the dimensions of a tensor. |
|
""" |
|
|
|
def __init__(self, dims: Iterable[int]): |
|
super().__init__() |
|
self.dims = dims |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return x.permute(self.dims) |
|
|
|
def __repr__(self): |
|
return f'{self.__class__.__name__}({", ".join(map(str, self.dims))})' |
|
|
|
|
|
def layer_norm(in_channels: int) -> nn.Sequential: |
|
""" |
|
Apply Layer Normalization over the channel dimension. |
|
Layer Normalization paper by Ba et al. (2016): https://arxiv.org/abs/1607.06450 |
|
""" |
|
layers = [ |
|
|
|
Permute((0, 2, 3, 1)), |
|
|
|
|
|
nn.LayerNorm(in_channels), |
|
|
|
Permute((0, 3, 1, 2)), |
|
] |
|
return nn.Sequential(*layers) |
|
|
|
|
|
def get_norm_layer(normalization: None | str, in_channels: int) -> nn.Module: |
|
""" |
|
Get the normalization layer based on the specified type. |
|
Either 'bn' for batch normalization, 'ln' for layer normalization, |
|
or None for no normalization layer. |
|
""" |
|
if normalization == "bn": |
|
return batch_norm(in_channels) |
|
if normalization == "ln": |
|
return layer_norm(in_channels) |
|
return nn.Identity() |
|
|
|
|
|
def copy_and_crop(large: torch.Tensor, small: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Implementation of a copy-and-crop block in the U-Net architecture. |
|
Copy the large image and crop it to the size of the small image. |
|
The large image is cropped in the middle, and then the two images are |
|
concatenated along the channel dimension. |
|
|
|
In the U-Net architecture illustration in the U-Net paper, |
|
this corresponds to a gray arrow. |
|
""" |
|
large_height, large_width = large.shape[-2:] |
|
small_height, small_width = small.shape[-2:] |
|
start_x = (large_height - small_height) // 2 |
|
start_y = (large_width - small_width) // 2 |
|
cropped_large = large[ |
|
..., start_x : start_x + small_height, start_y : start_y + small_width |
|
] |
|
return torch.cat([cropped_large, small], dim=-3) |
|
|
|
|
|
class ContractionBlock(nn.Module): |
|
""" |
|
Implementation of a contraction block in the U-Net architecture. |
|
This block consists of a max pooling layer followed by a convolution block. |
|
|
|
In the U-Net architecture illustration in the U-Net paper, this corresponds to |
|
one red arrow followed by the subsequent two blue arrows. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
non_linearity: nn.Module, |
|
nonormalization: None | str, |
|
bias: bool, |
|
pad: bool, |
|
): |
|
super().__init__() |
|
self.max_pool = self._max_pool() |
|
self.conv_block = conv_block( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
non_linearity=non_linearity, |
|
normalization=nonormalization, |
|
bias=bias, |
|
pad=pad, |
|
) |
|
|
|
def _max_pool(self) -> nn.MaxPool2d: |
|
layer = nn.MaxPool2d(kernel_size=2, stride=2) |
|
return layer |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.max_pool(x) |
|
x = self.conv_block(x) |
|
return x |
|
|
|
|
|
class Upsample(nn.Module): |
|
""" |
|
Implementation of an upsampling block in the U-Net architecture. |
|
This block consists of either a transposed convolution or bilinear upsampling, |
|
followed by a convolution block. |
|
|
|
In the U-Net architecture illustration in the U-Net paper, this corresponds to |
|
one green arrow. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
non_linearity, |
|
normalization: None | str, |
|
bias: bool, |
|
bilinear: bool, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.non_linearity = non_linearity |
|
self.normalization = normalization |
|
self.bias = bias |
|
self.bilinear = bilinear |
|
self.up = self._upsample(in_channels, out_channels) |
|
|
|
def _upsample(self, in_channels: int, out_channels: int) -> nn.Sequential: |
|
if self.bilinear: |
|
up = self._up_bilinear(in_channels, out_channels) |
|
else: |
|
up = self._up_trans_conv2x2(in_channels, out_channels) |
|
return up |
|
|
|
def _up_trans_conv2x2(self, in_channels: int, out_channels: int) -> nn.Sequential: |
|
layers = [ |
|
nn.ConvTranspose2d( |
|
in_channels, out_channels, kernel_size=2, stride=2, bias=self.bias |
|
), |
|
self.non_linearity, |
|
] |
|
layers.append(get_norm_layer(self.normalization, out_channels)) |
|
return nn.Sequential(*layers) |
|
|
|
def _up_bilinear(self, in_channels: int, out_channels: int) -> nn.Sequential: |
|
layers = [ |
|
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=True), |
|
nn.Conv2d( |
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1 |
|
), |
|
self.non_linearity, |
|
] |
|
layers.append(get_norm_layer(self.normalization, out_channels)) |
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.up(x) |
|
|
|
|
|
class ExpansionBlock(nn.Module): |
|
""" |
|
Implementation of an expansion block in the U-Net architecture. |
|
This block consists of an upsampling block followed by a copy-and-crop block and |
|
a convolution block. |
|
|
|
In the U-Net architecture illustration in the U-Net paper, this corresponds to |
|
one green arrow followed by a gray arrow and then two blue arrows. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
non_linearity: nn.Module, |
|
normalization: None | str, |
|
bias: bool, |
|
bilinear: bool, |
|
pad: bool, |
|
): |
|
super().__init__() |
|
self.pad = pad |
|
self.upsample = Upsample( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
non_linearity=non_linearity, |
|
normalization=normalization, |
|
bias=bias, |
|
bilinear=bilinear, |
|
) |
|
self.conv_block = self.conv_block = conv_block( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
non_linearity=non_linearity, |
|
normalization=normalization, |
|
bias=bias, |
|
pad=pad, |
|
) |
|
|
|
def forward(self, large: torch.Tensor, small: torch.Tensor) -> torch.Tensor: |
|
x = self.upsample(small) |
|
if self.pad: |
|
diff_h = large.shape[-2] - x.shape[-2] |
|
diff_w = large.shape[-1] - x.shape[-1] |
|
pad_left = diff_w // 2 |
|
pad_right = diff_w - pad_left |
|
pad_top = diff_h // 2 |
|
pad_bottom = diff_h - pad_top |
|
x = F.pad( |
|
x, |
|
(pad_left, pad_right, pad_top, pad_bottom), |
|
mode="constant", |
|
value=0.0, |
|
) |
|
x = copy_and_crop(large, x) |
|
x = self.conv_block(x) |
|
return x |
|
|
|
|
|
class UNet(nn.Module): |
|
""" |
|
in_channels : int\\ |
|
Number of input channels. |
|
|
|
out_channels : int\\ |
|
Number of output channels |
|
|
|
pad : bool, default=True\\ |
|
If True use padding in the convolution layers, preserving the input size. |
|
If False, the output size will be reduced compared to the input size. |
|
|
|
bilinear : bool, default=True\\ |
|
If True use bilinear upsampling. |
|
If False use transposed convolution. |
|
|
|
normalization: None | str, default=None\\ |
|
If None use no normalization. |
|
If 'bn' use batch normalization. |
|
If 'ln' use layer normalization. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
pad: bool = True, |
|
bilinear: bool = True, |
|
normalization: None | str = None, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.pad = pad |
|
self.bilinear = bilinear |
|
self.normalization = normalization |
|
if self.normalization not in [None, "bn", "ln"]: |
|
raise ValueError( |
|
"Normalization must be None, 'bn' for batch normalization," |
|
"or 'ln' for layer normalization" |
|
) |
|
|
|
|
|
self.bias_conv = normalization is None |
|
self.non_linearity = nn.ReLU(inplace=True) |
|
self.intermediate_channels = [64 * 2**i for i in range(5)] |
|
self.first_convs = conv_block( |
|
in_channels=in_channels, |
|
out_channels=self.intermediate_channels[0], |
|
non_linearity=self.non_linearity, |
|
normalization=self.normalization, |
|
bias=self.bias_conv, |
|
pad=self.pad, |
|
) |
|
self.last_conv = nn.Conv2d( |
|
self.intermediate_channels[0], out_channels, kernel_size=1 |
|
) |
|
|
|
self.contraction1 = self._get_contraction_block( |
|
in_channels=self.intermediate_channels[0], |
|
out_channels=self.intermediate_channels[1], |
|
) |
|
self.contraction2 = self._get_contraction_block( |
|
in_channels=self.intermediate_channels[1], |
|
out_channels=self.intermediate_channels[2], |
|
) |
|
self.contraction3 = self._get_contraction_block( |
|
in_channels=self.intermediate_channels[2], |
|
out_channels=self.intermediate_channels[3], |
|
) |
|
self.contraction4 = self._get_contraction_block( |
|
in_channels=self.intermediate_channels[3], |
|
out_channels=self.intermediate_channels[4], |
|
) |
|
self.expansion4 = self._get_expansion_block( |
|
in_channels=self.intermediate_channels[4], |
|
out_channels=self.intermediate_channels[3], |
|
) |
|
self.expansion3 = self._get_expansion_block( |
|
in_channels=self.intermediate_channels[3], |
|
out_channels=self.intermediate_channels[2], |
|
) |
|
self.expansion2 = self._get_expansion_block( |
|
in_channels=self.intermediate_channels[2], |
|
out_channels=self.intermediate_channels[1], |
|
) |
|
self.expansion1 = self._get_expansion_block( |
|
in_channels=self.intermediate_channels[1], |
|
out_channels=self.intermediate_channels[0], |
|
) |
|
|
|
|
|
for m in self.modules(): |
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
|
nn.init.kaiming_normal_(m.weight) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def _get_contraction_block( |
|
self, in_channels: int, out_channels: int |
|
) -> ContractionBlock: |
|
return ContractionBlock( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
non_linearity=self.non_linearity, |
|
nonormalization=self.normalization, |
|
bias=self.bias_conv, |
|
pad=self.pad, |
|
) |
|
|
|
def _get_expansion_block( |
|
self, in_channels: int, out_channels: int |
|
) -> ExpansionBlock: |
|
return ExpansionBlock( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
non_linearity=self.non_linearity, |
|
normalization=self.normalization, |
|
bias=self.bias_conv, |
|
bilinear=self.bilinear, |
|
pad=self.pad, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x1 = self.first_convs(x) |
|
x2 = self.contraction1(x1) |
|
x3 = self.contraction2(x2) |
|
x4 = self.contraction3(x3) |
|
x5 = self.contraction4(x4) |
|
x = self.expansion4(x4, x5) |
|
x = self.expansion3(x3, x) |
|
x = self.expansion2(x2, x) |
|
x = self.expansion1(x1, x) |
|
x = self.last_conv(x) |
|
return x |
|
|