File size: 3,142 Bytes
e98bd8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

from .utils import MODELS, imagenet_weights
from .utils import tome_presets
from .model.base_module import BaseModule
from .configs.config.config import Config
from .utils.build_functions import build_model_from_cfg


class SegFormer(BaseModule):
    """
    This class represents a SegFormer model that allows for the application of token merging.

    Attributes:
         backbone (BaseModule): MixVisionTransformer backbone
         decode_head (BaseModule): SegFormer head

    """
    def __init__(self, cfg):
        """
        Initialize the SegFormer model.

        Args:
            cfg (Config): an mmengine Config object, which defines the backbone, head and token merging strategy used.

        """
        super().__init__()
        self.backbone = build_model_from_cfg(cfg.backbone, registry=MODELS)
        self.decode_head = build_model_from_cfg(cfg.decode_head, registry=MODELS)

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): input tensor of shape [B, C, H, W]

        Returns:
            torch.Tensor: output tensor

        """
        x = self.backbone(x)
        x = self.decode_head(x)
        return x


def create_model(
        backbone: str = 'b0',
        tome_strategy: str = None,
        out_channels: int = 19,
        pretrained: bool = False,
):
    """
    Create a SegFormer model using the predefined SegFormer backbones from the MiT series (b0-b5).

    Args:
        backbone (str): backbone name (e.g. 'b0')
        tome_strategy (str | list(dict)): select strategy from presets ('bsm_hq', 'bsm_fast', 'n2d_2x2') or define a
            custom strategy using a list, that contains of dictionaries, in which the strategies for the stage are
            defined
        out_channels (int): number of output channels (e.g. 19 for the cityscapes semantic segmentation task)
        pretrained: use pretrained (imagenet) weights

    Returns:
        BaseModule: SegFormer model

    """
    backbone = backbone.lower()
    assert backbone in [f'b{i}' for i in range(6)]

    wd = os.path.dirname(os.path.abspath(__file__))

    cfg = Config.fromfile(os.path.join(wd, 'configs', f'segformer_mit_{backbone}.py'))

    cfg.decode_head.out_channels = out_channels

    if tome_strategy is not None:
        if tome_strategy not in list(tome_presets.keys()):
            print("Using custom merging strategy.")
        cfg.backbone.tome_cfg = tome_presets[tome_strategy]

    # load imagenet weights
    if pretrained:
        cfg.backbone.init_cfg = dict(type='Pretrained', checkpoint=imagenet_weights[backbone])

    return SegFormer(cfg)


def create_custom_model(
        model_cfg: Config,
        tome_strategy: list[dict] = None,
):
    """
    Create a SegFormer model with customizable backbone and head.

    Args:
        model_cfg (Config): backbone name (e.g. 'b0')
        tome_strategy (list(dict)): custom token merging strategy

    Returns:
        BaseModule: SegFormer model

    """
    if tome_strategy is not None:
        model_cfg.backbone.tome_cfg = tome_strategy

    return SegFormer(model_cfg)