Image Feature Extraction
Transformers
Safetensors
dinov2
File size: 2,572 Bytes
201611f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from torch import nn
from transformers import AutoImageProcessor
from transformers import AutoModel
from transformers.image_processing_base import BatchFeature


__version__ = "0.1.0"

TypeClsToken = Float[Tensor, "batch_size embed_dim"]
TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"]
TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"]
TypeInputImages = Image.Image | list[Image.Image]


class RadDino(nn.Module):
    _REPO = "microsoft/rad-dino"

    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained(self._REPO).eval()
        self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)

    def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
        return self.processor(image_or_images, return_tensors="pt")

    def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]:
        outputs = self.model(**inputs)
        cls_token = outputs.last_hidden_state[:, 0]
        patch_tokens = outputs.last_hidden_state[:, 1:]
        return cls_token, patch_tokens

    def reshape_patch_tokens(
        self,
        patch_tokens_flat: TypePatchTokensFlat,
    ) -> TypePatchTokens:
        input_size = self.processor.crop_size["height"]
        patch_size = self.model.config.patch_size
        embeddings_size = input_size // patch_size
        patches_grid = rearrange(
            patch_tokens_flat,
            "batch (height width) embed_dim -> batch embed_dim height width",
            height=embeddings_size,
        )
        return patches_grid

    @torch.inference_mode()
    def extract_features(
        self,
        image_or_images: TypeInputImages,
    ) -> tuple[TypeClsToken, TypePatchTokens]:
        inputs = self.preprocess(image_or_images)
        cls_token, patch_tokens_flat = self.encode(inputs)
        patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
        return cls_token, patch_tokens

    def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken:
        cls_token, _ = self.extract_features(image_or_images)
        return cls_token

    def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens:
        _, patch_tokens = self.extract_features(image_or_images)
        return patch_tokens

    def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]:
        return self.extract_features(*args)