File size: 2,657 Bytes
c59e72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoProcessor


class VisionTransformer(nn.Module):
    """Huggingface AutoModel to generate token embeddings.
    Loads the correct class, e.g. BERT / RoBERTa etc.

    Args:
        model_name_or_path: Huggingface models name
            (https://huggingface.co/models)
        model_args: Keyword arguments passed to the Huggingface
            Transformers model
        tokenizer_args: Keyword arguments passed to the Huggingface
            Transformers tokenizer
        config_args: Keyword arguments passed to the Huggingface
            Transformers config
        cache_dir: Cache dir for Huggingface Transformers to store/load
            models
    """

    def __init__(
        self,
        model_name_or_path: str,
        model_args: Optional[Dict[str, Any]] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        config_args: Optional[Dict[str, Any]] = None,
        cache_dir: Optional[str] = None,
    ) -> None:
        super(VisionTransformer, self).__init__()
        if model_args is None:
            model_args = {}
        if tokenizer_args is None:
            tokenizer_args = {}
        if config_args is None:
            config_args = {}

        self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
        self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args, cache_dir=cache_dir)
        self.processor = AutoProcessor.from_pretrained(model_name_or_path, config=self.config, **tokenizer_args, cache_dir=cache_dir)

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Returns token_embeddings, cls_token"""
        output_states = self.model(pixel_values=features["pixel_values"], return_dict=False)[0]
        features.update({"token_embeddings": output_states})
        return features

    def get_word_embedding_dimension(self) -> int:
        return self.config.hidden_size

    def tokenize(
        self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True
    ) -> Dict[str, torch.Tensor]:
        return self.processor(texts, return_tensors="pt")

    def get_config_dict(self) -> Dict[str, Any]:
        return {key: self.__dict__[key] for key in self.config_keys}

    def save(self, output_path: str, safe_serialization: bool = True) -> None:
        self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
        self.processor.save_pretrained(output_path)