Tim77777767 commited on
Commit
eacc0a7
·
1 Parent(s): c2cafb7

Anpassung an der modeling für b5 nutzung

Browse files
Files changed (1) hide show
  1. modeling_my_segformer.py +14 -10
modeling_my_segformer.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
 
5
- from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone
6
  from mix_vision_transformer_config import MySegformerConfig # Config
7
  from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head
8
 
@@ -16,7 +16,8 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
16
 
17
  # Backbone (MixVisionTransformer)
18
  self.backbone = MixVisionTransformer(
19
- embed_dims=config.embed_dims, # z.B. [64, 128, 320, 512]
 
20
  num_stages=config.num_stages,
21
  num_layers=config.num_layers,
22
  num_heads=config.num_heads,
@@ -32,14 +33,17 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
32
  )
33
 
34
  # Head direkt importieren
35
- in_channels = [64, 128, 320, 512]
 
 
36
 
37
  self.segmentation_head = SegformerHead(
38
- in_channels=in_channels, # Liste der Embeddings aus Backbone
39
- in_index=list(config.out_indices), # welche Feature Maps genutzt werden
40
- out_channels=getattr(config, "num_classes", 19), # Anzahl Klassen
41
- dropout_ratio=0.1,
42
- align_corners=False
 
43
  )
44
 
45
  self.post_init()
@@ -48,11 +52,11 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
48
  # Backbone → Features (Liste von Tensors)
49
  features = self.backbone(x)
50
 
51
- # Debug: Ausgabe der Shapes der Backbone-Features
52
  for i, f in enumerate(features):
53
  print(f"Feature {i}: shape = {f.shape}")
54
 
55
  # Head → logits
56
  logits = self.segmentation_head(features)
57
 
58
- return logits
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
 
5
+ from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone
6
  from mix_vision_transformer_config import MySegformerConfig # Config
7
  from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head
8
 
 
16
 
17
  # Backbone (MixVisionTransformer)
18
  self.backbone = MixVisionTransformer(
19
+ # Pass only the first element of embed_dims for the initial patch embedding
20
+ embed_dims=config.embed_dims[0], # <--- KORRIGIERTE ZEILE
21
  num_stages=config.num_stages,
22
  num_layers=config.num_layers,
23
  num_heads=config.num_heads,
 
33
  )
34
 
35
  # Head direkt importieren
36
+ # Use config.decode_head.in_channels directly, as it's defined in the config.
37
+ # This ensures consistency with the backbone's output channels for the head.
38
+ in_channels = config.decode_head["in_channels"]
39
 
40
  self.segmentation_head = SegformerHead(
41
+ in_channels=in_channels, # Liste der Embeddings aus Backbone
42
+ in_index=list(config.out_indices), # welche Feature Maps genutzt werden
43
+ channels=config.decode_head["channels"], # channels parameter for SegformerHead itself
44
+ dropout_ratio=config.decode_head["dropout_ratio"],
45
+ num_classes=getattr(config, "num_classes", 19), # Ensure num_classes is passed if not directly in decode_head config
46
+ align_corners=config.decode_head["align_corners"]
47
  )
48
 
49
  self.post_init()
 
52
  # Backbone → Features (Liste von Tensors)
53
  features = self.backbone(x)
54
 
55
+ # Debug: Ausgabe der Shapes der Backbone-Features
56
  for i, f in enumerate(features):
57
  print(f"Feature {i}: shape = {f.shape}")
58
 
59
  # Head → logits
60
  logits = self.segmentation_head(features)
61
 
62
+ return logits