File size: 2,657 Bytes
c59e72d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoProcessor
class VisionTransformer(nn.Module):
"""Huggingface AutoModel to generate token embeddings.
Loads the correct class, e.g. BERT / RoBERTa etc.
Args:
model_name_or_path: Huggingface models name
(https://huggingface.co/models)
model_args: Keyword arguments passed to the Huggingface
Transformers model
tokenizer_args: Keyword arguments passed to the Huggingface
Transformers tokenizer
config_args: Keyword arguments passed to the Huggingface
Transformers config
cache_dir: Cache dir for Huggingface Transformers to store/load
models
"""
def __init__(
self,
model_name_or_path: str,
model_args: Optional[Dict[str, Any]] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
config_args: Optional[Dict[str, Any]] = None,
cache_dir: Optional[str] = None,
) -> None:
super(VisionTransformer, self).__init__()
if model_args is None:
model_args = {}
if tokenizer_args is None:
tokenizer_args = {}
if config_args is None:
config_args = {}
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args, cache_dir=cache_dir)
self.processor = AutoProcessor.from_pretrained(model_name_or_path, config=self.config, **tokenizer_args, cache_dir=cache_dir)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Returns token_embeddings, cls_token"""
output_states = self.model(pixel_values=features["pixel_values"], return_dict=False)[0]
features.update({"token_embeddings": output_states})
return features
def get_word_embedding_dimension(self) -> int:
return self.config.hidden_size
def tokenize(
self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True
) -> Dict[str, torch.Tensor]:
return self.processor(texts, return_tensors="pt")
def get_config_dict(self) -> Dict[str, Any]:
return {key: self.__dict__[key] for key in self.config_keys}
def save(self, output_path: str, safe_serialization: bool = True) -> None:
self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
self.processor.save_pretrained(output_path)
|