Tim77777767 commited on
Commit
66c5431
·
1 Parent(s): c620883

Anpassungen an der modeling, sodass der Head nun direkt importiert, und nicht selbst implementiert ist

Browse files
Files changed (1) hide show
  1. modeling_my_segformer.py +15 -85
modeling_my_segformer.py CHANGED
@@ -1,79 +1,10 @@
1
- from transformers import PreTrainedModel
2
  import torch
3
  import torch.nn as nn
4
- from segformer_plusplus.utils import resize
5
- from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone-Import
6
- from mix_vision_transformer_config import MySegformerConfig # Config-Import
7
-
8
- # Head-Implementierung (vereinfacht)
9
- class SegformerHead(nn.Module):
10
- def __init__(self,
11
- in_channels=[64, 128, 256, 512], # anpassen je nach Backbone-Ausgabe!
12
- in_index=[0, 1, 2, 3],
13
- channels=256,
14
- dropout_ratio=0.1,
15
- out_channels=19, # Anzahl Klassen anpassen!
16
- norm_cfg=None,
17
- align_corners=False,
18
- interpolate_mode='bilinear'):
19
- super().__init__()
20
- self.in_channels = in_channels
21
- self.in_index = in_index
22
- self.channels = channels
23
- self.dropout_ratio = dropout_ratio
24
- self.out_channels = out_channels
25
- self.norm_cfg = norm_cfg
26
- self.align_corners = align_corners
27
- self.interpolate_mode = interpolate_mode
28
-
29
- self.act_cfg = dict(type='ReLU')
30
- self.conv_seg = nn.Conv2d(channels, out_channels, kernel_size=1)
31
- self.dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None
32
-
33
- num_inputs = len(in_channels)
34
- assert num_inputs == len(in_index)
35
-
36
- from segformer_plusplus.utils.activation import ConvModule
37
-
38
- self.convs = nn.ModuleList()
39
- for i in range(num_inputs):
40
- self.convs.append(
41
- ConvModule(
42
- in_channels=in_channels[i],
43
- out_channels=channels,
44
- kernel_size=1,
45
- stride=1,
46
- bias=False,
47
- norm_cfg=norm_cfg,
48
- act_cfg=self.act_cfg))
49
-
50
- self.fusion_conv = ConvModule(
51
- in_channels=channels * num_inputs,
52
- out_channels=channels,
53
- kernel_size=1,
54
- bias=False,
55
- norm_cfg=norm_cfg)
56
-
57
- def cls_seg(self, feat):
58
- if self.dropout is not None:
59
- feat = self.dropout(feat)
60
- return self.conv_seg(feat)
61
-
62
- def forward(self, inputs):
63
- outs = []
64
- for idx in range(len(inputs)):
65
- x = inputs[idx]
66
- conv = self.convs[idx]
67
- outs.append(
68
- resize(
69
- input=conv(x),
70
- size=inputs[0].shape[2:],
71
- mode=self.interpolate_mode,
72
- align_corners=self.align_corners))
73
 
74
- out = self.fusion_conv(torch.cat(outs, dim=1))
75
- out = self.cls_seg(out)
76
- return out
77
 
78
 
79
  class MySegformerForSemanticSegmentation(PreTrainedModel):
@@ -83,9 +14,9 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
83
  def __init__(self, config):
84
  super().__init__(config)
85
 
86
- # Wichtig: die gesamte Liste übergeben, nicht nur das erste Element
87
  self.backbone = MixVisionTransformer(
88
- embed_dims=config.embed_dims, # GANZE Liste, z.B. [64, 128, 320, 512]
89
  num_stages=config.num_stages,
90
  num_layers=config.num_layers,
91
  num_heads=config.num_heads,
@@ -100,16 +31,15 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
100
  out_indices=config.out_indices
101
  )
102
 
103
- # Sicherstellen, dass in_channels eine Liste ist
104
  in_channels = config.embed_dims
105
  if isinstance(in_channels, int):
106
  in_channels = [in_channels]
107
 
108
- print(f"config.embed_dims: {config.embed_dims}, type: {type(config.embed_dims)}")
109
  self.segmentation_head = SegformerHead(
110
- in_channels=config.embed_dims, # z.B. [64, 128, 320, 512]
111
- in_index=list(config.out_indices), # z.B. [0, 1, 2, 3]
112
- out_channels=config.num_classes if hasattr(config, 'num_classes') else 19,
113
  dropout_ratio=0.1,
114
  align_corners=False
115
  )
@@ -117,10 +47,10 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
117
  self.post_init()
118
 
119
  def forward(self, x):
120
- # Backbone liefert eine Liste von Features (Multi-Scale Features)
121
- features = self.backbone(x) # z.B. List[Tensor]
122
 
123
- # Übergabe an den Segmentation Head
124
- output = self.segmentation_head(features) # Tensor: logits oder Segmentationsmasken
125
 
126
- return output
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone
6
+ from mix_vision_transformer_config import MySegformerConfig # Config
7
+ from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head
8
 
9
 
10
  class MySegformerForSemanticSegmentation(PreTrainedModel):
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
 
17
+ # Backbone (MixVisionTransformer)
18
  self.backbone = MixVisionTransformer(
19
+ embed_dims=config.embed_dims, # z.B. [64, 128, 320, 512]
20
  num_stages=config.num_stages,
21
  num_layers=config.num_layers,
22
  num_heads=config.num_heads,
 
31
  out_indices=config.out_indices
32
  )
33
 
34
+ # Head direkt importieren
35
  in_channels = config.embed_dims
36
  if isinstance(in_channels, int):
37
  in_channels = [in_channels]
38
 
 
39
  self.segmentation_head = SegformerHead(
40
+ in_channels=in_channels, # Liste der Embeddings aus Backbone
41
+ in_index=list(config.out_indices), # welche Feature Maps genutzt werden
42
+ out_channels=getattr(config, "num_classes", 19), # Anzahl Klassen
43
  dropout_ratio=0.1,
44
  align_corners=False
45
  )
 
47
  self.post_init()
48
 
49
  def forward(self, x):
50
+ # Backbone Features (Liste von Tensors)
51
+ features = self.backbone(x)
52
 
53
+ # Head logits
54
+ logits = self.segmentation_head(features)
55
 
56
+ return {"logits": logits}