Tim77777767 commited on
Commit
2483c4f
·
1 Parent(s): 10e4799

Anpassungen für HF-Upload

Browse files
segformer_plusplus/configuration_segformer_plusplus.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class SegformerPlusPlusConfig(PretrainedConfig):
4
+ model_type = "segformer_plus_plus"
5
+
6
+ def __init__(self, backbone_cfg=None, decode_head_cfg=None, out_channels=19, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.backbone_cfg = backbone_cfg
9
+ self.decode_head_cfg = decode_head_cfg
10
+ self.out_channels = out_channels
segformer_plusplus/modeling_segformer_plusplus.py CHANGED
@@ -1,44 +1,13 @@
1
- from typing import Optional, Tuple
2
- import torch.nn as nn
3
- from transformers import PreTrainedModel, PretrainedConfig
4
- from segformer_plusplus.model import create_model
5
 
6
-
7
- class SegformerPlusPlusConfig(PretrainedConfig):
8
- model_type = "segformer_plusplus"
9
-
10
- def __init__(
11
- self,
12
- backbone: str = "b5",
13
- tome_strategy: Optional[str] = "bsm_hq",
14
- num_labels: int = 19,
15
- id2label: Optional[dict] = None,
16
- label2id: Optional[dict] = None,
17
- **kwargs,
18
- ):
19
- self.backbone = backbone
20
- self.tome_strategy = tome_strategy
21
- self.num_labels = num_labels
22
-
23
- if id2label is None:
24
- id2label = {i: f"class_{i}" for i in range(num_labels)}
25
- if label2id is None:
26
- label2id = {v: k for k, v in id2label.items()}
27
-
28
- self.id2label = id2label
29
- self.label2id = label2id
30
-
31
- super().__init__(**kwargs)
32
-
33
-
34
- class SegformerPlusPlusForSemanticSegmentation(PreTrainedModel):
35
  config_class = SegformerPlusPlusConfig
36
 
37
  def __init__(self, config: SegformerPlusPlusConfig):
38
  super().__init__(config)
39
- self.segformer = create_model(
40
- backbone=config.backbone,
41
- tome_strategy=config.tome_strategy,
42
- out_channels=config.num_labels,
43
- pretrained=False,
44
- )
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration_segformer_plusplus import SegformerPlusPlusConfig
3
+ from .build_model import SegFormer
 
4
 
5
+ class SegformerPlusPlusModel(PreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  config_class = SegformerPlusPlusConfig
7
 
8
  def __init__(self, config: SegformerPlusPlusConfig):
9
  super().__init__(config)
10
+ self.model = SegFormer(config)
11
+
12
+ def forward(self, x, **kwargs):
13
+ return self.model(x)