from transformers import PretrainedConfig class MySegformerConfig(PretrainedConfig): model_type = "my_segformer" def __init__( self, in_channels=3, # Corrected for SegFormer-B5: list of embedding dimensions for each stage embed_dims=[64, 128, 320, 512], num_stages=4, # Corrected for SegFormer-B5: number of transformer layers in each stage num_layers=[3, 6, 40, 3], num_heads=[1, 2, 5, 8], patch_sizes=[7, 3, 3, 3], strides=[4, 2, 2, 2], sr_ratios=[8, 4, 2, 1], mlp_ratio=4, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, out_indices=(0, 1, 2, 3), num_classes=19, decode_head=None, **kwargs ): super().__init__(**kwargs) self.in_channels = in_channels self.embed_dims = embed_dims # This will now be a list, which is correct for SegFormer self.num_stages = num_stages self.num_layers = num_layers self.num_heads = num_heads self.patch_sizes = patch_sizes self.strides = strides self.sr_ratios = sr_ratios self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.drop_rate = drop_rate self.attn_drop_rate = attn_drop_rate self.drop_path_rate = drop_path_rate self.out_indices = out_indices self.num_classes = num_classes # Optional block for Head-Config (if decode_head not passed) if decode_head is None: decode_head = { # Corrected for SegFormer-B5: input channels for the decode head from each stage "in_channels": [64, 128, 320, 512], "in_index": list(range(self.num_stages)), "channels": 256, "dropout_ratio": 0.1, "out_channels": self.num_classes, "align_corners": False, "interpolate_mode": "bilinear" } self.decode_head = decode_head