|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
|
|
from segformer_plusplus.model.backbone.mit import MixVisionTransformer |
|
from mix_vision_transformer_config import MySegformerConfig |
|
from segformer_plusplus.model.head.segformer_head import SegformerHead |
|
|
|
|
|
class MySegformerForSemanticSegmentation(PreTrainedModel): |
|
config_class = MySegformerConfig |
|
base_model_prefix = "my_segformer" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
self.backbone = MixVisionTransformer( |
|
embed_dims=config.embed_dims[0], |
|
num_stages=config.num_stages, |
|
num_layers=config.num_layers, |
|
num_heads=config.num_heads, |
|
patch_sizes=config.patch_sizes, |
|
strides=config.strides, |
|
sr_ratios=config.sr_ratios, |
|
mlp_ratio=config.mlp_ratio, |
|
qkv_bias=config.qkv_bias, |
|
drop_rate=config.drop_rate, |
|
attn_drop_rate=config.attn_drop_rate, |
|
drop_path_rate=config.drop_path_rate, |
|
out_indices=config.out_indices |
|
) |
|
|
|
|
|
in_channels = config.decode_head["in_channels"] |
|
|
|
self.segmentation_head = SegformerHead( |
|
in_channels=in_channels, |
|
in_index=list(config.out_indices), |
|
channels=config.decode_head["channels"], |
|
dropout_ratio=config.decode_head["dropout_ratio"], |
|
align_corners=config.decode_head["align_corners"], |
|
interpolate_mode=config.decode_head["interpolate_mode"], |
|
use_conv_bias_in_convmodules=False |
|
) |
|
|
|
self.post_init() |
|
|
|
def forward(self, x): |
|
features = self.backbone(x) |
|
|
|
|
|
for i, f in enumerate(features): |
|
print(f"Feature {i}: shape = {f.shape}") |
|
|
|
logits = self.segmentation_head(features) |
|
|
|
return logits |