Tim77777767
commited on
Commit
·
eacc0a7
1
Parent(s):
c2cafb7
Anpassung an der modeling für b5 nutzung
Browse files- modeling_my_segformer.py +14 -10
modeling_my_segformer.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from transformers import PreTrainedModel
|
4 |
|
5 |
-
from segformer_plusplus.model.backbone.mit import MixVisionTransformer
|
6 |
from mix_vision_transformer_config import MySegformerConfig # Config
|
7 |
from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head
|
8 |
|
@@ -16,7 +16,8 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
|
|
16 |
|
17 |
# Backbone (MixVisionTransformer)
|
18 |
self.backbone = MixVisionTransformer(
|
19 |
-
|
|
|
20 |
num_stages=config.num_stages,
|
21 |
num_layers=config.num_layers,
|
22 |
num_heads=config.num_heads,
|
@@ -32,14 +33,17 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
|
|
32 |
)
|
33 |
|
34 |
# Head direkt importieren
|
35 |
-
|
|
|
|
|
36 |
|
37 |
self.segmentation_head = SegformerHead(
|
38 |
-
in_channels=in_channels,
|
39 |
-
in_index=list(config.out_indices),
|
40 |
-
|
41 |
-
dropout_ratio=
|
42 |
-
|
|
|
43 |
)
|
44 |
|
45 |
self.post_init()
|
@@ -48,11 +52,11 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
|
|
48 |
# Backbone → Features (Liste von Tensors)
|
49 |
features = self.backbone(x)
|
50 |
|
51 |
-
|
52 |
for i, f in enumerate(features):
|
53 |
print(f"Feature {i}: shape = {f.shape}")
|
54 |
|
55 |
# Head → logits
|
56 |
logits = self.segmentation_head(features)
|
57 |
|
58 |
-
return logits
|
|
|
2 |
import torch.nn as nn
|
3 |
from transformers import PreTrainedModel
|
4 |
|
5 |
+
from segformer_plusplus.model.backbone.mit import MixVisionTransformer # Backbone
|
6 |
from mix_vision_transformer_config import MySegformerConfig # Config
|
7 |
from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head
|
8 |
|
|
|
16 |
|
17 |
# Backbone (MixVisionTransformer)
|
18 |
self.backbone = MixVisionTransformer(
|
19 |
+
# Pass only the first element of embed_dims for the initial patch embedding
|
20 |
+
embed_dims=config.embed_dims[0], # <--- KORRIGIERTE ZEILE
|
21 |
num_stages=config.num_stages,
|
22 |
num_layers=config.num_layers,
|
23 |
num_heads=config.num_heads,
|
|
|
33 |
)
|
34 |
|
35 |
# Head direkt importieren
|
36 |
+
# Use config.decode_head.in_channels directly, as it's defined in the config.
|
37 |
+
# This ensures consistency with the backbone's output channels for the head.
|
38 |
+
in_channels = config.decode_head["in_channels"]
|
39 |
|
40 |
self.segmentation_head = SegformerHead(
|
41 |
+
in_channels=in_channels, # Liste der Embeddings aus Backbone
|
42 |
+
in_index=list(config.out_indices), # welche Feature Maps genutzt werden
|
43 |
+
channels=config.decode_head["channels"], # channels parameter for SegformerHead itself
|
44 |
+
dropout_ratio=config.decode_head["dropout_ratio"],
|
45 |
+
num_classes=getattr(config, "num_classes", 19), # Ensure num_classes is passed if not directly in decode_head config
|
46 |
+
align_corners=config.decode_head["align_corners"]
|
47 |
)
|
48 |
|
49 |
self.post_init()
|
|
|
52 |
# Backbone → Features (Liste von Tensors)
|
53 |
features = self.backbone(x)
|
54 |
|
55 |
+
# Debug: Ausgabe der Shapes der Backbone-Features
|
56 |
for i, f in enumerate(features):
|
57 |
print(f"Feature {i}: shape = {f.shape}")
|
58 |
|
59 |
# Head → logits
|
60 |
logits = self.segmentation_head(features)
|
61 |
|
62 |
+
return logits
|