Tim77777767 commited on
Commit
e5b4342
·
1 Parent(s): 85ebba9

Anpassungen an der modeling

Browse files
segformer_plusplus/modeling_segformer_plusplus.py CHANGED
@@ -1,13 +1,6 @@
1
- # modeling_segformer_plusplus.py
2
-
3
  from typing import Optional, Tuple
4
- import torch
5
  import torch.nn as nn
6
  from transformers import PreTrainedModel, PretrainedConfig
7
- from transformers.modeling_outputs import SemanticSegmenterOutput
8
-
9
- # Falls du SegFormer direkt importieren willst, musst du sicherstellen,
10
- # dass diese Klasse im selben Repo verfügbar ist.
11
  from segformer_plusplus.model import create_model
12
 
13
 
@@ -47,23 +40,5 @@ class SegformerPlusPlusForSemanticSegmentation(PreTrainedModel):
47
  backbone=config.backbone,
48
  tome_strategy=config.tome_strategy,
49
  out_channels=config.num_labels,
50
- pretrained=False, # Kein Pretrained hier – wird über .from_pretrained geladen
51
- )
52
-
53
- def forward(
54
- self,
55
- pixel_values: torch.FloatTensor,
56
- labels: Optional[torch.LongTensor] = None,
57
- ) -> SemanticSegmenterOutput:
58
-
59
- logits = self.segformer(pixel_values)
60
-
61
- loss = None
62
- if labels is not None:
63
- loss_fct = nn.CrossEntropyLoss(ignore_index=255)
64
- loss = loss_fct(logits, labels.long())
65
-
66
- return SemanticSegmenterOutput(
67
- loss=loss,
68
- logits=logits,
69
  )
 
 
 
1
  from typing import Optional, Tuple
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
 
 
 
 
4
  from segformer_plusplus.model import create_model
5
 
6
 
 
40
  backbone=config.backbone,
41
  tome_strategy=config.tome_strategy,
42
  out_channels=config.num_labels,
43
+ pretrained=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )