| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import, division, print_function |
| |
|
| | import warnings |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.nn.init import constant_, xavier_uniform_ |
| |
|
| | from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch, has_cuda_kernel |
| |
|
| |
|
| | class to_channels_first(nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x): |
| | return x.permute(0, 3, 1, 2) |
| |
|
| |
|
| | class to_channels_last(nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x): |
| | return x.permute(0, 2, 3, 1) |
| |
|
| |
|
| | def build_norm_layer(dim, |
| | norm_layer, |
| | in_format='channels_last', |
| | out_format='channels_last', |
| | eps=1e-6): |
| | layers = [] |
| | if norm_layer == 'BN': |
| | if in_format == 'channels_last': |
| | layers.append(to_channels_first()) |
| | layers.append(nn.BatchNorm2d(dim)) |
| | if out_format == 'channels_last': |
| | layers.append(to_channels_last()) |
| | elif norm_layer == 'LN': |
| | if in_format == 'channels_first': |
| | layers.append(to_channels_last()) |
| | layers.append(nn.LayerNorm(dim, eps=eps)) |
| | if out_format == 'channels_first': |
| | layers.append(to_channels_first()) |
| | else: |
| | raise NotImplementedError( |
| | f'build_norm_layer does not support {norm_layer}') |
| | return nn.Sequential(*layers) |
| |
|
| |
|
| | def build_act_layer(act_layer): |
| | if act_layer == 'ReLU': |
| | return nn.ReLU(inplace=True) |
| | elif act_layer == 'SiLU': |
| | return nn.SiLU(inplace=True) |
| | elif act_layer == 'GELU': |
| | return nn.GELU() |
| |
|
| | raise NotImplementedError(f'build_act_layer does not support {act_layer}') |
| |
|
| |
|
| | def _is_power_of_2(n): |
| | if (not isinstance(n, int)) or (n < 0): |
| | raise ValueError( |
| | 'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n))) |
| |
|
| | return (n & (n - 1) == 0) and n != 0 |
| |
|
| |
|
| | class CenterFeatureScaleModule(nn.Module): |
| | def forward(self, |
| | query, |
| | center_feature_scale_proj_weight, |
| | center_feature_scale_proj_bias): |
| | center_feature_scale = F.linear(query, |
| | weight=center_feature_scale_proj_weight, |
| | bias=center_feature_scale_proj_bias).sigmoid() |
| | return center_feature_scale |
| |
|
| |
|
| | class DCNv3_pytorch(nn.Module): |
| | def __init__( |
| | self, |
| | channels=64, |
| | kernel_size=3, |
| | dw_kernel_size=None, |
| | stride=1, |
| | pad=1, |
| | dilation=1, |
| | group=4, |
| | offset_scale=1.0, |
| | act_layer='GELU', |
| | norm_layer='LN', |
| | center_feature_scale=False, |
| | remove_center=False, |
| | ): |
| | """ |
| | DCNv3 Module |
| | :param channels |
| | :param kernel_size |
| | :param stride |
| | :param pad |
| | :param dilation |
| | :param group |
| | :param offset_scale |
| | :param act_layer |
| | :param norm_layer |
| | """ |
| | super().__init__() |
| | if channels % group != 0: |
| | raise ValueError( |
| | f'channels must be divisible by group, but got {channels} and {group}') |
| | _d_per_group = channels // group |
| | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size |
| | |
| | if not _is_power_of_2(_d_per_group): |
| | warnings.warn( |
| | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " |
| | 'which is more efficient in our CUDA implementation.') |
| |
|
| | self.offset_scale = offset_scale |
| | self.channels = channels |
| | self.kernel_size = kernel_size |
| | self.dw_kernel_size = dw_kernel_size |
| | self.stride = stride |
| | self.dilation = dilation |
| | self.pad = pad |
| | self.group = group |
| | self.group_channels = channels // group |
| | self.offset_scale = offset_scale |
| | self.center_feature_scale = center_feature_scale |
| | self.remove_center = int(remove_center) |
| |
|
| | self.dw_conv = nn.Sequential( |
| | nn.Conv2d( |
| | channels, |
| | channels, |
| | kernel_size=dw_kernel_size, |
| | stride=1, |
| | padding=(dw_kernel_size - 1) // 2, |
| | groups=channels), |
| | build_norm_layer( |
| | channels, |
| | norm_layer, |
| | 'channels_first', |
| | 'channels_last'), |
| | build_act_layer(act_layer)) |
| | self.offset = nn.Linear( |
| | channels, |
| | group * (kernel_size * kernel_size - remove_center) * 2) |
| | self.mask = nn.Linear( |
| | channels, |
| | group * (kernel_size * kernel_size - remove_center)) |
| | self.input_proj = nn.Linear(channels, channels) |
| | self.output_proj = nn.Linear(channels, channels) |
| | self._reset_parameters() |
| |
|
| | if center_feature_scale: |
| | self.center_feature_scale_proj_weight = nn.Parameter( |
| | torch.zeros((group, channels), dtype=torch.float)) |
| | self.center_feature_scale_proj_bias = nn.Parameter( |
| | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) |
| | self.center_feature_scale_module = CenterFeatureScaleModule() |
| |
|
| | def _reset_parameters(self): |
| | constant_(self.offset.weight.data, 0.) |
| | constant_(self.offset.bias.data, 0.) |
| | constant_(self.mask.weight.data, 0.) |
| | constant_(self.mask.bias.data, 0.) |
| | xavier_uniform_(self.input_proj.weight.data) |
| | constant_(self.input_proj.bias.data, 0.) |
| | xavier_uniform_(self.output_proj.weight.data) |
| | constant_(self.output_proj.bias.data, 0.) |
| |
|
| | def forward(self, input): |
| | """ |
| | :param query (N, H, W, C) |
| | :return output (N, H, W, C) |
| | """ |
| | N, H, W, _ = input.shape |
| |
|
| | x = self.input_proj(input) |
| | x_proj = x |
| |
|
| | x1 = input.permute(0, 3, 1, 2) |
| | x1 = self.dw_conv(x1) |
| | offset = self.offset(x1) |
| | mask = self.mask(x1).reshape(N, H, W, self.group, -1) |
| | mask = F.softmax(mask, -1).reshape(N, H, W, -1) |
| |
|
| | x = dcnv3_core_pytorch( |
| | x, offset, mask, |
| | self.kernel_size, self.kernel_size, |
| | self.stride, self.stride, |
| | self.pad, self.pad, |
| | self.dilation, self.dilation, |
| | self.group, self.group_channels, |
| | self.offset_scale, self.remove_center) |
| | if self.center_feature_scale: |
| | center_feature_scale = self.center_feature_scale_module( |
| | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) |
| | |
| | center_feature_scale = center_feature_scale[..., None].repeat( |
| | 1, 1, 1, 1, self.channels // self.group).flatten(-2) |
| | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale |
| | x = self.output_proj(x) |
| |
|
| | return x |
| |
|
| |
|
| | class DCNv3(nn.Module): |
| | def __init__( |
| | self, |
| | channels=64, |
| | kernel_size=3, |
| | dw_kernel_size=None, |
| | stride=1, |
| | pad=1, |
| | dilation=1, |
| | group=4, |
| | offset_scale=1.0, |
| | act_layer='GELU', |
| | norm_layer='LN', |
| | center_feature_scale=False, |
| | remove_center=False, |
| | ): |
| | """ |
| | DCNv3 Module |
| | :param channels |
| | :param kernel_size |
| | :param stride |
| | :param pad |
| | :param dilation |
| | :param group |
| | :param offset_scale |
| | :param act_layer |
| | :param norm_layer |
| | """ |
| | super().__init__() |
| | if channels % group != 0: |
| | raise ValueError( |
| | f'channels must be divisible by group, but got {channels} and {group}') |
| | _d_per_group = channels // group |
| | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size |
| | |
| | if not _is_power_of_2(_d_per_group): |
| | warnings.warn( |
| | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " |
| | 'which is more efficient in our CUDA implementation.') |
| |
|
| | self.offset_scale = offset_scale |
| | self.channels = channels |
| | self.kernel_size = kernel_size |
| | self.dw_kernel_size = dw_kernel_size |
| | self.stride = stride |
| | self.dilation = dilation |
| | self.pad = pad |
| | self.group = group |
| | self.group_channels = channels // group |
| | self.offset_scale = offset_scale |
| | self.center_feature_scale = center_feature_scale |
| | self.remove_center = int(remove_center) |
| |
|
| | if self.remove_center and self.kernel_size % 2 == 0: |
| | raise ValueError('remove_center is only compatible with odd kernel size.') |
| |
|
| | self.dw_conv = nn.Sequential( |
| | nn.Conv2d( |
| | channels, |
| | channels, |
| | kernel_size=dw_kernel_size, |
| | stride=1, |
| | padding=(dw_kernel_size - 1) // 2, |
| | groups=channels), |
| | build_norm_layer( |
| | channels, |
| | norm_layer, |
| | 'channels_first', |
| | 'channels_last'), |
| | build_act_layer(act_layer)) |
| | self.offset = nn.Linear( |
| | channels, |
| | group * (kernel_size * kernel_size - remove_center) * 2) |
| | self.mask = nn.Linear( |
| | channels, |
| | group * (kernel_size * kernel_size - remove_center)) |
| | self.input_proj = nn.Linear(channels, channels) |
| | self.output_proj = nn.Linear(channels, channels) |
| | self._reset_parameters() |
| |
|
| | if center_feature_scale: |
| | self.center_feature_scale_proj_weight = nn.Parameter( |
| | torch.zeros((group, channels), dtype=torch.float)) |
| | self.center_feature_scale_proj_bias = nn.Parameter( |
| | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) |
| | self.center_feature_scale_module = CenterFeatureScaleModule() |
| |
|
| | def _reset_parameters(self): |
| | constant_(self.offset.weight.data, 0.) |
| | constant_(self.offset.bias.data, 0.) |
| | constant_(self.mask.weight.data, 0.) |
| | constant_(self.mask.bias.data, 0.) |
| | xavier_uniform_(self.input_proj.weight.data) |
| | constant_(self.input_proj.bias.data, 0.) |
| | xavier_uniform_(self.output_proj.weight.data) |
| | constant_(self.output_proj.bias.data, 0.) |
| |
|
| | def forward(self, input): |
| | """ |
| | :param query (N, H, W, C) |
| | :return output (N, H, W, C) |
| | """ |
| | N, H, W, _ = input.shape |
| |
|
| | x = self.input_proj(input) |
| | x_proj = x |
| | dtype = x.dtype |
| |
|
| | x1 = input.permute(0, 3, 1, 2) |
| | x1 = self.dw_conv(x1) |
| | offset = self.offset(x1) |
| | mask = self.mask(x1).reshape(N, H, W, self.group, -1) |
| | mask = F.softmax(mask, -1) |
| | mask = mask.reshape(N, H, W, -1).type(dtype) |
| |
|
| | x = DCNv3Function.apply( |
| | x, offset, mask, |
| | self.kernel_size, self.kernel_size, |
| | self.stride, self.stride, |
| | self.pad, self.pad, |
| | self.dilation, self.dilation, |
| | self.group, self.group_channels, |
| | self.offset_scale, |
| | 256, |
| | self.remove_center) |
| |
|
| | if self.center_feature_scale: |
| | center_feature_scale = self.center_feature_scale_module( |
| | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) |
| | |
| | center_feature_scale = center_feature_scale[..., None].repeat( |
| | 1, 1, 1, 1, self.channels // self.group).flatten(-2) |
| | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale |
| | x = self.output_proj(x) |
| |
|
| | return x |
| |
|