File size: 1,971 Bytes
02508fb
 
66c5431
02508fb
d21bb56
 
 
02508fb
 
 
 
 
 
 
 
 
66c5431
02508fb
d21bb56
02508fb
 
 
 
 
 
 
 
 
 
 
 
 
 
d21bb56
eacc0a7
1a260cd
02508fb
d21bb56
 
 
eacc0a7
d21bb56
827f017
 
02508fb
 
 
e4634c2
 
66c5431
e4634c2
eacc0a7
9659a3a
 
 
66c5431
e4634c2
eacc0a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from transformers import PreTrainedModel

from segformer_plusplus.model.backbone.mit import MixVisionTransformer
from mix_vision_transformer_config import MySegformerConfig
from segformer_plusplus.model.head.segformer_head import SegformerHead


class MySegformerForSemanticSegmentation(PreTrainedModel):
    config_class = MySegformerConfig
    base_model_prefix = "my_segformer"

    def __init__(self, config):
        super().__init__(config)

        # Backbone (MixVisionTransformer)
        self.backbone = MixVisionTransformer(
            embed_dims=config.embed_dims[0],
            num_stages=config.num_stages,
            num_layers=config.num_layers,
            num_heads=config.num_heads,
            patch_sizes=config.patch_sizes,
            strides=config.strides,
            sr_ratios=config.sr_ratios,
            mlp_ratio=config.mlp_ratio,
            qkv_bias=config.qkv_bias,
            drop_rate=config.drop_rate,
            attn_drop_rate=config.attn_drop_rate,
            drop_path_rate=config.drop_path_rate,
            out_indices=config.out_indices
        )

        # Head initialization
        in_channels = config.decode_head["in_channels"]

        self.segmentation_head = SegformerHead(
            in_channels=in_channels,
            in_index=list(config.out_indices),
            channels=config.decode_head["channels"],
            dropout_ratio=config.decode_head["dropout_ratio"],
            align_corners=config.decode_head["align_corners"],
            interpolate_mode=config.decode_head["interpolate_mode"],
            use_conv_bias_in_convmodules=False
        )

        self.post_init()

    def forward(self, x):
        features = self.backbone(x)

        # Debug: Ausgabe der Shapes der Backbone-Features
        for i, f in enumerate(features):
            print(f"Feature {i}: shape = {f.shape}")

        logits = self.segmentation_head(features)

        return logits