File size: 433 Bytes
2483c4f 85ebba9 2483c4f 85ebba9 2483c4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from transformers import PreTrainedModel
from .configuration_segformer_plusplus import SegformerPlusPlusConfig
from .build_model import SegFormer
class SegformerPlusPlusModel(PreTrainedModel):
config_class = SegformerPlusPlusConfig
def __init__(self, config: SegformerPlusPlusConfig):
super().__init__(config)
self.model = SegFormer(config)
def forward(self, x, **kwargs):
return self.model(x)
|