from typing import Optional, Tuple import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from segformer_plusplus.model import create_model class SegformerPlusPlusConfig(PretrainedConfig): model_type = "segformer_plusplus" def __init__( self, backbone: str = "b5", tome_strategy: Optional[str] = "bsm_hq", num_labels: int = 19, id2label: Optional[dict] = None, label2id: Optional[dict] = None, **kwargs, ): self.backbone = backbone self.tome_strategy = tome_strategy self.num_labels = num_labels if id2label is None: id2label = {i: f"class_{i}" for i in range(num_labels)} if label2id is None: label2id = {v: k for k, v in id2label.items()} self.id2label = id2label self.label2id = label2id super().__init__(**kwargs) class SegformerPlusPlusForSemanticSegmentation(PreTrainedModel): config_class = SegformerPlusPlusConfig def __init__(self, config: SegformerPlusPlusConfig): super().__init__(config) self.segformer = create_model( backbone=config.backbone, tome_strategy=config.tome_strategy, out_channels=config.num_labels, pretrained=False, )