|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...utils import MODELS |
|
from ...utils import resize |
|
from ..base_module import BaseModule |
|
from ...utils.activation import ConvModule |
|
|
|
|
|
@MODELS.register_module() |
|
class SegformerHead(BaseModule): |
|
"""The all mlp Head of segformer. |
|
|
|
This head is the implementation of |
|
`Segformer <https://arxiv.org/abs/2105.15203>` _. |
|
|
|
Args: |
|
interpolate_mode: The interpolate mode of MLP head upsample operation. |
|
Default: 'bilinear'. |
|
use_conv_bias_in_convmodules (bool | str): If True, ConvModules will use bias. |
|
If False, they won't. If 'auto', they follow ConvModule's default. |
|
This is added for compatibility with models trained with no conv bias |
|
when followed by BatchNorm, while keeping default local behavior. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels=[32, 64, 160, 256], |
|
in_index=[0, 1, 2, 3], |
|
channels=256, |
|
dropout_ratio=0.1, |
|
out_channels=19, |
|
norm_cfg=None, |
|
align_corners=False, |
|
interpolate_mode='bilinear', |
|
use_conv_bias_in_convmodules: bool | str = 'auto' |
|
): |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.in_index = in_index |
|
self.channels = channels |
|
self.dropout_ratio = dropout_ratio |
|
self.out_channels = out_channels |
|
self.norm_cfg = norm_cfg |
|
self.align_corners = align_corners |
|
self.interpolate_mode = interpolate_mode |
|
self.use_conv_bias_in_convmodules = use_conv_bias_in_convmodules |
|
|
|
self.act_cfg = dict(type='ReLU') |
|
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) |
|
if dropout_ratio > 0: |
|
self.dropout = nn.Dropout2d(dropout_ratio) |
|
else: |
|
self.dropout = None |
|
|
|
num_inputs = len(self.in_channels) |
|
|
|
conv_module_bias_setting = use_conv_bias_in_convmodules |
|
if use_conv_bias_in_convmodules == 'auto': |
|
pass |
|
elif isinstance(use_conv_bias_in_convmodules, bool): |
|
|
|
conv_module_bias_setting = use_conv_bias_in_convmodules |
|
else: |
|
raise ValueError("use_conv_bias_in_convmodules must be 'auto', True, or False") |
|
|
|
|
|
self.convs = nn.ModuleList() |
|
for i in range(num_inputs): |
|
self.convs.append( |
|
ConvModule( |
|
in_channels=self.in_channels[i], |
|
out_channels=self.channels, |
|
kernel_size=1, |
|
stride=1, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
bias=conv_module_bias_setting |
|
)) |
|
|
|
self.fusion_conv = ConvModule( |
|
in_channels=self.channels * num_inputs, |
|
out_channels=self.channels, |
|
kernel_size=1, |
|
norm_cfg=self.norm_cfg, |
|
bias=conv_module_bias_setting |
|
) |
|
|
|
def cls_seg(self, feat): |
|
"""Classify each pixel.""" |
|
if self.dropout is not None: |
|
feat = self.dropout(feat) |
|
output = self.conv_seg(feat) |
|
return output |
|
|
|
def forward(self, inputs): |
|
|
|
outs = [] |
|
for idx in range(len(inputs)): |
|
x = inputs[idx] |
|
conv = self.convs[idx] |
|
outs.append( |
|
resize( |
|
input=conv(x), |
|
size=inputs[0].shape[2:], |
|
mode=self.interpolate_mode, |
|
align_corners=self.align_corners)) |
|
|
|
out = self.fusion_conv(torch.cat(outs, dim=1)) |
|
|
|
out = self.cls_seg(out) |
|
|
|
return out |