Tim77777767
commited on
Commit
·
2483c4f
1
Parent(s):
10e4799
Anpassungen für HF-Upload
Browse files
segformer_plusplus/configuration_segformer_plusplus.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class SegformerPlusPlusConfig(PretrainedConfig):
|
4 |
+
model_type = "segformer_plus_plus"
|
5 |
+
|
6 |
+
def __init__(self, backbone_cfg=None, decode_head_cfg=None, out_channels=19, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
+
self.backbone_cfg = backbone_cfg
|
9 |
+
self.decode_head_cfg = decode_head_cfg
|
10 |
+
self.out_channels = out_channels
|
segformer_plusplus/modeling_segformer_plusplus.py
CHANGED
@@ -1,44 +1,13 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
from
|
4 |
-
from segformer_plusplus.model import create_model
|
5 |
|
6 |
-
|
7 |
-
class SegformerPlusPlusConfig(PretrainedConfig):
|
8 |
-
model_type = "segformer_plusplus"
|
9 |
-
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
backbone: str = "b5",
|
13 |
-
tome_strategy: Optional[str] = "bsm_hq",
|
14 |
-
num_labels: int = 19,
|
15 |
-
id2label: Optional[dict] = None,
|
16 |
-
label2id: Optional[dict] = None,
|
17 |
-
**kwargs,
|
18 |
-
):
|
19 |
-
self.backbone = backbone
|
20 |
-
self.tome_strategy = tome_strategy
|
21 |
-
self.num_labels = num_labels
|
22 |
-
|
23 |
-
if id2label is None:
|
24 |
-
id2label = {i: f"class_{i}" for i in range(num_labels)}
|
25 |
-
if label2id is None:
|
26 |
-
label2id = {v: k for k, v in id2label.items()}
|
27 |
-
|
28 |
-
self.id2label = id2label
|
29 |
-
self.label2id = label2id
|
30 |
-
|
31 |
-
super().__init__(**kwargs)
|
32 |
-
|
33 |
-
|
34 |
-
class SegformerPlusPlusForSemanticSegmentation(PreTrainedModel):
|
35 |
config_class = SegformerPlusPlusConfig
|
36 |
|
37 |
def __init__(self, config: SegformerPlusPlusConfig):
|
38 |
super().__init__(config)
|
39 |
-
self.
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
pretrained=False,
|
44 |
-
)
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from .configuration_segformer_plusplus import SegformerPlusPlusConfig
|
3 |
+
from .build_model import SegFormer
|
|
|
4 |
|
5 |
+
class SegformerPlusPlusModel(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
config_class = SegformerPlusPlusConfig
|
7 |
|
8 |
def __init__(self, config: SegformerPlusPlusConfig):
|
9 |
super().__init__(config)
|
10 |
+
self.model = SegFormer(config)
|
11 |
+
|
12 |
+
def forward(self, x, **kwargs):
|
13 |
+
return self.model(x)
|
|
|
|