SegformerPlusPlus / segformer_plusplus /modeling_segformer_plusplus.py
Tim77777767
Anpassungen für HF-Upload
2483c4f
raw
history blame
433 Bytes
from transformers import PreTrainedModel
from .configuration_segformer_plusplus import SegformerPlusPlusConfig
from .build_model import SegFormer
class SegformerPlusPlusModel(PreTrainedModel):
config_class = SegformerPlusPlusConfig
def __init__(self, config: SegformerPlusPlusConfig):
super().__init__(config)
self.model = SegFormer(config)
def forward(self, x, **kwargs):
return self.model(x)