import os from .utils import MODELS, imagenet_weights from .utils import tome_presets from .model.base_module import BaseModule from .configs.config.config import Config from .utils.build_functions import build_model_from_cfg class SegFormer(BaseModule): """ This class represents a SegFormer model that allows for the application of token merging. Attributes: backbone (BaseModule): MixVisionTransformer backbone decode_head (BaseModule): SegFormer head """ def __init__(self, cfg): """ Initialize the SegFormer model. Args: cfg (Config): an mmengine Config object, which defines the backbone, head and token merging strategy used. """ super().__init__() self.backbone = build_model_from_cfg(cfg.backbone, registry=MODELS) self.decode_head = build_model_from_cfg(cfg.decode_head, registry=MODELS) def forward(self, x): """ Forward pass of the model. Args: x (torch.Tensor): input tensor of shape [B, C, H, W] Returns: torch.Tensor: output tensor """ x = self.backbone(x) x = self.decode_head(x) return x def create_model( backbone: str = 'b0', tome_strategy: str = None, out_channels: int = 19, pretrained: bool = False, ): """ Create a SegFormer model using the predefined SegFormer backbones from the MiT series (b0-b5). Args: backbone (str): backbone name (e.g. 'b0') tome_strategy (str | list(dict)): select strategy from presets ('bsm_hq', 'bsm_fast', 'n2d_2x2') or define a custom strategy using a list, that contains of dictionaries, in which the strategies for the stage are defined out_channels (int): number of output channels (e.g. 19 for the cityscapes semantic segmentation task) pretrained: use pretrained (imagenet) weights Returns: BaseModule: SegFormer model """ backbone = backbone.lower() assert backbone in [f'b{i}' for i in range(6)] wd = os.path.dirname(os.path.abspath(__file__)) cfg = Config.fromfile(os.path.join(wd, 'configs', f'segformer_mit_{backbone}.py')) cfg.decode_head.out_channels = out_channels if tome_strategy is not None: if tome_strategy not in list(tome_presets.keys()): print("Using custom merging strategy.") cfg.backbone.tome_cfg = tome_presets[tome_strategy] # load imagenet weights if pretrained: cfg.backbone.init_cfg = dict(type='Pretrained', checkpoint=imagenet_weights[backbone]) return SegFormer(cfg) def create_custom_model( model_cfg: Config, tome_strategy: list[dict] = None, ): """ Create a SegFormer model with customizable backbone and head. Args: model_cfg (Config): backbone name (e.g. 'b0') tome_strategy (list(dict)): custom token merging strategy Returns: BaseModule: SegFormer model """ if tome_strategy is not None: model_cfg.backbone.tome_cfg = tome_strategy return SegFormer(model_cfg)