Tim77777767
commited on
Commit
·
d21bb56
1
Parent(s):
eacc0a7
ANpassung modeling
Browse files- modeling_my_segformer.py +13 -14
modeling_my_segformer.py
CHANGED
@@ -2,9 +2,9 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from transformers import PreTrainedModel
|
4 |
|
5 |
-
from segformer_plusplus.model.backbone.mit import MixVisionTransformer
|
6 |
-
from mix_vision_transformer_config import MySegformerConfig
|
7 |
-
from segformer_plusplus.model.head.segformer_head import SegformerHead
|
8 |
|
9 |
|
10 |
class MySegformerForSemanticSegmentation(PreTrainedModel):
|
@@ -16,8 +16,7 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
|
|
16 |
|
17 |
# Backbone (MixVisionTransformer)
|
18 |
self.backbone = MixVisionTransformer(
|
19 |
-
|
20 |
-
embed_dims=config.embed_dims[0], # <--- KORRIGIERTE ZEILE
|
21 |
num_stages=config.num_stages,
|
22 |
num_layers=config.num_layers,
|
23 |
num_heads=config.num_heads,
|
@@ -32,31 +31,31 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
|
|
32 |
out_indices=config.out_indices
|
33 |
)
|
34 |
|
35 |
-
# Head
|
36 |
# Use config.decode_head.in_channels directly, as it's defined in the config.
|
37 |
-
# This ensures consistency with the backbone's output channels for the head.
|
38 |
in_channels = config.decode_head["in_channels"]
|
39 |
|
40 |
self.segmentation_head = SegformerHead(
|
41 |
-
in_channels=in_channels,
|
42 |
-
in_index=list(config.out_indices),
|
43 |
-
channels=config.decode_head["channels"],
|
44 |
dropout_ratio=config.decode_head["dropout_ratio"],
|
45 |
-
num_classes=getattr(config, "num_classes", 19),
|
46 |
-
align_corners=config.decode_head["align_corners"]
|
|
|
|
|
|
|
47 |
)
|
48 |
|
49 |
self.post_init()
|
50 |
|
51 |
def forward(self, x):
|
52 |
-
# Backbone → Features (Liste von Tensors)
|
53 |
features = self.backbone(x)
|
54 |
|
55 |
# Debug: Ausgabe der Shapes der Backbone-Features
|
56 |
for i, f in enumerate(features):
|
57 |
print(f"Feature {i}: shape = {f.shape}")
|
58 |
|
59 |
-
# Head → logits
|
60 |
logits = self.segmentation_head(features)
|
61 |
|
62 |
return logits
|
|
|
2 |
import torch.nn as nn
|
3 |
from transformers import PreTrainedModel
|
4 |
|
5 |
+
from segformer_plusplus.model.backbone.mit import MixVisionTransformer
|
6 |
+
from mix_vision_transformer_config import MySegformerConfig
|
7 |
+
from segformer_plusplus.model.head.segformer_head import SegformerHead
|
8 |
|
9 |
|
10 |
class MySegformerForSemanticSegmentation(PreTrainedModel):
|
|
|
16 |
|
17 |
# Backbone (MixVisionTransformer)
|
18 |
self.backbone = MixVisionTransformer(
|
19 |
+
embed_dims=config.embed_dims[0],
|
|
|
20 |
num_stages=config.num_stages,
|
21 |
num_layers=config.num_layers,
|
22 |
num_heads=config.num_heads,
|
|
|
31 |
out_indices=config.out_indices
|
32 |
)
|
33 |
|
34 |
+
# Head initialization
|
35 |
# Use config.decode_head.in_channels directly, as it's defined in the config.
|
|
|
36 |
in_channels = config.decode_head["in_channels"]
|
37 |
|
38 |
self.segmentation_head = SegformerHead(
|
39 |
+
in_channels=in_channels,
|
40 |
+
in_index=list(config.out_indices),
|
41 |
+
channels=config.decode_head["channels"],
|
42 |
dropout_ratio=config.decode_head["dropout_ratio"],
|
43 |
+
# REMOVED: num_classes=getattr(config, "num_classes", 19), <--- Entfernen Sie diese Zeile
|
44 |
+
align_corners=config.decode_head["align_corners"],
|
45 |
+
# Fügen Sie hier interpolate_mode hinzu, falls SegformerHead dies explizit erwartet
|
46 |
+
# und es in config.decode_head definiert ist (was es in Ihrer config.json ist)
|
47 |
+
interpolate_mode=config.decode_head["interpolate_mode"] # <-- Hinzugefügt
|
48 |
)
|
49 |
|
50 |
self.post_init()
|
51 |
|
52 |
def forward(self, x):
|
|
|
53 |
features = self.backbone(x)
|
54 |
|
55 |
# Debug: Ausgabe der Shapes der Backbone-Features
|
56 |
for i, f in enumerate(features):
|
57 |
print(f"Feature {i}: shape = {f.shape}")
|
58 |
|
|
|
59 |
logits = self.segmentation_head(features)
|
60 |
|
61 |
return logits
|