Move inputs to model's device
Browse files- src/rad_dino/__init__.py +6 -2
src/rad_dino/__init__.py
CHANGED
@@ -6,7 +6,7 @@ from torch import Tensor
|
|
6 |
from torch import nn
|
7 |
from transformers import AutoImageProcessor
|
8 |
from transformers import AutoModel
|
9 |
-
from transformers.
|
10 |
|
11 |
|
12 |
__version__ = "0.1.0"
|
@@ -25,6 +25,10 @@ class RadDino(nn.Module):
|
|
25 |
self.model = AutoModel.from_pretrained(self._REPO).eval()
|
26 |
self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
|
27 |
|
|
|
|
|
|
|
|
|
28 |
def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
|
29 |
return self.processor(image_or_images, return_tensors="pt")
|
30 |
|
@@ -53,7 +57,7 @@ class RadDino(nn.Module):
|
|
53 |
self,
|
54 |
image_or_images: TypeInputImages,
|
55 |
) -> tuple[TypeClsToken, TypePatchTokens]:
|
56 |
-
inputs = self.preprocess(image_or_images)
|
57 |
cls_token, patch_tokens_flat = self.encode(inputs)
|
58 |
patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
|
59 |
return cls_token, patch_tokens
|
|
|
6 |
from torch import nn
|
7 |
from transformers import AutoImageProcessor
|
8 |
from transformers import AutoModel
|
9 |
+
from transformers.feature_extraction_utils import BatchFeature
|
10 |
|
11 |
|
12 |
__version__ = "0.1.0"
|
|
|
25 |
self.model = AutoModel.from_pretrained(self._REPO).eval()
|
26 |
self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
|
27 |
|
28 |
+
@property
|
29 |
+
def device(self) -> torch.device:
|
30 |
+
return next(self.model.parameters()).device
|
31 |
+
|
32 |
def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
|
33 |
return self.processor(image_or_images, return_tensors="pt")
|
34 |
|
|
|
57 |
self,
|
58 |
image_or_images: TypeInputImages,
|
59 |
) -> tuple[TypeClsToken, TypePatchTokens]:
|
60 |
+
inputs = self.preprocess(image_or_images).to(self.device)
|
61 |
cls_token, patch_tokens_flat = self.encode(inputs)
|
62 |
patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
|
63 |
return cls_token, patch_tokens
|