File size: 4,007 Bytes
e98bd8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827f017
 
 
 
e98bd8c
 
 
 
 
 
 
 
 
 
827f017
 
 
e98bd8c
 
 
 
 
 
 
 
 
 
827f017
e98bd8c
 
 
 
 
 
 
 
 
 
827f017
 
 
 
 
 
 
 
 
3899963
e98bd8c
 
 
 
 
 
 
 
 
827f017
 
 
e98bd8c
 
 
 
 
827f017
 
 
e98bd8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827f017
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from ...utils import MODELS
from ...utils import resize
from ..base_module import BaseModule
from ...utils.activation import ConvModule


@MODELS.register_module()
class SegformerHead(BaseModule):
    """The all mlp Head of segformer.

    This head is the implementation of
    `Segformer <https://arxiv.org/abs/2105.15203>` _.

    Args:
        interpolate_mode: The interpolate mode of MLP head upsample operation.
            Default: 'bilinear'.
        use_conv_bias_in_convmodules (bool | str): If True, ConvModules will use bias.
            If False, they won't. If 'auto', they follow ConvModule's default.
            This is added for compatibility with models trained with no conv bias
            when followed by BatchNorm, while keeping default local behavior.
    """

    def __init__(self,
                 in_channels=[32, 64, 160, 256],
                 in_index=[0, 1, 2, 3],
                 channels=256,
                 dropout_ratio=0.1,
                 out_channels=19,
                 norm_cfg=None,
                 align_corners=False,
                 interpolate_mode='bilinear',
                 use_conv_bias_in_convmodules: bool | str = 'auto'
                 ):
        super().__init__()

        self.in_channels = in_channels
        self.in_index = in_index
        self.channels = channels
        self.dropout_ratio = dropout_ratio
        self.out_channels = out_channels
        self.norm_cfg = norm_cfg
        self.align_corners = align_corners
        self.interpolate_mode = interpolate_mode
        self.use_conv_bias_in_convmodules = use_conv_bias_in_convmodules # Speichern des neuen Parameters

        self.act_cfg = dict(type='ReLU')
        self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
        if dropout_ratio > 0:
            self.dropout = nn.Dropout2d(dropout_ratio)
        else:
            self.dropout = None

        num_inputs = len(self.in_channels)

        conv_module_bias_setting = use_conv_bias_in_convmodules
        if use_conv_bias_in_convmodules == 'auto':
            pass
        elif isinstance(use_conv_bias_in_convmodules, bool):
            # Wenn True/False explizit übergeben wird, verwenden wir das
            conv_module_bias_setting = use_conv_bias_in_convmodules
        else:
            raise ValueError("use_conv_bias_in_convmodules must be 'auto', True, or False")


        self.convs = nn.ModuleList()
        for i in range(num_inputs):
            self.convs.append(
                ConvModule(
                    in_channels=self.in_channels[i],
                    out_channels=self.channels,
                    kernel_size=1,
                    stride=1,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg,
                    bias=conv_module_bias_setting # Verwende den bestimmten Bias-Wert
                ))

        self.fusion_conv = ConvModule(
            in_channels=self.channels * num_inputs,
            out_channels=self.channels,
            kernel_size=1,
            norm_cfg=self.norm_cfg,
            bias=conv_module_bias_setting # Verwende den bestimmten Bias-Wert
        )

    def cls_seg(self, feat):
        """Classify each pixel."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    def forward(self, inputs):
        # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
        outs = []
        for idx in range(len(inputs)):
            x = inputs[idx]
            conv = self.convs[idx]
            outs.append(
                resize(
                    input=conv(x),
                    size=inputs[0].shape[2:],
                    mode=self.interpolate_mode,
                    align_corners=self.align_corners))

        out = self.fusion_conv(torch.cat(outs, dim=1))

        out = self.cls_seg(out)

        return out