import torch import torch.nn as nn from transformers import PreTrainedModel from segformer_plusplus.model.backbone.mit import MixVisionTransformer from mix_vision_transformer_config import MySegformerConfig from segformer_plusplus.model.head.segformer_head import SegformerHead class MySegformerForSemanticSegmentation(PreTrainedModel): config_class = MySegformerConfig base_model_prefix = "my_segformer" def __init__(self, config): super().__init__(config) # Backbone (MixVisionTransformer) self.backbone = MixVisionTransformer( embed_dims=config.embed_dims[0], num_stages=config.num_stages, num_layers=config.num_layers, num_heads=config.num_heads, patch_sizes=config.patch_sizes, strides=config.strides, sr_ratios=config.sr_ratios, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, drop_rate=config.drop_rate, attn_drop_rate=config.attn_drop_rate, drop_path_rate=config.drop_path_rate, out_indices=config.out_indices ) # Head initialization in_channels = config.decode_head["in_channels"] self.segmentation_head = SegformerHead( in_channels=in_channels, in_index=list(config.out_indices), channels=config.decode_head["channels"], dropout_ratio=config.decode_head["dropout_ratio"], align_corners=config.decode_head["align_corners"], interpolate_mode=config.decode_head["interpolate_mode"], use_conv_bias_in_convmodules=False ) self.post_init() def forward(self, x): features = self.backbone(x) # Debug: Ausgabe der Shapes der Backbone-Features for i, f in enumerate(features): print(f"Feature {i}: shape = {f.shape}") logits = self.segmentation_head(features) return logits