Tim77777767 commited on
Commit
d21bb56
·
1 Parent(s): eacc0a7

ANpassung modeling

Browse files
Files changed (1) hide show
  1. modeling_my_segformer.py +13 -14
modeling_my_segformer.py CHANGED
@@ -2,9 +2,9 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
 
5
- from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone
6
- from mix_vision_transformer_config import MySegformerConfig # Config
7
- from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head
8
 
9
 
10
  class MySegformerForSemanticSegmentation(PreTrainedModel):
@@ -16,8 +16,7 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
16
 
17
  # Backbone (MixVisionTransformer)
18
  self.backbone = MixVisionTransformer(
19
- # Pass only the first element of embed_dims for the initial patch embedding
20
- embed_dims=config.embed_dims[0], # <--- KORRIGIERTE ZEILE
21
  num_stages=config.num_stages,
22
  num_layers=config.num_layers,
23
  num_heads=config.num_heads,
@@ -32,31 +31,31 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
32
  out_indices=config.out_indices
33
  )
34
 
35
- # Head direkt importieren
36
  # Use config.decode_head.in_channels directly, as it's defined in the config.
37
- # This ensures consistency with the backbone's output channels for the head.
38
  in_channels = config.decode_head["in_channels"]
39
 
40
  self.segmentation_head = SegformerHead(
41
- in_channels=in_channels, # Liste der Embeddings aus Backbone
42
- in_index=list(config.out_indices), # welche Feature Maps genutzt werden
43
- channels=config.decode_head["channels"], # channels parameter for SegformerHead itself
44
  dropout_ratio=config.decode_head["dropout_ratio"],
45
- num_classes=getattr(config, "num_classes", 19), # Ensure num_classes is passed if not directly in decode_head config
46
- align_corners=config.decode_head["align_corners"]
 
 
 
47
  )
48
 
49
  self.post_init()
50
 
51
  def forward(self, x):
52
- # Backbone → Features (Liste von Tensors)
53
  features = self.backbone(x)
54
 
55
  # Debug: Ausgabe der Shapes der Backbone-Features
56
  for i, f in enumerate(features):
57
  print(f"Feature {i}: shape = {f.shape}")
58
 
59
- # Head → logits
60
  logits = self.segmentation_head(features)
61
 
62
  return logits
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
 
5
+ from segformer_plusplus.model.backbone.mit import MixVisionTransformer
6
+ from mix_vision_transformer_config import MySegformerConfig
7
+ from segformer_plusplus.model.head.segformer_head import SegformerHead
8
 
9
 
10
  class MySegformerForSemanticSegmentation(PreTrainedModel):
 
16
 
17
  # Backbone (MixVisionTransformer)
18
  self.backbone = MixVisionTransformer(
19
+ embed_dims=config.embed_dims[0],
 
20
  num_stages=config.num_stages,
21
  num_layers=config.num_layers,
22
  num_heads=config.num_heads,
 
31
  out_indices=config.out_indices
32
  )
33
 
34
+ # Head initialization
35
  # Use config.decode_head.in_channels directly, as it's defined in the config.
 
36
  in_channels = config.decode_head["in_channels"]
37
 
38
  self.segmentation_head = SegformerHead(
39
+ in_channels=in_channels,
40
+ in_index=list(config.out_indices),
41
+ channels=config.decode_head["channels"],
42
  dropout_ratio=config.decode_head["dropout_ratio"],
43
+ # REMOVED: num_classes=getattr(config, "num_classes", 19), <--- Entfernen Sie diese Zeile
44
+ align_corners=config.decode_head["align_corners"],
45
+ # Fügen Sie hier interpolate_mode hinzu, falls SegformerHead dies explizit erwartet
46
+ # und es in config.decode_head definiert ist (was es in Ihrer config.json ist)
47
+ interpolate_mode=config.decode_head["interpolate_mode"] # <-- Hinzugefügt
48
  )
49
 
50
  self.post_init()
51
 
52
  def forward(self, x):
 
53
  features = self.backbone(x)
54
 
55
  # Debug: Ausgabe der Shapes der Backbone-Features
56
  for i, f in enumerate(features):
57
  print(f"Feature {i}: shape = {f.shape}")
58
 
 
59
  logits = self.segmentation_head(features)
60
 
61
  return logits