File size: 433 Bytes
2483c4f
 
 
85ebba9
2483c4f
85ebba9
 
 
 
2483c4f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import PreTrainedModel
from .configuration_segformer_plusplus import SegformerPlusPlusConfig
from .build_model import SegFormer

class SegformerPlusPlusModel(PreTrainedModel):
    config_class = SegformerPlusPlusConfig

    def __init__(self, config: SegformerPlusPlusConfig):
        super().__init__(config)
        self.model = SegFormer(config)

    def forward(self, x, **kwargs):
        return self.model(x)