Tim77777767
commited on
Commit
·
e5b4342
1
Parent(s):
85ebba9
Anpassungen an der modeling
Browse files
segformer_plusplus/modeling_segformer_plusplus.py
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
-
# modeling_segformer_plusplus.py
|
2 |
-
|
3 |
from typing import Optional, Tuple
|
4 |
-
import torch
|
5 |
import torch.nn as nn
|
6 |
from transformers import PreTrainedModel, PretrainedConfig
|
7 |
-
from transformers.modeling_outputs import SemanticSegmenterOutput
|
8 |
-
|
9 |
-
# Falls du SegFormer direkt importieren willst, musst du sicherstellen,
|
10 |
-
# dass diese Klasse im selben Repo verfügbar ist.
|
11 |
from segformer_plusplus.model import create_model
|
12 |
|
13 |
|
@@ -47,23 +40,5 @@ class SegformerPlusPlusForSemanticSegmentation(PreTrainedModel):
|
|
47 |
backbone=config.backbone,
|
48 |
tome_strategy=config.tome_strategy,
|
49 |
out_channels=config.num_labels,
|
50 |
-
pretrained=False,
|
51 |
-
)
|
52 |
-
|
53 |
-
def forward(
|
54 |
-
self,
|
55 |
-
pixel_values: torch.FloatTensor,
|
56 |
-
labels: Optional[torch.LongTensor] = None,
|
57 |
-
) -> SemanticSegmenterOutput:
|
58 |
-
|
59 |
-
logits = self.segformer(pixel_values)
|
60 |
-
|
61 |
-
loss = None
|
62 |
-
if labels is not None:
|
63 |
-
loss_fct = nn.CrossEntropyLoss(ignore_index=255)
|
64 |
-
loss = loss_fct(logits, labels.long())
|
65 |
-
|
66 |
-
return SemanticSegmenterOutput(
|
67 |
-
loss=loss,
|
68 |
-
logits=logits,
|
69 |
)
|
|
|
|
|
|
|
1 |
from typing import Optional, Tuple
|
|
|
2 |
import torch.nn as nn
|
3 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
|
|
|
4 |
from segformer_plusplus.model import create_model
|
5 |
|
6 |
|
|
|
40 |
backbone=config.backbone,
|
41 |
tome_strategy=config.tome_strategy,
|
42 |
out_channels=config.num_labels,
|
43 |
+
pretrained=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
)
|