Image Feature Extraction
Transformers
Safetensors
dinov2
fepegar commited on
Commit
d9eb255
·
1 Parent(s): 201611f

Move inputs to model's device

Browse files
Files changed (1) hide show
  1. src/rad_dino/__init__.py +6 -2
src/rad_dino/__init__.py CHANGED
@@ -6,7 +6,7 @@ from torch import Tensor
6
  from torch import nn
7
  from transformers import AutoImageProcessor
8
  from transformers import AutoModel
9
- from transformers.image_processing_base import BatchFeature
10
 
11
 
12
  __version__ = "0.1.0"
@@ -25,6 +25,10 @@ class RadDino(nn.Module):
25
  self.model = AutoModel.from_pretrained(self._REPO).eval()
26
  self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
27
 
 
 
 
 
28
  def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
29
  return self.processor(image_or_images, return_tensors="pt")
30
 
@@ -53,7 +57,7 @@ class RadDino(nn.Module):
53
  self,
54
  image_or_images: TypeInputImages,
55
  ) -> tuple[TypeClsToken, TypePatchTokens]:
56
- inputs = self.preprocess(image_or_images)
57
  cls_token, patch_tokens_flat = self.encode(inputs)
58
  patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
59
  return cls_token, patch_tokens
 
6
  from torch import nn
7
  from transformers import AutoImageProcessor
8
  from transformers import AutoModel
9
+ from transformers.feature_extraction_utils import BatchFeature
10
 
11
 
12
  __version__ = "0.1.0"
 
25
  self.model = AutoModel.from_pretrained(self._REPO).eval()
26
  self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
27
 
28
+ @property
29
+ def device(self) -> torch.device:
30
+ return next(self.model.parameters()).device
31
+
32
  def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
33
  return self.processor(image_or_images, return_tensors="pt")
34
 
 
57
  self,
58
  image_or_images: TypeInputImages,
59
  ) -> tuple[TypeClsToken, TypePatchTokens]:
60
+ inputs = self.preprocess(image_or_images).to(self.device)
61
  cls_token, patch_tokens_flat = self.encode(inputs)
62
  patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
63
  return cls_token, patch_tokens