# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from ...utils import MODELS from ...utils import resize from ..base_module import BaseModule from ...utils.activation import ConvModule @MODELS.register_module() class SegformerHead(BaseModule): """The all mlp Head of segformer. This head is the implementation of `Segformer ` _. Args: interpolate_mode: The interpolate mode of MLP head upsample operation. Default: 'bilinear'. """ def __init__(self, in_channels=[32, 64, 160, 256], in_index=[0, 1, 2, 3], channels=256, dropout_ratio=0.1, out_channels=19, norm_cfg=None, align_corners=False, interpolate_mode='bilinear'): super().__init__() self.in_channels = in_channels self.in_index = in_index self.channels = channels self.dropout_ratio = dropout_ratio self.out_channels = out_channels self.norm_cfg = norm_cfg self.align_corners = align_corners self.interpolate_mode = interpolate_mode self.act_cfg = dict(type='ReLU') self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: self.dropout = None num_inputs = len(self.in_channels) assert num_inputs == len(self.in_index) self.convs = nn.ModuleList() for i in range(num_inputs): self.convs.append( ConvModule( in_channels=self.in_channels[i], out_channels=self.channels, kernel_size=1, stride=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion_conv = ConvModule( in_channels=self.channels * num_inputs, out_channels=self.channels, kernel_size=1, norm_cfg=self.norm_cfg) def cls_seg(self, feat): """Classify each pixel.""" if self.dropout is not None: feat = self.dropout(feat) output = self.conv_seg(feat) return output def forward(self, inputs): # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 outs = [] for idx in range(len(inputs)): x = inputs[idx] conv = self.convs[idx] outs.append( resize( input=conv(x), size=inputs[0].shape[2:], mode=self.interpolate_mode, align_corners=self.align_corners)) out = self.fusion_conv(torch.cat(outs, dim=1)) out = self.cls_seg(out) return out