# modeling_segformer_plusplus.py from typing import Optional, Tuple import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import SemanticSegmenterOutput # Falls du SegFormer direkt importieren willst, musst du sicherstellen, # dass diese Klasse im selben Repo verfügbar ist. from segformer_plusplus.model import create_model class SegformerPlusPlusConfig(PretrainedConfig): model_type = "segformer_plusplus" def __init__( self, backbone: str = "b5", tome_strategy: Optional[str] = "bsm_hq", num_labels: int = 19, id2label: Optional[dict] = None, label2id: Optional[dict] = None, **kwargs, ): self.backbone = backbone self.tome_strategy = tome_strategy self.num_labels = num_labels if id2label is None: id2label = {i: f"class_{i}" for i in range(num_labels)} if label2id is None: label2id = {v: k for k, v in id2label.items()} self.id2label = id2label self.label2id = label2id super().__init__(**kwargs) class SegformerPlusPlusForSemanticSegmentation(PreTrainedModel): config_class = SegformerPlusPlusConfig def __init__(self, config: SegformerPlusPlusConfig): super().__init__(config) self.segformer = create_model( backbone=config.backbone, tome_strategy=config.tome_strategy, out_channels=config.num_labels, pretrained=False, # Kein Pretrained hier – wird über .from_pretrained geladen ) def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, ) -> SemanticSegmenterOutput: logits = self.segformer(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=255) loss = loss_fct(logits, labels.long()) return SemanticSegmenterOutput( loss=loss, logits=logits, )