Tim77777767 commited on
Commit
e4634c2
·
1 Parent(s): 1a260cd

Anpassungen für HF

Browse files
modeling_my_segformer.py CHANGED
@@ -120,3 +120,12 @@ class MySegformerForSemanticSegmentation(PreTrainedModel):
120
  )
121
 
122
  self.post_init()
 
 
 
 
 
 
 
 
 
 
120
  )
121
 
122
  self.post_init()
123
+
124
+ def forward(self, x):
125
+ # Backbone liefert eine Liste von Features (Multi-Scale Features)
126
+ features = self.backbone(x) # z.B. List[Tensor]
127
+
128
+ # Übergabe an den Segmentation Head
129
+ output = self.segmentation_head(features) # Tensor: logits oder Segmentationsmasken
130
+
131
+ return output
preTrainedTest.py CHANGED
@@ -1,13 +1,42 @@
 
 
 
 
 
 
1
  from modeling_my_segformer import MySegformerForSemanticSegmentation
2
  from mix_vision_transformer_config import MySegformerConfig
3
 
4
- # Der Pfad zu deinem HF-Repo (kann auch einfach als String benutzt werden)
5
- model_name_or_path = "TimM77/SegformerPlusPlus"
 
6
 
7
- # Config laden (automatisch aus config.json im Repo)
 
8
  config = MySegformerConfig.from_pretrained(model_name_or_path)
9
-
10
- # Modell laden (Gewichte aus pytorch_model.bin + Config)
11
  model = MySegformerForSemanticSegmentation.from_pretrained(model_name_or_path, config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- print(model, config)
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import torchvision.transforms as T
4
+ import numpy as np
5
+ import os
6
+
7
  from modeling_my_segformer import MySegformerForSemanticSegmentation
8
  from mix_vision_transformer_config import MySegformerConfig
9
 
10
+ # Gerät auswählen
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ print(f"Using device: {device}")
13
 
14
+ # Modell laden
15
+ model_name_or_path = "TimM77/SegformerPlusPlus"
16
  config = MySegformerConfig.from_pretrained(model_name_or_path)
 
 
17
  model = MySegformerForSemanticSegmentation.from_pretrained(model_name_or_path, config=config)
18
+ model.to(device).eval()
19
+
20
+ # Bild laden
21
+ image_path = "segformer_plusplus/cityscape/berlin_000543_000019_leftImg8bit.png"
22
+ image = Image.open(image_path).convert("RGB")
23
+
24
+ # Preprocessing
25
+ transform = T.Compose([
26
+ T.Resize((512, 512)),
27
+ T.ToTensor(),
28
+ T.Normalize(mean=[0.485, 0.456, 0.406],
29
+ std=[0.229, 0.224, 0.225])
30
+ ])
31
+ input_tensor = transform(image).unsqueeze(0).to(device)
32
+
33
+ # Inferenz
34
+ with torch.no_grad():
35
+ output = model(input_tensor)
36
+ logits = output.logits if hasattr(output, "logits") else output
37
+ pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
38
 
39
+ # Ergebnis als Textdatei speichern
40
+ output_path = os.path.join("segformer_plusplus", "cityscapes_prediction_output_overHF.txt")
41
+ np.savetxt(output_path, pred, fmt="%d")
42
+ print(f"Prediction saved as {output_path}")
segformer_plusplus/model/backbone/mit.py CHANGED
@@ -415,7 +415,7 @@ class MixVisionTransformer(BaseModule):
415
  cur = 0
416
  self.layers = ModuleList()
417
  for i, num_layer in enumerate(num_layers):
418
- embed_dims_i = embed_dims[i]
419
  patch_embed = PatchEmbed(
420
  in_channels=in_channels,
421
  embed_dims=embed_dims_i,
 
415
  cur = 0
416
  self.layers = ModuleList()
417
  for i, num_layer in enumerate(num_layers):
418
+ embed_dims_i = embed_dims * num_heads[i]
419
  patch_embed = PatchEmbed(
420
  in_channels=in_channels,
421
  embed_dims=embed_dims_i,