Tim77777767 commited on
Commit
02508fb
·
1 Parent(s): 245b43a

Created Files for HF compatibility

Browse files
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "my_segformer",
3
+ "embed_dims": 64,
4
+ "num_stages": 4,
5
+ "num_layers": [3,4,6,3],
6
+ "num_heads": [1,2,4,8],
7
+ "patch_sizes": [7,3,3,3],
8
+ "strides": [4,2,2,2],
9
+ "sr_ratios": [8,4,2,1],
10
+ "mlp_ratio": 4,
11
+ "qkv_bias": true,
12
+ "drop_rate": 0.0,
13
+ "attn_drop_rate": 0.0,
14
+ "drop_path_rate": 0.0,
15
+ "out_indices": [0,1,2,3]
16
+ }
mix_vision_transformer_config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MySegformerConfig(PretrainedConfig):
4
+ model_type = "my_segformer"
5
+
6
+ def __init__(
7
+ self,
8
+ embed_dims=64,
9
+ num_stages=4,
10
+ num_layers=[3, 4, 6, 3],
11
+ num_heads=[1, 2, 4, 8],
12
+ patch_sizes=[7, 3, 3, 3],
13
+ strides=[4, 2, 2, 2],
14
+ sr_ratios=[8, 4, 2, 1],
15
+ mlp_ratio=4,
16
+ qkv_bias=True,
17
+ drop_rate=0.0,
18
+ attn_drop_rate=0.0,
19
+ drop_path_rate=0.0,
20
+ out_indices=(0, 1, 2, 3),
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.embed_dims = embed_dims
25
+ self.num_stages = num_stages
26
+ self.num_layers = num_layers
27
+ self.num_heads = num_heads
28
+ self.patch_sizes = patch_sizes
29
+ self.strides = strides
30
+ self.sr_ratios = sr_ratios
31
+ self.mlp_ratio = mlp_ratio
32
+ self.qkv_bias = qkv_bias
33
+ self.drop_rate = drop_rate
34
+ self.attn_drop_rate = attn_drop_rate
35
+ self.drop_path_rate = drop_path_rate
36
+ self.out_indices = out_indices
modeling_my_segformer.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import torch
3
+ import torch.nn as nn
4
+ from segformer_plusplus.utils import resize
5
+ from segformer_plusplus.model.backbone.mit import MixVisionTransformer # deine Backbone-Importierung
6
+ from mix_vision_transformer_config import MySegformerConfig # deine Config-Importierung
7
+
8
+ # Head-Implementierung (etwas vereinfacht und angepasst)
9
+ class SegformerHead(nn.Module):
10
+ def __init__(self,
11
+ in_channels=[64, 128, 256, 512], # anpassen je nach Backbone-Ausgabe!
12
+ in_index=[0, 1, 2, 3],
13
+ channels=256,
14
+ dropout_ratio=0.1,
15
+ out_channels=19, # Anzahl Klassen, anpassen!
16
+ norm_cfg=None,
17
+ align_corners=False,
18
+ interpolate_mode='bilinear'):
19
+ super().__init__()
20
+ self.in_channels = in_channels
21
+ self.in_index = in_index
22
+ self.channels = channels
23
+ self.dropout_ratio = dropout_ratio
24
+ self.out_channels = out_channels
25
+ self.norm_cfg = norm_cfg
26
+ self.align_corners = align_corners
27
+ self.interpolate_mode = interpolate_mode
28
+
29
+ self.act_cfg = dict(type='ReLU')
30
+ self.conv_seg = nn.Conv2d(channels, out_channels, kernel_size=1)
31
+ self.dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None
32
+
33
+ num_inputs = len(in_channels)
34
+ assert num_inputs == len(in_index)
35
+
36
+ from segformer_plusplus.utils.activation import ConvModule
37
+
38
+ self.convs = nn.ModuleList()
39
+ for i in range(num_inputs):
40
+ self.convs.append(
41
+ ConvModule(
42
+ in_channels=in_channels[i],
43
+ out_channels=channels,
44
+ kernel_size=1,
45
+ stride=1,
46
+ norm_cfg=norm_cfg,
47
+ act_cfg=self.act_cfg))
48
+
49
+ self.fusion_conv = ConvModule(
50
+ in_channels=channels * num_inputs,
51
+ out_channels=channels,
52
+ kernel_size=1,
53
+ norm_cfg=norm_cfg)
54
+
55
+ def cls_seg(self, feat):
56
+ if self.dropout is not None:
57
+ feat = self.dropout(feat)
58
+ return self.conv_seg(feat)
59
+
60
+ def forward(self, inputs):
61
+ outs = []
62
+ for idx in range(len(inputs)):
63
+ x = inputs[idx]
64
+ conv = self.convs[idx]
65
+ outs.append(
66
+ resize(
67
+ input=conv(x),
68
+ size=inputs[0].shape[2:],
69
+ mode=self.interpolate_mode,
70
+ align_corners=self.align_corners))
71
+
72
+ out = self.fusion_conv(torch.cat(outs, dim=1))
73
+ out = self.cls_seg(out)
74
+ return out
75
+
76
+
77
+ class MySegformerForSemanticSegmentation(PreTrainedModel):
78
+ config_class = MySegformerConfig
79
+ base_model_prefix = "my_segformer"
80
+
81
+ def __init__(self, config):
82
+ super().__init__(config)
83
+
84
+ # Backbone initialisieren mit Parametern aus Config
85
+ self.backbone = MixVisionTransformer(
86
+ embed_dims=config.embed_dims,
87
+ num_stages=config.num_stages,
88
+ num_layers=config.num_layers,
89
+ num_heads=config.num_heads,
90
+ patch_sizes=config.patch_sizes,
91
+ strides=config.strides,
92
+ sr_ratios=config.sr_ratios,
93
+ mlp_ratio=config.mlp_ratio,
94
+ qkv_bias=config.qkv_bias,
95
+ drop_rate=config.drop_rate,
96
+ attn_drop_rate=config.attn_drop_rate,
97
+ drop_path_rate=config.drop_path_rate,
98
+ out_indices=config.out_indices
99
+ )
100
+
101
+ # Head initialisieren, out_channels aus config oder fix setzen
102
+ self.segmentation_head = SegformerHead(
103
+ in_channels=[64, 128, 256, 512], # <- Anpassen, je nachdem wie Backbone ausgibt!
104
+ out_channels=config.num_classes if hasattr(config, 'num_classes') else 19,
105
+ dropout_ratio=0.1,
106
+ align_corners=False
107
+ )
108
+
109
+ self.post_init()
110
+
111
+ def forward(self, x):
112
+ features = self.backbone(x)
113
+ segmentation_output = self.segmentation_head(features)
114
+ return segmentation_output
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eea970e0387b05e22ec603d4d2f4a3f73b38fd84bcf104f451b79043009339a3
3
+ size 328287283
segformer_plusplus/start_cityscape_benchmark.py CHANGED
@@ -19,6 +19,10 @@ if args.checkpoint:
19
  print(f"Loading checkpoint: {checkpoint_path}")
20
  checkpoint = torch.load(checkpoint_path)
21
  model.load_state_dict(checkpoint)
 
 
 
 
22
  else:
23
  print("No checkpoint provided – using model as initialized.")
24
 
 
19
  print(f"Loading checkpoint: {checkpoint_path}")
20
  checkpoint = torch.load(checkpoint_path)
21
  model.load_state_dict(checkpoint)
22
+
23
+ # state_dict nach Laden des Checkpoints abspeichern
24
+ #save_path = os.path.join(os.getcwd(), "pytorch_model.bin")
25
+ #torch.save(model.state_dict(), save_path)
26
  else:
27
  print("No checkpoint provided – using model as initialized.")
28