diff --git a/.gitattributes b/.gitattributes index d0f505a61f41e741902b539b0cb5441e6943ca57..ad76a03cd43014294f4f6a6e9aa01d31de713043 100644 --- a/.gitattributes +++ b/.gitattributes @@ -59,3 +59,4 @@ torchrun.exe filter=lfs diff=lfs merge=lfs -text tqdm.exe filter=lfs diff=lfs merge=lfs -text transformers-cli.exe filter=lfs diff=lfs merge=lfs -text wheel.exe filter=lfs diff=lfs merge=lfs -text +tokenizers.pyd filter=lfs diff=lfs merge=lfs -text diff --git a/__init__.cpython-312.pyc b/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b572178461b799287473fa97edbaee42caf57ca Binary files /dev/null and b/__init__.cpython-312.pyc differ diff --git a/__init__.py b/__init__.py index a8b61af2c4b58ff14ab7e3b24bf22e8ec6a95da0..0094328283214245bae20d1f18b2ba3db716ed2c 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,8 @@ -"""torchgen +# Generated content DO NOT EDIT +from .. import models -This module contains codegeneration utilities for PyTorch. It is used to -build PyTorch from source, but may also be used for out-of-tree projects -that extend PyTorch. - -Note well that we provide no BC guarantees for torchgen. If you're interested -in using torchgen and want the PyTorch team to be aware, please reach out -on GitHub. -""" +Model = models.Model +BPE = models.BPE +Unigram = models.Unigram +WordLevel = models.WordLevel +WordPiece = models.WordPiece diff --git a/__init__.pyi b/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e5bff5b080ff41e145752fbebe1e4deb08e10ade --- /dev/null +++ b/__init__.pyi @@ -0,0 +1,591 @@ +# Generated content DO NOT EDIT +class Model: + """ + Base class for all models + + The model represents the actual tokenization algorithm. This is the part that + will contain and manage the learned vocabulary. + + This class cannot be constructed directly. Please use one of the concrete models. + """ + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class BPE(Model): + """ + An implementation of the BPE (Byte-Pair Encoding) algorithm + + Args: + vocab (:obj:`Dict[str, int]`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + merges (:obj:`List[Tuple[str, str]]`, `optional`): + A list of pairs of tokens (:obj:`Tuple[str, str]`) :obj:`[("a", "b"),...]` + + cache_capacity (:obj:`int`, `optional`): + The number of words that the BPE cache can contain. The cache allows + to speed-up the process by keeping the result of the merge operations + for a number of words. + + dropout (:obj:`float`, `optional`): + A float between 0 and 1 that represents the BPE dropout to use. + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + + continuing_subword_prefix (:obj:`str`, `optional`): + The prefix to attach to subword units that don't represent a beginning of word. + + end_of_word_suffix (:obj:`str`, `optional`): + The suffix to attach to subword units that represent an end of word. + + fuse_unk (:obj:`bool`, `optional`): + Whether to fuse any subsequent unknown tokens into a single one + + byte_fallback (:obj:`bool`, `optional`): + Whether to use spm byte-fallback trick (defaults to False) + + ignore_merges (:obj:`bool`, `optional`): + Whether or not to match tokens with the vocab before using merges. + """ + def __init__( + self, + vocab=None, + merges=None, + cache_capacity=None, + dropout=None, + unk_token=None, + continuing_subword_prefix=None, + end_of_word_suffix=None, + fuse_unk=None, + byte_fallback=False, + ignore_merges=False, + ): + pass + + @staticmethod + def from_file(cls, vocab, merge, **kwargs): + """ + Instantiate a BPE model from the given files. + + This method is roughly equivalent to doing:: + + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + bpe = BPE(vocab, merges) + + If you don't need to keep the :obj:`vocab, merges` values lying around, + this method is more optimized than manually calling + :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + merges (:obj:`str`): + The path to a :obj:`merges.txt` file + + Returns: + :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(self, vocab, merges): + """ + Read a :obj:`vocab.json` and a :obj:`merges.txt` files + + This method provides a way to read and parse the content of these files, + returning the relevant data structures. If you want to instantiate some BPE models + from memory, this method gives you the expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + merges (:obj:`str`): + The path to a :obj:`merges.txt` file + + Returns: + A :obj:`Tuple` with the vocab and the merges: + The vocabulary and merges loaded into memory + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class Unigram(Model): + """ + An implementation of the Unigram algorithm + + Args: + vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`): + A list of vocabulary items and their relative score [("am", -0.2442),...] + """ + def __init__(self, vocab, unk_id, byte_fallback): + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class WordLevel(Model): + """ + An implementation of the WordLevel algorithm + + Most simple tokenizer model based on mapping tokens to their corresponding id. + + Args: + vocab (:obj:`str`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + """ + def __init__(self, vocab, unk_token): + pass + + @staticmethod + def from_file(vocab, unk_token): + """ + Instantiate a WordLevel model from the given file + + This method is roughly equivalent to doing:: + + vocab = WordLevel.read_file(vocab_filename) + wordlevel = WordLevel(vocab) + + If you don't need to keep the :obj:`vocab` values lying around, this method is + more optimized than manually calling :meth:`~tokenizers.models.WordLevel.read_file` to + initialize a :class:`~tokenizers.models.WordLevel` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + Returns: + :class:`~tokenizers.models.WordLevel`: An instance of WordLevel loaded from file + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(vocab): + """ + Read a :obj:`vocab.json` + + This method provides a way to read and parse the content of a vocabulary file, + returning the relevant data structures. If you want to instantiate some WordLevel models + from memory, this method gives you the expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.json` file + + Returns: + :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass + +class WordPiece(Model): + """ + An implementation of the WordPiece algorithm + + Args: + vocab (:obj:`Dict[str, int]`, `optional`): + A dictionary of string keys and their ids :obj:`{"am": 0,...}` + + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + + max_input_chars_per_word (:obj:`int`, `optional`): + The maximum number of characters to authorize in a single word. + """ + def __init__(self, vocab, unk_token, max_input_chars_per_word): + pass + + @staticmethod + def from_file(vocab, **kwargs): + """ + Instantiate a WordPiece model from the given file + + This method is roughly equivalent to doing:: + + vocab = WordPiece.read_file(vocab_filename) + wordpiece = WordPiece(vocab) + + If you don't need to keep the :obj:`vocab` values lying around, this method is + more optimized than manually calling :meth:`~tokenizers.models.WordPiece.read_file` to + initialize a :class:`~tokenizers.models.WordPiece` + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.txt` file + + Returns: + :class:`~tokenizers.models.WordPiece`: An instance of WordPiece loaded from file + """ + pass + + def get_trainer(self): + """ + Get the associated :class:`~tokenizers.trainers.Trainer` + + Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this + :class:`~tokenizers.models.Model`. + + Returns: + :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model + """ + pass + + def id_to_token(self, id): + """ + Get the token associated to an ID + + Args: + id (:obj:`int`): + An ID to convert to a token + + Returns: + :obj:`str`: The token associated to the ID + """ + pass + + @staticmethod + def read_file(vocab): + """ + Read a :obj:`vocab.txt` file + + This method provides a way to read and parse the content of a standard `vocab.txt` + file as used by the WordPiece Model, returning the relevant data structures. If you + want to instantiate some WordPiece models from memory, this method gives you the + expected input from the standard files. + + Args: + vocab (:obj:`str`): + The path to a :obj:`vocab.txt` file + + Returns: + :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` + """ + pass + + def save(self, folder, prefix): + """ + Save the current model + + Save the current model in the given folder, using the given prefix for the various + files that will get created. + Any file with the same name that already exists in this folder will be overwritten. + + Args: + folder (:obj:`str`): + The path to the target folder in which to save the various files + + prefix (:obj:`str`, `optional`): + An optional prefix, used to prefix each file name + + Returns: + :obj:`List[str]`: The list of saved files + """ + pass + + def token_to_id(self, tokens): + """ + Get the ID associated to a token + + Args: + token (:obj:`str`): + A token to convert to an ID + + Returns: + :obj:`int`: The ID associated to the token + """ + pass + + def tokenize(self, sequence): + """ + Tokenize a sequence + + Args: + sequence (:obj:`str`): + A sequence to tokenize + + Returns: + A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens + """ + pass diff --git a/activations.py b/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..2355fb5fed678d0de6e2c53f52644a35a691a34e --- /dev/null +++ b/activations.py @@ -0,0 +1,239 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict + +import torch +from packaging import version +from torch import Tensor, nn + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class PytorchGELUTanh(nn.Module): + """ + A fast C implementation of the tanh approximation of the GeLU activation function. See + https://arxiv.org/abs/1606.08415. + + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.12.0"): + raise ImportError( + f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " + "PytorchGELUTanh. Please upgrade torch." + ) + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.gelu(input, approximate="tanh") + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class GELUActivation(nn.Module): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Module): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Module): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError(f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return torch.clip(gelu(x), self.min, self.max) + + +class AccurateGELUActivation(nn.Module): + """ + Applies GELU approximation that is faster than default and more accurate than QuickGELU. See: + https://github.com/hendrycks/GELUs + + Implemented along with MEGA (Moving Average Equipped Gated Attention) + """ + + def __init__(self): + super().__init__() + self.precomputed_constant = math.sqrt(2 / math.pi) + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3)))) + + +class MishActivation(nn.Module): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.9.0"): + self.act = self._mish_python + else: + self.act = nn.functional.mish + + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def forward(self, input: Tensor) -> Tensor: + return input + + +class LaplaceActivation(nn.Module): + """ + Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See + https://arxiv.org/abs/2209.10655 + + Inspired by squared relu, but with bounded range and gradient for better stability + """ + + def forward(self, input, mu=0.707107, sigma=0.282095): + input = (input - mu).div(sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + torch.erf(input)) + + +class ReLUSquaredActivation(nn.Module): + """ + Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + """ + + def forward(self, input): + relu_applied = nn.functional.relu(input) + squared = torch.square(relu_applied) + return squared + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "gelu_pytorch_tanh": PytorchGELUTanh, + "gelu_accurate": AccurateGELUActivation, + "laplace": LaplaceActivation, + "leaky_relu": nn.LeakyReLU, + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "relu2": ReLUSquaredActivation, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": nn.SiLU, + "swish": nn.SiLU, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/activations_tf.py b/activations_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..d12b73ea45176f3a4bc42cdabe8b73078a3b90f2 --- /dev/null +++ b/activations_tf.py @@ -0,0 +1,147 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import tensorflow as tf +from packaging.version import parse + + +try: + import tf_keras as keras +except (ModuleNotFoundError, ImportError): + import keras + + if parse(keras.__version__).major > 2: + raise ValueError( + "Your currently installed version of Keras is Keras 3, but this is not yet supported in " + "Transformers. Please install the backwards-compatible tf-keras package with " + "`pip install tf-keras`." + ) + + +def _gelu(x): + """ + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see + https://arxiv.org/abs/1606.08415 + """ + x = tf.convert_to_tensor(x) + cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) + + return x * cdf + + +def _gelu_new(x): + """ + Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 + + Args: + x: float Tensor to perform activation + + Returns: + `x` with the GELU activation applied. + """ + x = tf.convert_to_tensor(x) + pi = tf.cast(math.pi, x.dtype) + coeff = tf.cast(0.044715, x.dtype) + cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + + return x * cdf + + +def mish(x): + x = tf.convert_to_tensor(x) + + return x * tf.tanh(tf.math.softplus(x)) + + +def gelu_fast(x): + x = tf.convert_to_tensor(x) + coeff1 = tf.cast(0.044715, x.dtype) + coeff2 = tf.cast(0.7978845608, x.dtype) + + return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) + + +def quick_gelu(x): + x = tf.convert_to_tensor(x) + coeff = tf.cast(1.702, x.dtype) + return x * tf.math.sigmoid(coeff * x) + + +def gelu_10(x): + """ + Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as + it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602 + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see + https://arxiv.org/abs/1606.08415 :param x: :return: + """ + return tf.clip_by_value(_gelu(x), -10, 10) + + +def glu(x, axis=-1): + """ + Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where + the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B). + + Args: + `x`: float Tensor to perform activation + `axis`: dimension across which `x` be split in half + + Returns: + `x` with the GLU activation applied (with its size halved across the dimension `axis`). + """ + a, b = tf.split(x, 2, axis=axis) + return a * tf.math.sigmoid(b) + + +if parse(tf.version.VERSION) >= parse("2.4"): + + def approximate_gelu_wrap(x): + return keras.activations.gelu(x, approximate=True) + + gelu = keras.activations.gelu + gelu_new = approximate_gelu_wrap +else: + gelu = _gelu + gelu_new = _gelu_new + + +ACT2FN = { + "gelu": gelu, + "gelu_10": gelu_10, + "gelu_fast": gelu_fast, + "gelu_new": gelu_new, + "glu": glu, + "mish": mish, + "quick_gelu": quick_gelu, + "relu": keras.activations.relu, + "sigmoid": keras.activations.sigmoid, + "silu": keras.activations.swish, + "swish": keras.activations.swish, + "tanh": keras.activations.tanh, +} + + +def get_tf_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") diff --git a/audio_utils.py b/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f11287f309cf1437b84928b6721052b7ba4531 --- /dev/null +++ b/audio_utils.py @@ -0,0 +1,1123 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks +and remove unnecessary dependencies. +""" + +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np + + +def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: + """ + Convert frequency from hertz to mels. + + Args: + freq (`float` or `np.ndarray`): + The frequency, or multiple frequencies, in hertz (Hz). + mel_scale (`str`, *optional*, defaults to `"htk"`): + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + + Returns: + `float` or `np.ndarray`: The frequencies on the mel scale. + """ + + if mel_scale not in ["slaney", "htk", "kaldi"]: + raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".') + + if mel_scale == "htk": + return 2595.0 * np.log10(1.0 + (freq / 700.0)) + elif mel_scale == "kaldi": + return 1127.0 * np.log(1.0 + (freq / 700.0)) + + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = 27.0 / np.log(6.4) + mels = 3.0 * freq / 200.0 + + if isinstance(freq, np.ndarray): + log_region = freq >= min_log_hertz + mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep + elif freq >= min_log_hertz: + mels = min_log_mel + np.log(freq / min_log_hertz) * logstep + + return mels + + +def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: + """ + Convert frequency from mels to hertz. + + Args: + mels (`float` or `np.ndarray`): + The frequency, or multiple frequencies, in mels. + mel_scale (`str`, *optional*, `"htk"`): + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + + Returns: + `float` or `np.ndarray`: The frequencies in hertz. + """ + + if mel_scale not in ["slaney", "htk", "kaldi"]: + raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".') + + if mel_scale == "htk": + return 700.0 * (np.power(10, mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (np.exp(mels / 1127.0) - 1.0) + + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = np.log(6.4) / 27.0 + freq = 200.0 * mels / 3.0 + + if isinstance(mels, np.ndarray): + log_region = mels >= min_log_mel + freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel)) + elif mels >= min_log_mel: + freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel)) + + return freq + + +def hertz_to_octave( + freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12 +): + """ + Convert frequency from hertz to fractional octave numbers. + Adapted from *librosa*. + + Args: + freq (`float` or `np.ndarray`): + The frequency, or multiple frequencies, in hertz (Hz). + tuning (`float`, defaults to `0.`): + Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave. + bins_per_octave (`int`, defaults to `12`): + Number of bins per octave. + + Returns: + `float` or `np.ndarray`: The frequencies on the octave scale. + """ + stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave) + octave = np.log2(freq / (float(stuttgart_pitch) / 16)) + return octave + + +def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray: + """ + Creates a triangular filter bank. + + Adapted from *torchaudio* and *librosa*. + + Args: + fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`): + Discrete frequencies of the FFT bins in Hz. + filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`): + Center frequencies of the triangular filters to create, in Hz. + + Returns: + `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)` + """ + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) + + +def chroma_filter_bank( + num_frequency_bins: int, + num_chroma: int, + sampling_rate: int, + tuning: float = 0.0, + power: Optional[float] = 2.0, + weighting_parameters: Optional[Tuple[float]] = (5.0, 2), + start_at_c_chroma: Optional[bool] = True, +): + """ + Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins. + + Adapted from *librosa*. + + Args: + num_frequency_bins (`int`): + Number of frequencies used to compute the spectrogram (should be the same as in `stft`). + num_chroma (`int`): + Number of chroma bins (i.e pitch classes). + sampling_rate (`float`): + Sample rate of the audio waveform. + tuning (`float`): + Tuning deviation from A440 in fractions of a chroma bin. + power (`float`, *optional*, defaults to 2.0): + If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm. + weighting_parameters (`Tuple[float]`, *optional*, defaults to `(5., 2.)`): + If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and + the second element being the Gaussian half-width. + start_at_c_chroma (`float`, *optional*, defaults to `True`): + If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'. + Returns: + `np.ndarray` of shape `(num_frequency_bins, num_chroma)` + """ + # Get the FFT bins, not counting the DC component + frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:] + + freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma) + + # make up a value for the 0 Hz bin = 1.5 octaves below bin 1 + # (so chroma is 50% rotated from bin 1, and bin width is broad) + freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins)) + + bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1])) + + chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T + + num_chroma2 = np.round(float(num_chroma) / 2) + + # Project into range -num_chroma/2 .. num_chroma/2 + # add on fixed offset of 10*num_chroma to ensure all values passed to + # rem are positive + chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2 + + # Gaussian bumps - 2*D to make them narrower + chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2) + + # normalize each column + if power is not None: + chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power) + + # Maybe apply scaling for fft bins + if weighting_parameters is not None: + center, half_width = weighting_parameters + chroma_filters *= np.tile( + np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)), + (num_chroma, 1), + ) + + if start_at_c_chroma: + chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0) + + # remove aliasing columns, copy to ensure row-contiguity + return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)]) + + +def mel_filter_bank( + num_frequency_bins: int, + num_mel_filters: int, + min_frequency: float, + max_frequency: float, + sampling_rate: int, + norm: Optional[str] = None, + mel_scale: str = "htk", + triangularize_in_mel_space: bool = False, +) -> np.ndarray: + """ + Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and + various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters + are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these + features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency. + + Different banks of mel filters were introduced in the literature. The following variations are supported: + + - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech + bandwidth of `[0, 4600]` Hz. + - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech + bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz. + - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and + speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization. + - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of + 12.5 kHz and speech bandwidth of `[0, 6250]` Hz. + + This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's + `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation. + + Args: + num_frequency_bins (`int`): + Number of frequencies used to compute the spectrogram (should be the same as in `stft`). + num_mel_filters (`int`): + Number of mel filters to generate. + min_frequency (`float`): + Lowest frequency of interest in Hz. + max_frequency (`float`): + Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`. + sampling_rate (`int`): + Sample rate of the audio waveform. + norm (`str`, *optional*): + If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization). + mel_scale (`str`, *optional*, defaults to `"htk"`): + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + triangularize_in_mel_space (`bool`, *optional*, defaults to `False`): + If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This + should be set to `true` in order to get the same results as `torchaudio` when computing mel filters. + + Returns: + `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a + projection matrix to go from a spectrogram to a mel spectrogram. + """ + if norm is not None and norm != "slaney": + raise ValueError('norm must be one of None or "slaney"') + + # center points of the triangular mel filters + mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) + mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) + mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) + filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) + + if triangularize_in_mel_space: + # frequencies of FFT bins in Hz, but filters triangularized in mel space + fft_bin_width = sampling_rate / (num_frequency_bins * 2) + fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale) + filter_freqs = mel_freqs + else: + # frequencies of FFT bins in Hz + fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) + + mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) + mel_filters *= np.expand_dims(enorm, 0) + + if (mel_filters.max(axis=0) == 0.0).any(): + warnings.warn( + "At least one mel filter has all zero values. " + f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. " + f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low." + ) + + return mel_filters + + +def optimal_fft_length(window_length: int) -> int: + """ + Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not + already a power of two, rounds it up to the next power or two. + + The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size + of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples + is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies, + it simply gives a higher frequency resolution (i.e. the frequency bins are smaller). + """ + return 2 ** int(np.ceil(np.log2(window_length))) + + +def window_function( + window_length: int, + name: str = "hann", + periodic: bool = True, + frame_length: Optional[int] = None, + center: bool = True, +) -> np.ndarray: + """ + Returns an array containing the specified window. This window is intended to be used with `stft`. + + The following window types are supported: + + - `"boxcar"`: a rectangular window + - `"hamming"`: the Hamming window + - `"hann"`: the Hann window + - `"povey"`: the Povey window + + Args: + window_length (`int`): + The length of the window in samples. + name (`str`, *optional*, defaults to `"hann"`): + The name of the window function. + periodic (`bool`, *optional*, defaults to `True`): + Whether the window is periodic or symmetric. + frame_length (`int`, *optional*): + The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller + than the frame length, so that it will be zero-padded. + center (`bool`, *optional*, defaults to `True`): + Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided. + + Returns: + `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window. + """ + length = window_length + 1 if periodic else window_length + + if name == "boxcar": + window = np.ones(length) + elif name in ["hamming", "hamming_window"]: + window = np.hamming(length) + elif name in ["hann", "hann_window"]: + window = np.hanning(length) + elif name in ["povey"]: + window = np.power(np.hanning(length), 0.85) + else: + raise ValueError(f"Unknown window function '{name}'") + + if periodic: + window = window[:-1] + + if frame_length is None: + return window + + if window_length > frame_length: + raise ValueError( + f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})" + ) + + padded_window = np.zeros(frame_length) + offset = (frame_length - window_length) // 2 if center else 0 + padded_window[offset : offset + window_length] = window + return padded_window + + +# TODO This method does not support batching yet as we are mainly focused on inference. +def spectrogram( + waveform: np.ndarray, + window: np.ndarray, + frame_length: int, + hop_length: int, + fft_length: Optional[int] = None, + power: Optional[float] = 1.0, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + preemphasis: Optional[float] = None, + mel_filters: Optional[np.ndarray] = None, + mel_floor: float = 1e-10, + log_mel: Optional[str] = None, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, + remove_dc_offset: Optional[bool] = None, + dtype: np.dtype = np.float32, +) -> np.ndarray: + """ + Calculates a spectrogram over one waveform using the Short-Time Fourier Transform. + + This function can create the following kinds of spectrograms: + + - amplitude spectrogram (`power = 1.0`) + - power spectrogram (`power = 2.0`) + - complex-valued spectrogram (`power = None`) + - log spectrogram (use `log_mel` argument) + - mel spectrogram (provide `mel_filters`) + - log-mel spectrogram (provide `mel_filters` and `log_mel`) + + How this works: + + 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length + - hop_length` samples. + 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`. + 3. The DFT is taken of each windowed frame. + 4. The results are stacked into a spectrogram. + + We make a distinction between the following "blocks" of sample data, each of which may have a different lengths: + + - The analysis frame. This is the size of the time slices that the input waveform is split into. + - The window. Each analysis frame is multiplied by the window to avoid spectral leakage. + - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram. + + In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A + padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, + typically the next power of two. + + Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and + `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms + can be constructed. + + Args: + waveform (`np.ndarray` of shape `(length,)`): + The input waveform. This must be a single real-valued, mono waveform. + window (`np.ndarray` of shape `(frame_length,)`): + The windowing function to apply, including zero-padding if necessary. The actual window length may be + shorter than `frame_length`, but we're assuming the array has already been zero-padded. + frame_length (`int`): + The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also + allow smaller sizes. + hop_length (`int`): + The stride between successive analysis frames in samples. + fft_length (`int`, *optional*): + The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have. + For optimal speed, this should be a power of two. If `None`, uses `frame_length`. + power (`float`, *optional*, defaults to 1.0): + If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns + complex numbers. + center (`bool`, *optional*, defaults to `True`): + Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame + `t` will start at time `t * hop_length`. + pad_mode (`str`, *optional*, defaults to `"reflect"`): + Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"` + (pad with edge values), `"reflect"` (pads with mirrored values). + onesided (`bool`, *optional*, defaults to `True`): + If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1` + frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins. + preemphasis (`float`, *optional*) + Coefficient for a low-pass filter that applies pre-emphasis before the DFT. + mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*): + The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram. + mel_floor (`float`, *optional*, defaults to 1e-10): + Minimum value of mel frequency banks. + log_mel (`str`, *optional*): + How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take + the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be + used when `power` is not `None`. + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-10`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an + amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + remove_dc_offset (`bool`, *optional*): + Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in + order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters. + dtype (`np.dtype`, *optional*, defaults to `np.float32`): + Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be + `np.complex64`. + + Returns: + `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape + `(num_mel_filters, length)` for a mel spectrogram. + """ + window_length = len(window) + + if fft_length is None: + fft_length = frame_length + + if frame_length > fft_length: + raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") + + if window_length != frame_length: + raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})") + + if hop_length <= 0: + raise ValueError("hop_length must be greater than zero") + + if waveform.ndim != 1: + raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}") + + if np.iscomplexobj(waveform): + raise ValueError("Complex-valued input waveforms are not currently supported") + + if power is None and mel_filters is not None: + raise ValueError( + "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram." + "Specify `power` to fix this issue." + ) + + # center pad the waveform + if center: + padding = [(int(frame_length // 2), int(frame_length // 2))] + waveform = np.pad(waveform, padding, mode=pad_mode) + + # promote to float64, since np.fft uses float64 internally + waveform = waveform.astype(np.float64) + window = window.astype(np.float64) + + # split waveform into frames of frame_length size + num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length)) + + num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length + spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64) + + # rfft is faster than fft + fft_func = np.fft.rfft if onesided else np.fft.fft + buffer = np.zeros(fft_length) + + timestep = 0 + for frame_idx in range(num_frames): + buffer[:frame_length] = waveform[timestep : timestep + frame_length] + + if remove_dc_offset: + buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean() + + if preemphasis is not None: + buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1] + buffer[0] *= 1 - preemphasis + + buffer[:frame_length] *= window + + spectrogram[frame_idx] = fft_func(buffer) + timestep += hop_length + + # note: ** is much faster than np.power + if power is not None: + spectrogram = np.abs(spectrogram, dtype=np.float64) ** power + + spectrogram = spectrogram.T + + if mel_filters is not None: + spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram)) + + if power is not None and log_mel is not None: + if log_mel == "log": + spectrogram = np.log(spectrogram) + elif log_mel == "log10": + spectrogram = np.log10(spectrogram) + elif log_mel == "dB": + if power == 1.0: + spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range) + elif power == 2.0: + spectrogram = power_to_db(spectrogram, reference, min_value, db_range) + else: + raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}") + else: + raise ValueError(f"Unknown log_mel option: {log_mel}") + + spectrogram = np.asarray(spectrogram, dtype) + + return spectrogram + + +def spectrogram_batch( + waveform_list: List[np.ndarray], + window: np.ndarray, + frame_length: int, + hop_length: int, + fft_length: Optional[int] = None, + power: Optional[float] = 1.0, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + preemphasis: Optional[float] = None, + mel_filters: Optional[np.ndarray] = None, + mel_floor: float = 1e-10, + log_mel: Optional[str] = None, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, + remove_dc_offset: Optional[bool] = None, + dtype: np.dtype = np.float32, +) -> List[np.ndarray]: + """ + Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing. + This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting. + + It supports generating various types of spectrograms: + + - amplitude spectrogram (`power = 1.0`) + - power spectrogram (`power = 2.0`) + - complex-valued spectrogram (`power = None`) + - log spectrogram (use `log_mel` argument) + - mel spectrogram (provide `mel_filters`) + - log-mel spectrogram (provide `mel_filters` and `log_mel`) + + How this works: + + 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length + - hop_length` samples. + 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`. + 3. The DFT is taken of each windowed frame. + 4. The results are stacked into a spectrogram. + + We make a distinction between the following "blocks" of sample data, each of which may have a different lengths: + + - The analysis frame. This is the size of the time slices that the input waveform is split into. + - The window. Each analysis frame is multiplied by the window to avoid spectral leakage. + - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram. + + In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A + padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, + typically the next power of two. + + Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`. + + Args: + waveform_list (`List[np.ndarray]` with arrays of shape `(length,)`): + The list of input waveforms, each a single-channel (mono) signal. + window (`np.ndarray` of shape `(frame_length,)`): + The windowing function to apply, including zero-padding if necessary. + frame_length (`int`): + The length of each frame for analysis. + hop_length (`int`): + The step size between successive frames. + fft_length (`int`, *optional*): + The size of the FFT buffer, defining frequency bin resolution. + power (`float`, *optional*, defaults to 1.0): + Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex. + center (`bool`, *optional*, defaults to `True`): + Whether to center-pad the waveform frames. + pad_mode (`str`, *optional*, defaults to `"reflect"`): + The padding strategy when `center` is `True`. + onesided (`bool`, *optional*, defaults to `True`): + If True, returns a one-sided spectrogram for real input signals. + preemphasis (`float`, *optional*): + Applies a pre-emphasis filter to each frame. + mel_filters (`np.ndarray`, *optional*): + Mel filter bank for converting to mel spectrogram. + mel_floor (`float`, *optional*, defaults to 1e-10): + Floor value for mel spectrogram to avoid log(0). + log_mel (`str`, *optional*): + Specifies log scaling strategy; options are None, "log", "log10", "dB". + reference (`float`, *optional*, defaults to 1.0): + Reference value for dB conversion in log_mel. + min_value (`float`, *optional*, defaults to 1e-10): + Minimum floor value for log scale conversions. + db_range (`float`, *optional*): + Dynamic range for dB scale spectrograms. + remove_dc_offset (`bool`, *optional*): + Whether to remove the DC offset from each frame. + dtype (`np.dtype`, *optional*, defaults to `np.float32`): + Data type of the output spectrogram. + + Returns: + List[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform. + """ + window_length = len(window) + + if fft_length is None: + fft_length = frame_length + + if frame_length > fft_length: + raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") + + if window_length != frame_length: + raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})") + + if hop_length <= 0: + raise ValueError("hop_length must be greater than zero") + + # Check the dimensions of the waveform , and if waveform is complex + for waveform in waveform_list: + if waveform.ndim != 1: + raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}") + if np.iscomplexobj(waveform): + raise ValueError("Complex-valued input waveforms are not currently supported") + # Center pad the waveform + if center: + padding = [(int(frame_length // 2), int(frame_length // 2))] + waveform_list = [ + np.pad( + waveform, + padding, + mode=pad_mode, + ) + for waveform in waveform_list + ] + original_waveform_lengths = [ + len(waveform) for waveform in waveform_list + ] # these lengths will be used to remove padding later + + # Batch pad the waveform + max_length = max(original_waveform_lengths) + padded_waveform_batch = np.array( + [ + np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0) + for waveform in waveform_list + ], + dtype=dtype, + ) + + # Promote to float64, since np.fft uses float64 internally + padded_waveform_batch = padded_waveform_batch.astype(np.float64) + window = window.astype(np.float64) + + # Split waveform into frames of frame_length size + num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length)) + # these lengths will be used to remove padding later + true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths] + num_batches = padded_waveform_batch.shape[0] + + num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length + spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64) + + # rfft is faster than fft + fft_func = np.fft.rfft if onesided else np.fft.fft + buffer = np.zeros((num_batches, fft_length)) + + for frame_idx in range(num_frames): + timestep = frame_idx * hop_length + buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length] + + if remove_dc_offset: + buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True) + + if preemphasis is not None: + buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1] + buffer[:, 0] *= 1 - preemphasis + + buffer[:, :frame_length] *= window + + spectrogram[:, frame_idx] = fft_func(buffer) + + # Note: ** is much faster than np.power + if power is not None: + spectrogram = np.abs(spectrogram, dtype=np.float64) ** power + + # Apply mel filters if provided + if mel_filters is not None: + result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1])) + spectrogram = np.maximum(mel_floor, result) + + # Convert to log scale if specified + if power is not None and log_mel is not None: + if log_mel == "log": + spectrogram = np.log(spectrogram) + elif log_mel == "log10": + spectrogram = np.log10(spectrogram) + elif log_mel == "dB": + if power == 1.0: + spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range) + elif power == 2.0: + spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range) + else: + raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}") + else: + raise ValueError(f"Unknown log_mel option: {log_mel}") + + spectrogram = np.asarray(spectrogram, dtype) + + spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))] + + return spectrogram_list + + +def power_to_db( + spectrogram: np.ndarray, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, +) -> np.ndarray: + """ + Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic + logarithm properties for numerical stability. + + The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a + linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it. + This means that large variations in energy may not sound all that different if the sound is loud to begin with. + This compression operation makes the (mel) spectrogram features match more closely what humans actually hear. + + Based on the implementation of `librosa.power_to_db`. + + Args: + spectrogram (`np.ndarray`): + The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared! + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-10`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the spectrogram in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None) + + return spectrogram + + +def power_to_db_batch( + spectrogram: np.ndarray, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, +) -> np.ndarray: + """ + Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`, + using basic logarithm properties for numerical stability. + + This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram. + + Args: + spectrogram (`np.ndarray`): + The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape). + Note that a power spectrogram has the amplitudes squared! + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-10`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the batch of spectrograms in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + # Apply db_range clipping per batch item + max_values = spectrogram.max(axis=(1, 2), keepdims=True) + spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) + + return spectrogram + + +def amplitude_to_db( + spectrogram: np.ndarray, + reference: float = 1.0, + min_value: float = 1e-5, + db_range: Optional[float] = None, +) -> np.ndarray: + """ + Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using + basic logarithm properties for numerical stability. + + The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a + linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it. + This means that large variations in energy may not sound all that different if the sound is loud to begin with. + This compression operation makes the (mel) spectrogram features match more closely what humans actually hear. + + Args: + spectrogram (`np.ndarray`): + The input amplitude (mel) spectrogram. + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-5`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the spectrogram in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None) + + return spectrogram + + +def amplitude_to_db_batch( + spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None +) -> np.ndarray: + """ + Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`, + using basic logarithm properties for numerical stability. + + The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram. + + Args: + spectrogram (`np.ndarray`): + The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape). + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-5`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the batch of spectrograms in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + # Apply db_range clipping per batch item + max_values = spectrogram.max(axis=(1, 2), keepdims=True) + spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) + + return spectrogram + + +### deprecated functions below this line ### + + +def get_mel_filter_banks( + nb_frequency_bins: int, + nb_mel_filters: int, + frequency_min: float, + frequency_max: float, + sample_rate: int, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> np.array: + warnings.warn( + "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers", + FutureWarning, + ) + return mel_filter_bank( + num_frequency_bins=nb_frequency_bins, + num_mel_filters=nb_mel_filters, + min_frequency=frequency_min, + max_frequency=frequency_max, + sampling_rate=sample_rate, + norm=norm, + mel_scale=mel_scale, + ) + + +def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True): + """ + In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed + segments called `frames`. + + The window length (window_length) defines how much of the signal is contained in each frame, while the hop length + defines the step between the beginning of each new frame. + + + Args: + waveform (`np.array` of shape `(sample_length,)`): + The raw waveform which will be split into smaller chunks. + hop_length (`int`, *optional*, defaults to 160): + Step between each window of the waveform. + fft_window_size (`int`, *optional*, defaults to 400): + Defines the size of the window. + center (`bool`, defaults to `True`): + Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the + waveform on the left and on the right. + + Return: + framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`): + The framed waveforms that can be fed to `np.fft`. + """ + warnings.warn( + "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers", + FutureWarning, + ) + frames = [] + for i in range(0, waveform.shape[0] + 1, hop_length): + if center: + half_window = (fft_window_size - 1) // 2 + 1 + start = i - half_window if i > half_window else 0 + end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0] + frame = waveform[start:end] + if start == 0: + padd_width = (-i + half_window, 0) + frame = np.pad(frame, pad_width=padd_width, mode="reflect") + + elif end == waveform.shape[0]: + padd_width = (0, (i - waveform.shape[0] + half_window)) + frame = np.pad(frame, pad_width=padd_width, mode="reflect") + + else: + frame = waveform[i : i + fft_window_size] + frame_width = frame.shape[0] + if frame_width < waveform.shape[0]: + frame = np.lib.pad( + frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0 + ) + frames.append(frame) + + frames = np.stack(frames, 0) + return frames + + +def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None): + """ + Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results + as `torch.stft`. + + Args: + frames (`np.array` of dimension `(num_frames, fft_window_size)`): + A framed audio signal obtained using `audio_utils.fram_wav`. + windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`: + A array representing the function that will be used to reduces the amplitude of the discontinuities at the + boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function. + For more information on the discontinuities, called *Spectral leakage*, refer to [this + tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf + fft_window_size (`int`, *optional*): + Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the + spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of + frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to + `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally. + + Example: + + ```python + >>> from transformers.audio_utils import stft, fram_wave + >>> import numpy as np + + >>> audio = np.random.rand(50) + >>> fft_window_size = 10 + >>> hop_length = 2 + >>> framed_audio = fram_wave(audio, hop_length, fft_window_size) + >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1)) + ``` + + Returns: + spectrogram (`np.ndarray`): + A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm + """ + warnings.warn( + "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers", + FutureWarning, + ) + frame_size = frames.shape[1] + + if fft_window_size is None: + fft_window_size = frame_size + + if fft_window_size < frame_size: + raise ValueError("FFT size must greater or equal the frame size") + # number of FFT bins to store + nb_frequency_bins = (fft_window_size >> 1) + 1 + + spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64) + fft_signal = np.zeros(fft_window_size) + + for f, frame in enumerate(frames): + if windowing_function is not None: + np.multiply(frame, windowing_function, out=fft_signal[:frame_size]) + else: + fft_signal[:frame_size] = frame + spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins] + return spectrogram.T diff --git a/base_tokenizer.cpython-312.pyc b/base_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ddf8ccb1a207d00d5cbb4aae8bf853a054db5a3 Binary files /dev/null and b/base_tokenizer.cpython-312.pyc differ diff --git a/base_tokenizer.py b/base_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc2503d662b6b580e9e2971e8075d0436f1a56d --- /dev/null +++ b/base_tokenizer.py @@ -0,0 +1,418 @@ +from typing import Dict, List, Optional, Tuple, Union + +from tokenizers import AddedToken, EncodeInput, Encoding, InputSequence, Tokenizer +from tokenizers.decoders import Decoder +from tokenizers.models import Model +from tokenizers.normalizers import Normalizer +from tokenizers.pre_tokenizers import PreTokenizer +from tokenizers.processors import PostProcessor + + +Offsets = Tuple[int, int] + + +class BaseTokenizer: + def __init__(self, tokenizer: Tokenizer, parameters=None): + self._tokenizer = tokenizer + self._parameters = parameters if parameters is not None else {} + + def __repr__(self): + return "Tokenizer(vocabulary_size={}, {})".format( + self._tokenizer.get_vocab_size(), + ", ".join(k + "=" + str(v) for k, v in self._parameters.items()), + ) + + def num_special_tokens_to_add(self, is_pair: bool) -> int: + """ + Return the number of special tokens that would be added for single/pair sentences. + :param is_pair: Boolean indicating if the input would be a single sentence or a pair + :return: + """ + return self._tokenizer.num_special_tokens_to_add(is_pair) + + def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]: + """Returns the vocabulary + + Args: + with_added_tokens: boolean: + Whether to include the added tokens in the vocabulary + + Returns: + The vocabulary + """ + return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens) + + def get_added_tokens_decoder(self) -> Dict[int, AddedToken]: + """Returns the added reverse vocabulary + + Returns: + The added vocabulary mapping ints to AddedTokens + """ + return self._tokenizer.get_added_tokens_decoder() + + def get_vocab_size(self, with_added_tokens: bool = True) -> int: + """Return the size of vocabulary, with or without added tokens. + + Args: + with_added_tokens: (`optional`) bool: + Whether to count in added special tokens or not + + Returns: + Size of vocabulary + """ + return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens) + + def enable_padding( + self, + direction: Optional[str] = "right", + pad_to_multiple_of: Optional[int] = None, + pad_id: Optional[int] = 0, + pad_type_id: Optional[int] = 0, + pad_token: Optional[str] = "[PAD]", + length: Optional[int] = None, + ): + """Change the padding strategy + + Args: + direction: (`optional`) str: + Can be one of: `right` or `left` + + pad_to_multiple_of: (`optional`) unsigned int: + If specified, the padding length should always snap to the next multiple of + the given value. For example if we were going to pad with a length of 250 but + `pad_to_multiple_of=8` then we will pad to 256. + + pad_id: (`optional`) unsigned int: + The indice to be used when padding + + pad_type_id: (`optional`) unsigned int: + The type indice to be used when padding + + pad_token: (`optional`) str: + The pad token to be used when padding + + length: (`optional`) unsigned int: + If specified, the length at which to pad. If not specified + we pad using the size of the longest sequence in a batch + """ + return self._tokenizer.enable_padding( + direction=direction, + pad_to_multiple_of=pad_to_multiple_of, + pad_id=pad_id, + pad_type_id=pad_type_id, + pad_token=pad_token, + length=length, + ) + + def no_padding(self): + """Disable padding""" + return self._tokenizer.no_padding() + + @property + def padding(self) -> Optional[dict]: + """Get the current padding parameters + + Returns: + None if padding is disabled, a dict with the currently set parameters + if the padding is enabled. + """ + return self._tokenizer.padding + + def enable_truncation(self, max_length: int, stride: Optional[int] = 0, strategy: Optional[str] = "longest_first"): + """Change the truncation options + + Args: + max_length: unsigned int: + The maximum length at which to truncate + + stride: (`optional`) unsigned int: + The length of the previous first sequence to be included + in the overflowing sequence + + strategy: (`optional`) str: + Can be one of `longest_first`, `only_first` or `only_second` + """ + return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy) + + def no_truncation(self): + """Disable truncation""" + return self._tokenizer.no_truncation() + + @property + def truncation(self) -> Optional[dict]: + """Get the current truncation parameters + + Returns: + None if truncation is disabled, a dict with the current truncation parameters if + truncation is enabled + """ + return self._tokenizer.truncation + + def add_tokens(self, tokens: List[Union[str, AddedToken]]) -> int: + """Add the given tokens to the vocabulary + + Args: + tokens: List[Union[str, AddedToken]]: + A list of tokens to add to the vocabulary. Each token can either be + a string, or an instance of AddedToken + + Returns: + The number of tokens that were added to the vocabulary + """ + return self._tokenizer.add_tokens(tokens) + + def add_special_tokens(self, special_tokens: List[Union[str, AddedToken]]) -> int: + """Add the given special tokens to the vocabulary, and treat them as special tokens. + + The special tokens will never be processed by the model, and will be + removed while decoding. + + Args: + tokens: List[Union[str, AddedToken]]: + A list of special tokens to add to the vocabulary. Each token can either be + a string, or an instance of AddedToken + + Returns: + The number of tokens that were added to the vocabulary + """ + return self._tokenizer.add_special_tokens(special_tokens) + + def normalize(self, sequence: str) -> str: + """Normalize the given sequence + + Args: + sequence: str: + The sequence to normalize + + Returns: + The normalized string + """ + return self._tokenizer.normalize(sequence) + + def encode( + self, + sequence: InputSequence, + pair: Optional[InputSequence] = None, + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> Encoding: + """Encode the given sequence and pair. This method can process raw text sequences as well + as already pre-tokenized sequences. + + Args: + sequence: InputSequence: + The sequence we want to encode. This sequence can be either raw text or + pre-tokenized, according to the `is_pretokenized` argument: + + - If `is_pretokenized=False`: `InputSequence` is expected to be `str` + - If `is_pretokenized=True`: `InputSequence` is expected to be + `Union[List[str], Tuple[str]]` + + is_pretokenized: bool: + Whether the input is already pre-tokenized. + + add_special_tokens: bool: + Whether to add the special tokens while encoding. + + Returns: + An Encoding + """ + if sequence is None: + raise ValueError("encode: `sequence` can't be `None`") + + return self._tokenizer.encode(sequence, pair, is_pretokenized, add_special_tokens) + + def encode_batch( + self, + inputs: List[EncodeInput], + is_pretokenized: bool = False, + add_special_tokens: bool = True, + ) -> List[Encoding]: + """Encode the given inputs. This method accept both raw text sequences as well as already + pre-tokenized sequences. + + Args: + inputs: List[EncodeInput]: + A list of single sequences or pair sequences to encode. Each `EncodeInput` is + expected to be of the following form: + `Union[InputSequence, Tuple[InputSequence, InputSequence]]` + + Each `InputSequence` can either be raw text or pre-tokenized, + according to the `is_pretokenized` argument: + + - If `is_pretokenized=False`: `InputSequence` is expected to be `str` + - If `is_pretokenized=True`: `InputSequence` is expected to be + `Union[List[str], Tuple[str]]` + + is_pretokenized: bool: + Whether the input is already pre-tokenized. + + add_special_tokens: bool: + Whether to add the special tokens while encoding. + + Returns: + A list of Encoding + """ + + if inputs is None: + raise ValueError("encode_batch: `inputs` can't be `None`") + + return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens) + + def decode(self, ids: List[int], skip_special_tokens: Optional[bool] = True) -> str: + """Decode the given list of ids to a string sequence + + Args: + ids: List[unsigned int]: + A list of ids to be decoded + + skip_special_tokens: (`optional`) boolean: + Whether to remove all the special tokens from the output string + + Returns: + The decoded string + """ + if ids is None: + raise ValueError("None input is not valid. Should be a list of integers.") + + return self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) + + def decode_batch(self, sequences: List[List[int]], skip_special_tokens: Optional[bool] = True) -> str: + """Decode the list of sequences to a list of string sequences + + Args: + sequences: List[List[unsigned int]]: + A list of sequence of ids to be decoded + + skip_special_tokens: (`optional`) boolean: + Whether to remove all the special tokens from the output strings + + Returns: + A list of decoded strings + """ + if sequences is None: + raise ValueError("None input is not valid. Should be list of list of integers.") + + return self._tokenizer.decode_batch(sequences, skip_special_tokens=skip_special_tokens) + + def token_to_id(self, token: str) -> Optional[int]: + """Convert the given token to its corresponding id + + Args: + token: str: + The token to convert + + Returns: + The corresponding id if it exists, None otherwise + """ + return self._tokenizer.token_to_id(token) + + def id_to_token(self, id: int) -> Optional[str]: + """Convert the given token id to its corresponding string + + Args: + token: id: + The token id to convert + + Returns: + The corresponding string if it exists, None otherwise + """ + return self._tokenizer.id_to_token(id) + + def save_model(self, directory: str, prefix: Optional[str] = None): + """Save the current model to the given directory + + Args: + directory: str: + A path to the destination directory + + prefix: (Optional) str: + An optional prefix, used to prefix each file name + """ + return self._tokenizer.model.save(directory, prefix=prefix) + + def save(self, path: str, pretty: bool = True): + """Save the current Tokenizer at the given path + + Args: + path: str: + A path to the destination Tokenizer file + """ + return self._tokenizer.save(path, pretty) + + def to_str(self, pretty: bool = False): + """Get a serialized JSON version of the Tokenizer as a str + + Args: + pretty: bool: + Whether the JSON string should be prettified + + Returns: + str + """ + return self._tokenizer.to_str(pretty) + + def post_process( + self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True + ) -> Encoding: + """Apply all the post-processing steps to the given encodings. + + The various steps are: + 1. Truncate according to global params (provided to `enable_truncation`) + 2. Apply the PostProcessor + 3. Pad according to global params. (provided to `enable_padding`) + + Args: + encoding: Encoding: + The main Encoding to post process + + pair: Optional[Encoding]: + An optional pair Encoding + + add_special_tokens: bool: + Whether to add special tokens + + Returns: + The resulting Encoding + """ + return self._tokenizer.post_process(encoding, pair, add_special_tokens) + + @property + def model(self) -> Model: + return self._tokenizer.model + + @model.setter + def model(self, model: Model): + self._tokenizer.model = model + + @property + def normalizer(self) -> Normalizer: + return self._tokenizer.normalizer + + @normalizer.setter + def normalizer(self, normalizer: Normalizer): + self._tokenizer.normalizer = normalizer + + @property + def pre_tokenizer(self) -> PreTokenizer: + return self._tokenizer.pre_tokenizer + + @pre_tokenizer.setter + def pre_tokenizer(self, pre_tokenizer: PreTokenizer): + self._tokenizer.pre_tokenizer = pre_tokenizer + + @property + def post_processor(self) -> PostProcessor: + return self._tokenizer.post_processor + + @post_processor.setter + def post_processor(self, post_processor: PostProcessor): + self._tokenizer.post_processor = post_processor + + @property + def decoder(self) -> Decoder: + return self._tokenizer.decoder + + @decoder.setter + def decoder(self, decoder: Decoder): + self._tokenizer.decoder = decoder diff --git a/bert_wordpiece.cpython-312.pyc b/bert_wordpiece.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47125a565319d97facae3d41c0497ab25281c2e3 Binary files /dev/null and b/bert_wordpiece.cpython-312.pyc differ diff --git a/bert_wordpiece.py b/bert_wordpiece.py new file mode 100644 index 0000000000000000000000000000000000000000..ed98c626537d2ad7ec881dc237d91fddaffda9ae --- /dev/null +++ b/bert_wordpiece.py @@ -0,0 +1,151 @@ +from typing import Dict, Iterator, List, Optional, Union + +from tokenizers import AddedToken, Tokenizer, decoders, trainers +from tokenizers.models import WordPiece +from tokenizers.normalizers import BertNormalizer +from tokenizers.pre_tokenizers import BertPreTokenizer +from tokenizers.processors import BertProcessing + +from .base_tokenizer import BaseTokenizer + + +class BertWordPieceTokenizer(BaseTokenizer): + """Bert WordPiece Tokenizer""" + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + unk_token: Union[str, AddedToken] = "[UNK]", + sep_token: Union[str, AddedToken] = "[SEP]", + cls_token: Union[str, AddedToken] = "[CLS]", + pad_token: Union[str, AddedToken] = "[PAD]", + mask_token: Union[str, AddedToken] = "[MASK]", + clean_text: bool = True, + handle_chinese_chars: bool = True, + strip_accents: Optional[bool] = None, + lowercase: bool = True, + wordpieces_prefix: str = "##", + ): + if vocab is not None: + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token))) + else: + tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) + + # Let the tokenizer know about special tokens if they are part of the vocab + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + if tokenizer.token_to_id(str(sep_token)) is not None: + tokenizer.add_special_tokens([str(sep_token)]) + if tokenizer.token_to_id(str(cls_token)) is not None: + tokenizer.add_special_tokens([str(cls_token)]) + if tokenizer.token_to_id(str(pad_token)) is not None: + tokenizer.add_special_tokens([str(pad_token)]) + if tokenizer.token_to_id(str(mask_token)) is not None: + tokenizer.add_special_tokens([str(mask_token)]) + + tokenizer.normalizer = BertNormalizer( + clean_text=clean_text, + handle_chinese_chars=handle_chinese_chars, + strip_accents=strip_accents, + lowercase=lowercase, + ) + tokenizer.pre_tokenizer = BertPreTokenizer() + + if vocab is not None: + sep_token_id = tokenizer.token_to_id(str(sep_token)) + if sep_token_id is None: + raise TypeError("sep_token not found in the vocabulary") + cls_token_id = tokenizer.token_to_id(str(cls_token)) + if cls_token_id is None: + raise TypeError("cls_token not found in the vocabulary") + + tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id)) + tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) + + parameters = { + "model": "BertWordPiece", + "unk_token": unk_token, + "sep_token": sep_token, + "cls_token": cls_token, + "pad_token": pad_token, + "mask_token": mask_token, + "clean_text": clean_text, + "handle_chinese_chars": handle_chinese_chars, + "strip_accents": strip_accents, + "lowercase": lowercase, + "wordpieces_prefix": wordpieces_prefix, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab: str, **kwargs): + vocab = WordPiece.read_file(vocab) + return BertWordPieceTokenizer(vocab, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + ): + """Train the model using the given files""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/byte_level_bpe.cpython-312.pyc b/byte_level_bpe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94c2345d0e50f3f0194b527c0322fa1742b5b77 Binary files /dev/null and b/byte_level_bpe.cpython-312.pyc differ diff --git a/byte_level_bpe.py b/byte_level_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..32f44674cbdc1f0407f5202934c832d8c57cf25c --- /dev/null +++ b/byte_level_bpe.py @@ -0,0 +1,122 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, processors, trainers +from tokenizers.models import BPE +from tokenizers.normalizers import Lowercase, Sequence, unicode_normalizer_from_str + +from .base_tokenizer import BaseTokenizer + + +class ByteLevelBPETokenizer(BaseTokenizer): + """ByteLevelBPETokenizer + + Represents a Byte-level BPE as introduced by OpenAI with their GPT-2 model + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]] = None, + add_prefix_space: bool = False, + lowercase: bool = False, + dropout: Optional[float] = None, + unicode_normalizer: Optional[str] = None, + continuing_subword_prefix: Optional[str] = None, + end_of_word_suffix: Optional[str] = None, + trim_offsets: bool = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=dropout, + continuing_subword_prefix=continuing_subword_prefix or "", + end_of_word_suffix=end_of_word_suffix or "", + ) + ) + else: + tokenizer = Tokenizer(BPE()) + + # Check for Unicode normalization first (before everything else) + normalizers = [] + + if unicode_normalizer: + normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + + if lowercase: + normalizers += [Lowercase()] + + # Create the normalizer structure + if len(normalizers) > 0: + if len(normalizers) > 1: + tokenizer.normalizer = Sequence(normalizers) + else: + tokenizer.normalizer = normalizers[0] + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.ByteLevel(trim_offsets=trim_offsets) + + parameters = { + "model": "ByteLevelBPE", + "add_prefix_space": add_prefix_space, + "lowercase": lowercase, + "dropout": dropout, + "unicode_normalizer": unicode_normalizer, + "continuing_subword_prefix": continuing_subword_prefix, + "end_of_word_suffix": end_of_word_suffix, + "trim_offsets": trim_offsets, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return ByteLevelBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + show_progress=show_progress, + special_tokens=special_tokens, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + show_progress=show_progress, + special_tokens=special_tokens, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/cache_utils.py b/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f38fc8f9824d3b88eb00296451994d16347eab2a --- /dev/null +++ b/cache_utils.py @@ -0,0 +1,2148 @@ +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from packaging import version + +from .configuration_utils import PretrainedConfig +from .utils import ( + is_hqq_available, + is_optimum_quanto_available, + is_torchdynamo_compiling, + logging, +) +from .utils.deprecation import deprecate_kwarg + + +if is_hqq_available(): + from hqq.core.quantize import Quantizer as HQQQuantizer + +logger = logging.get_logger(__name__) + + +class Cache(torch.nn.Module): + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" + # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles + # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so + # we change naming to be more explicit + def get_max_length(self) -> Optional[int]: + logger.warning_once( + "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " + "Calling `get_max_cache()` will raise error from v4.48" + ) + return self.get_max_cache_shape() + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] != []: + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx] != []: + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original presicion. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + + def __init__( + self, + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ): + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + + +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ + + cache_implementation = "static" + + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + + if self.batch_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="batch_size", + correct_value="> 0", + found_value=self.batch_size, + ), + ) + + if self.max_cache_len <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="max_cache_len", + correct_value="> 0", + found_value=self.max_cache_len, + ), + ) + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ + + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def __init__(self, num_hidden_layers: Optional[int] = None) -> None: + super().__init__() + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + len(self.key_cache[layer_idx]) == 0 + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx] != []: + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def batch_split( + self, full_batch_size: int, split_size: int, num_hidden_layers: int = None + ) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default CUDA stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("OffloadedCache can only be used with a GPU") + super().__init__() + self.original_device = [] + self.prefetch_stream = torch.cuda.Stream() + self.beam_idx = None # used to delay beam search operations + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with torch.cuda.stream(self.prefetch_stream): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + # Evict the previous layer if necessary + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + original_device = self.original_device[layer_idx] + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + # Now deal with beam search ops which were delayed + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(original_device) + key_tensor = key_tensor.index_select(0, self.beam_idx) + value_tensor = value_tensor.index_select(0, self.beam_idx) + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + return (key_tensor, value_tensor) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Saves the beam indices and reorders the cache when the tensor is back to its device.""" + # We delay this operation until the tensors are back to their original + # device because performing torch.index_select on the CPU is very slow + del self.beam_idx + self.beam_idx = beam_idx.clone() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(self.key_cache) == layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + self.evict_previous_layer(layer_idx) + else: + key_tensor, value_tensor = self[layer_idx] + self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError + # if a method is not supposed to be supported in a subclass we should set it to None + from_legacy_cache = None + + to_legacy_cache = None + + +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + super().__init__() + self._quantized_key_cache: List[torch.Tensor] = [] + self._quantized_value_cache: List[torch.Tensor] = [] + + self.nbits = cache_config.nbits + self.residual_length = cache_config.residual_length + self.q_group_size = cache_config.q_group_size + self.axis_key = cache_config.axis_key + self.axis_value = cache_config.axis_value + self.compute_dtype = cache_config.compute_dtype + self.device = cache_config.device + + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) < layer_idx: + raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(self.key_cache) == layer_idx: + self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) + self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) + self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + keys_to_return, values_to_return = key_states, value_states + else: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] + values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] + + keys_to_return = torch.cat(keys_to_return, dim=-2) + values_to_return = torch.cat(values_to_return, dim=-2) + if ( + self.key_cache[layer_idx].dim() == 4 + and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length + ): + self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.axis_value + ) + self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def _quantize(self, tensor, axis): + """Quantizes a key/value using a defined quantization method.""" + raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") + + def _dequantize(self, q_tensor): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") + + +class QuantoQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 + + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") + + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization + + def _quantize(self, tensor, axis): + # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + + def _dequantize(self, qtensor): + return qtensor.dequantize() + + +class HQQQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor, axis): + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype + return qtensor, meta + + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SinkCache() + ``` + """ + + is_sliding = True + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" + return self.window_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + max_cache_len: int = None, + device: torch.device = None, + dtype: torch.dtype = torch.float32, + max_batch_size: Optional[int] = None, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ) -> None: + super().__init__() + if batch_size is not None: + logger.warning_once( + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." + ) + + self.max_batch_size = batch_size or max_batch_size + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + for idx in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[idx] + else: + layer_device = device + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + # Notes: + # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case + # it is not needed anyway) + # 2. `torch.export()` requires mutations to be registered as buffers. + if not is_torchdynamo_compiling(): + self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) + self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) + new_layer_key_cache = getattr(self, f"key_cache_{idx}") + new_layer_value_cache = getattr(self, f"value_cache_{idx}") + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + + +class SlidingWindowCache(StaticCache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() + ``` + """ + + is_sliding = True + + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + max_cache_len: int = None, + device: torch.device = None, + dtype: torch.dtype = torch.float32, + max_batch_size: Optional[int] = None, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + max_cache_len = min(config.sliding_window, max_cache_len) + super().__init__( + config=config, + batch_size=batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + max_batch_size=max_batch_size, + layer_device_map=layer_device_map, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) + if cache_position.shape[0] > self.max_cache_len: + k_out = key_states[:, :, -self.max_cache_len :, :] + v_out = value_states[:, :, -self.max_cache_len :, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, self.max_cache_len - 1) + to_shift = cache_position >= self.max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + + return k_out, v_out + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + def reset(self): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + EncoderDecoderCache() + ``` + + """ + + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache.key_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` + return self.self_attention_cache.get_seq_length(layer_idx) + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def batch_split( + self, full_batch_size: int, split_size: int, num_hidden_layers: int = None + ) -> "List[EncoderDecoderCache]": + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + @classmethod + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def from_batch_splits( + cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None + ) -> "EncoderDecoderCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) + self_attention_cache.update(layer_keys, layer_values, idx) + + layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) + cross_attention_cache.update(layer_keys, layer_values, idx) + return cls(self_attention_cache, cross_attention_cache) + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*, defaults to `"cpu"`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (torch.dtype, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + max_cache_len: int = None, + device: Union[torch.device, str] = "cpu", + dtype: torch.dtype = torch.float32, + max_batch_size: Optional[int] = None, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ) -> None: + super().__init__() + if batch_size is not None: + logger.warning_once( + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." + ) + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size or max_batch_size + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype + self.num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) + layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC + self.is_sliding = torch.tensor( + [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device + ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) + sliding_cache_shape = ( + self.batch_size, + self.num_key_value_heads, + min(config.sliding_window, max_cache_len), + self.head_dim, + ) + for i in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[i] + else: + layer_device = device + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + if cache_position.shape[0] > max_cache_len: + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out + + def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + sliding_window = cache_kwargs.get("sliding_window") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + if sliding_window: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + def get_seq_length(self, layer_idx: Optional[int] = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Attributes: + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + intermediate_size: (`int`): + Model's intermediate_size taken from config. + ssm_state_size: (`int`): + Model's state_size taken from config. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config + conv_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + dtype: torch.dtype = torch.float16, + device: Optional[Union[torch.device, str]] = None, + max_batch_size: Optional[int] = None, + ): + if batch_size is not None: + logger.warning_once( + f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'max_batch_size' argument instead." + ) + self.dtype = dtype + self.max_batch_size = batch_size or max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=dtype, + ) + + torch._dynamo.mark_static_address(self.conv_states) + torch._dynamo.mark_static_address(self.ssm_states) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + @property + def batch_size(self): + logger.warning_once( + f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " + "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." + ) + return self.max_batch_size + + +class OffloadedStaticCache(StaticCache): + """ + Static cache class to be used with `torch.compile(model)` that offloads to the CPU or + another device. + + Args: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize + the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`Union[str, torch.device]`): + The device on which the cache should be initialized. Should be the same as the + layer device. + dtype (`torch.dtype`, *optional*): + The default `dtype` to use when initializing the cache. + offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): + The device to offload to. Defaults to CPU. + layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. + You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + + Attributes: + key_cache (`List[torch.Tensor]`): + Off-loaded key cache tensors. First one will be on device, where-as the others are + off-loaded. + value_cache (`List[torch.Tensor]`): + Off-loaded value cache tensors. First one will be on device, where-as the others are + off-loaded. + max_batch_size (`int`): + The maximum batch size with which this cache can be used. + max_cache_len (`int`): + The maximum sequence length with which this cache can be used. + device (`torch.device`): + The device on which the cache is used. + offload_device (`torch.device`): + The device used to offload to. + dtype (`torch.dtype`): + The `dtype` used to initializing the cache. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: Optional[int], + device: Union[str, torch.device], + dtype: Optional[torch.dtype] = None, + offload_device: Union[str, torch.device] = torch.device("cpu"), + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ) -> None: + self.max_batch_size = max_batch_size + self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] + self.offload_device = torch.device(offload_device) + self.dtype = dtype if dtype is not None else torch.float32 + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + + num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + + # Create offloaded CPU tensors. + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + for i in range(config.num_hidden_layers): + # First layer is always on-device. + device = self.device if i == 0 else self.offload_device + + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) + + self.key_cache.append(key_cache) + self.value_cache.append(value_cache) + + # Create device tensors. + self._device_key_cache: List[torch.Tensor] = [] + self._device_value_cache: List[torch.Tensor] = [] + + for i in range(2): + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) + + self._device_key_cache.append(key_cache) + self._device_value_cache.append(value_cache) + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Create new CUDA stream for parallel prefetching. + self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the + `cache_position` input to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + if layer_idx == 0: + # Update seen tokens. + # TODO(gante): Remove this. + self._seen_tokens += key_states.shape[-2] + + # Always there. + k_out = self.key_cache[0] + v_out = self.value_cache[0] + else: + # Wait for prefetch stream. + if self._prefetch_stream is not None: + torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) + + k_out = self._device_key_cache[layer_idx & 1] + v_out = self._device_value_cache[layer_idx & 1] + + self._prefetch_layer(layer_idx + 1) + + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + + # Copy the values to the offloaded device as well. + if layer_idx == 0: + self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) + self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does + # explicitly an in-place operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # Copy the values to the offloaded device as well. + if layer_idx != 0: + cache_position = cache_position.to(self.offload_device) + key_states = key_states.to(self.offload_device) + value_states = value_states.to(self.offload_device) + + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + + # TODO(gante): Remove this. + return self._seen_tokens + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + + return self.max_cache_len + + def reset(self) -> None: + """Resets the cache values while preserving the objects.""" + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Zero out cache. + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address. + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + @property + def seen_tokens(self) -> int: + # For backwards compatibility. + # TODO(gante): Remove this. + return self._seen_tokens + + def _create_key_value_cache_tensors( + self, shape: Tuple[int, ...], device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static + addresses for non-CPU tensors. + + Args: + shape (`Tuple[int, ...]`): Shape. + device (`torch.device`): Device. + + Returns: + Key and value cache tensors as a tuple. + """ + + is_cpu_device = device == torch.device("cpu") + + key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) + value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) + + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(key_cache) + torch._dynamo.mark_static_address(value_cache) + + return key_cache, value_cache + + def _prefetch_layer(self, layer_idx: int) -> None: + """Prefetch a layer to the device. Needs to be called in order of layer indices.""" + + # Don't fetch layers that do not exist. + if layer_idx >= len(self.key_cache): + return + + # Alternate between two on-device caches. + if self._prefetch_stream is not None: + with torch.cuda.stream(self._prefetch_stream): + self._prefetch_layer_in_context(layer_idx) + else: + self._prefetch_layer_in_context(layer_idx) + + def _prefetch_layer_in_context(self, layer_idx: int) -> None: + """Performs the actual copy of the layer to device cache.""" + + self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) + self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) diff --git a/char_level_bpe.cpython-312.pyc b/char_level_bpe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4681e707fb2e02df1ef8d9ddd4447101be1d09da Binary files /dev/null and b/char_level_bpe.cpython-312.pyc differ diff --git a/char_level_bpe.py b/char_level_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..bda731e769891c943653be640be4b132489b9be9 --- /dev/null +++ b/char_level_bpe.py @@ -0,0 +1,150 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from .. import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers +from ..models import BPE +from ..normalizers import BertNormalizer, Lowercase, Sequence, unicode_normalizer_from_str +from .base_tokenizer import BaseTokenizer + + +class CharBPETokenizer(BaseTokenizer): + """Original BPE Tokenizer + + Represents the BPE algorithm, as introduced by Rico Sennrich + (https://arxiv.org/abs/1508.07909) + + The defaults settings corresponds to OpenAI GPT BPE tokenizers and differs from the original + Sennrich subword-nmt implementation by the following options that you can deactivate: + - adding a normalizer to clean up the text (deactivate with `bert_normalizer=False`) by: + * removing any control characters and replacing all whitespaces by the classic one. + * handle chinese chars by putting spaces around them. + * strip all accents. + - spitting on punctuation in addition to whitespaces (deactivate it with + `split_on_whitespace_only=True`) + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]] = None, + unk_token: Union[str, AddedToken] = "", + suffix: str = "", + dropout: Optional[float] = None, + lowercase: bool = False, + unicode_normalizer: Optional[str] = None, + bert_normalizer: bool = True, + split_on_whitespace_only: bool = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=dropout, + unk_token=str(unk_token), + end_of_word_suffix=suffix, + ) + ) + else: + tokenizer = Tokenizer(BPE(unk_token=str(unk_token), dropout=dropout, end_of_word_suffix=suffix)) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + # Check for Unicode normalization first (before everything else) + normalizers = [] + + if unicode_normalizer: + normalizers += [unicode_normalizer_from_str(unicode_normalizer)] + + if bert_normalizer: + normalizers += [BertNormalizer(lowercase=False)] + + if lowercase: + normalizers += [Lowercase()] + + # Create the normalizer structure + if len(normalizers) > 0: + if len(normalizers) > 1: + tokenizer.normalizer = Sequence(normalizers) + else: + tokenizer.normalizer = normalizers[0] + + if split_on_whitespace_only: + tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit() + else: + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + tokenizer.decoder = decoders.BPEDecoder(suffix=suffix) + + parameters = { + "model": "BPE", + "unk_token": unk_token, + "suffix": suffix, + "dropout": dropout, + "lowercase": lowercase, + "unicode_normalizer": unicode_normalizer, + "bert_normalizer": bert_normalizer, + "split_on_whitespace_only": split_on_whitespace_only, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return CharBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + suffix: Optional[str] = "", + show_progress: bool = True, + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + end_of_word_suffix=suffix, + show_progress=show_progress, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + suffix: Optional[str] = "", + show_progress: bool = True, + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + end_of_word_suffix=suffix, + show_progress=show_progress, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/configuration_utils.py b/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..648877c8dce962d0d0387924ced7320781d9f056 --- /dev/null +++ b/configuration_utils.py @@ -0,0 +1,1187 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration base class and utilities.""" + +import copy +import json +import os +import re +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +from packaging import version + +from . import __version__ +from .dynamic_module_utils import custom_object_save +from .modeling_gguf_pytorch_utils import load_gguf_checkpoint +from .utils import ( + CONFIG_NAME, + PushToHubMixin, + add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, + cached_file, + copy_func, + download_url, + extract_commit_hash, + is_remote_url, + is_torch_available, + logging, +) +from .utils.generic import is_timm_config_dict + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class PretrainedConfig(PushToHubMixin): + # no-format + r""" + Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as + methods for loading/downloading/saving configurations. + + + + A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to + initialize a model does **not** load the model weights. It only affects the model's configuration. + + + + Class attributes (overridden by derived classes): + + - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate + the correct object in [`~transformers.AutoConfig`]. + - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the + config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like: + [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. + - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary + outputs of the model during inference. + - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized + naming of attributes. + - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor + parallel plan applied to the sub-module when `model.tensor_parallel` is called. + + Common attributes (present in all subclasses): + + - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the + embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT). + - **hidden_size** (`int`) -- The hidden size of the model. + - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the + model. + - **num_hidden_layers** (`int`) -- The number of blocks in the model. + + + + Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading + some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set + them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more + information about the individual parameters. + + + + Arg: + name_or_path (`str`, *optional*, defaults to `""`): + Store the string that was passed to [`PreTrainedModel.from_pretrained`] or + [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created + with such a method. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not the model should return all hidden-states. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not the model should returns all attentions. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple. + is_encoder_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as an encoder/decoder or not. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as decoder or not (in which case it's used as an encoder). + cross_attention_hidden_size** (`bool`, *optional*): + The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder + setting and the cross-attention hidden dimension differs from `self.config.hidden_size`. + add_cross_attention (`bool`, *optional*, defaults to `False`): + Whether cross-attention layers should be added to the model. Note, this option is only relevant for models + that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models + in `AUTO_MODELS_FOR_CAUSAL_LM`. + tie_encoder_decoder (`bool`, *optional*, defaults to `False`): + Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder + and decoder model to have the exact same parameter names. + prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`): + Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of + heads to prune in said layer. + + For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. + chunk_size_feed_forward (`int`, *optional*, defaults to `0`): + The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that + the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` < + sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed + Forward Chunking work?](../glossary.html#feed-forward-chunking). + + > Parameters for fine-tuning tasks + + architectures (`List[str]`, *optional*): + Model architectures that can be used with the model pretrained weights. + finetuning_task (`str`, *optional*): + Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow + or PyTorch) checkpoint. + id2label (`Dict[int, str]`, *optional*): + A map from index (for instance prediction index, or target index) to label. + label2id (`Dict[str, int]`, *optional*): A map from label to index for the model. + num_labels (`int`, *optional*): + Number of labels to use in the last layer added to the model, typically for a classification task. + task_specific_params (`Dict[str, Any]`, *optional*): + Additional keyword arguments to store for the current task. + problem_type (`str`, *optional*): + Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`, + `"single_label_classification"` or `"multi_label_classification"`. + + > Parameters linked to the tokenizer + + tokenizer_class (`str`, *optional*): + The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the + model by default). + prefix (`str`, *optional*): + A specific prompt that should be added at the beginning of each text before calling the model. + bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token. + pad_token_id (`int`, *optional*): The id of the _padding_ token. + eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token. + decoder_start_token_id (`int`, *optional*): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. + sep_token_id (`int`, *optional*): The id of the _separation_ token. + + > PyTorch specific parameters + + torchscript (`bool`, *optional*, defaults to `False`): + Whether or not the model should be used with Torchscript. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + torch_dtype (`str`, *optional*): + The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype` + (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved + model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load + `float16` weights. Since the config object is stored in plain text, this attribute contains just the + floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the + `"float16"` string. + + This attribute is currently not being used during model loading time, but this may change in the future + versions. But we can already start preparing for the future by saving the dtype with save_pretrained. + + > TensorFlow specific parameters + + use_bfloat16 (`bool`, *optional*, defaults to `False`): + Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models). + tf_legacy_loss (`bool`, *optional*, defaults to `False`): + Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may + not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers + v5. + loss_type (`str`, *optional*): + The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will + be automatically infered from the model architecture. + """ + + model_type: str = "" + base_config_key: str = "" + sub_configs: Dict[str, "PretrainedConfig"] = {} + is_composition: bool = False + attribute_map: Dict[str, str] = {} + base_model_tp_plan: Optional[Dict[str, Any]] = None + _auto_class: Optional[str] = None + + def __setattr__(self, key, value): + if key in super().__getattribute__("attribute_map"): + key = super().__getattribute__("attribute_map")[key] + super().__setattr__(key, value) + + def __getattribute__(self, key): + if key != "attribute_map" and key in super().__getattribute__("attribute_map"): + key = super().__getattribute__("attribute_map")[key] + return super().__getattribute__(key) + + def __init__(self, **kwargs): + # Attributes with defaults + self.return_dict = kwargs.pop("return_dict", True) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + self.output_attentions = kwargs.pop("output_attentions", False) + self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models + self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models + self.use_bfloat16 = kwargs.pop("use_bfloat16", False) + self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models + self.pruned_heads = kwargs.pop("pruned_heads", {}) + self.tie_word_embeddings = kwargs.pop( + "tie_word_embeddings", True + ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models. + self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) + + # Is decoder is used in encoder-decoder models to differentiate encoder from decoder + self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) + self.is_decoder = kwargs.pop("is_decoder", False) + self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None) + self.add_cross_attention = kwargs.pop("add_cross_attention", False) + self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) + + # Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these + # parameters, saving them will be deprecated. In a distant future, we won't need to load them. + for parameter_name, default_value in self._get_global_generation_defaults().items(): + setattr(self, parameter_name, kwargs.pop(parameter_name, default_value)) + + # Fine-tuning task arguments + self.architectures = kwargs.pop("architectures", None) + self.finetuning_task = kwargs.pop("finetuning_task", None) + self.id2label = kwargs.pop("id2label", None) + self.label2id = kwargs.pop("label2id", None) + if self.label2id is not None and not isinstance(self.label2id, dict): + raise ValueError("Argument label2id should be a dictionary.") + if self.id2label is not None: + if not isinstance(self.id2label, dict): + raise ValueError("Argument id2label should be a dictionary.") + num_labels = kwargs.pop("num_labels", None) + if num_labels is not None and len(self.id2label) != num_labels: + logger.warning( + f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " + f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}." + ) + self.id2label = {int(key): value for key, value in self.id2label.items()} + # Keys are always strings in JSON so convert ids to int here. + else: + self.num_labels = kwargs.pop("num_labels", 2) + + if self.torch_dtype is not None and isinstance(self.torch_dtype, str): + # we will start using self.torch_dtype in v5, but to be consistent with + # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object + if is_torch_available(): + import torch + + self.torch_dtype = getattr(torch, self.torch_dtype) + + # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.tokenizer_class = kwargs.pop("tokenizer_class", None) + self.prefix = kwargs.pop("prefix", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.sep_token_id = kwargs.pop("sep_token_id", None) + + self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + # task specific arguments + self.task_specific_params = kwargs.pop("task_specific_params", None) + + # regression / multi-label classification + self.problem_type = kwargs.pop("problem_type", None) + allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") + if self.problem_type is not None and self.problem_type not in allowed_problem_types: + raise ValueError( + f"The config parameter `problem_type` was not understood: received {self.problem_type} " + "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." + ) + + # TPU arguments + if kwargs.pop("xla_device", None) is not None: + logger.warning( + "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " + "safely remove it from your `config.json` file." + ) + + # Name or path to the pretrained checkpoint + self._name_or_path = str(kwargs.pop("name_or_path", "")) + # Config hash + self._commit_hash = kwargs.pop("_commit_hash", None) + + # Attention implementation to use, if relevant. + self._attn_implementation_internal = kwargs.pop("attn_implementation", None) + self._attn_implementation_autoset = False + + # Drop the transformers version info + self.transformers_version = kwargs.pop("transformers_version", None) + + # Deal with gradient checkpointing + if kwargs.get("gradient_checkpointing", False): + warnings.warn( + "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " + "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " + "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." + ) + + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + @property + def name_or_path(self) -> str: + return getattr(self, "_name_or_path", None) + + @name_or_path.setter + def name_or_path(self, value): + self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) + + @property + def use_return_dict(self) -> bool: + """ + `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples. + """ + # If torchscript is set, force `return_dict=False` to avoid jit errors + return self.return_dict and not self.torchscript + + @property + def num_labels(self) -> int: + """ + `int`: The number of labels for classification models. + """ + return len(self.id2label) + + @num_labels.setter + def num_labels(self, num_labels: int): + if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels: + self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} + self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + + @property + def _attn_implementation(self): + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation_internal + else: + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + self._set_token_in_kwargs(kwargs) + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + non_default_generation_parameters = self._get_non_default_generation_parameters() + if len(non_default_generation_parameters) > 0: + # TODO (joao): this should be an exception if the user has modified the loaded config. See #33886 + warnings.warn( + "Some non-default generation parameters are set in the model config. These should go into either a) " + "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file " + "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)." + "This warning will become an exception in the future." + f"\nNon-default generation parameters: {str(non_default_generation_parameters)}", + UserWarning, + ) + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file, use_diff=True) + logger.info(f"Configuration saved in {output_config_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + @staticmethod + def _set_token_in_kwargs(kwargs, token=None): + """Temporary method to deal with `token` and `use_auth_token`. + + This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`. + + Need to clean up `use_auth_token` in a follow PR. + """ + # Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet. + if token is None: + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> "PretrainedConfig": + r""" + Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`. + - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final configuration object. + + If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a + dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the + part of `kwargs` which has not been used to update `config` and is otherwise ignored. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (`Dict[str, Any]`, *optional*): + The values in kwargs of any keys which are configuration attributes will be used to override the loaded + values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled + by the `return_unused_kwargs` keyword parameter. + + Returns: + [`PretrainedConfig`]: The configuration object instantiated from this pretrained model. + + Examples: + + ```python + # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a + # derived class: BertConfig + config = BertConfig.from_pretrained( + "google-bert/bert-base-uncased" + ) # Download configuration from huggingface.co and cache. + config = BertConfig.from_pretrained( + "./test/saved_model/" + ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* + config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") + config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False) + assert config.output_attentions == True + config, unused_kwargs = BertConfig.from_pretrained( + "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True + ) + assert config.output_attentions == True + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + cls._set_token_in_kwargs(kwargs, token) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if cls.base_config_key and cls.base_config_key in config_dict: + config_dict = config_dict[cls.base_config_key] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + # sometimes the config has no `base_config_key` if the config is used in several composite models + # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning + for k, v in config_dict.items(): + if isinstance(v, dict) and v.get("model_type") == cls.model_type: + config_dict = v + + # raise warning only if we still can't see a match in `model_type` + if config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + [`PretrainedConfig`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. + + """ + cls._set_token_in_kwargs(kwargs) + + original_kwargs = copy.deepcopy(kwargs) + # Get config dict associated with the base config file + config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + if config_dict is None: + return {}, kwargs + if "_commit_hash" in config_dict: + original_kwargs["_commit_hash"] = config_dict["_commit_hash"] + + # That config file may point us toward another config file to use. + if "configuration_files" in config_dict: + configuration_file = get_configuration_file(config_dict["configuration_files"]) + config_dict, kwargs = cls._get_config_dict( + pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs + ) + + return config_dict, kwargs + + @classmethod + def _get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + subfolder = kwargs.pop("subfolder", "") + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + commit_hash = kwargs.pop("_commit_hash", None) + + gguf_file = kwargs.get("gguf_file", None) + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + user_agent = {"file_type": "config", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + # Special case when pretrained_model_name_or_path is a local file + resolved_config_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file + resolved_config_file = download_url(pretrained_model_name_or_path) + else: + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file + + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_config_file = cached_file( + pretrained_model_name_or_path, + configuration_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + if resolved_config_file is None: + return None, kwargs + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the same" + f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory" + f" containing a {configuration_file} file" + ) + + try: + if gguf_file: + config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"] + else: + # Load config dict + config_dict = cls._dict_from_json_file(resolved_config_file) + + config_dict["_commit_hash"] = commit_hash + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError( + f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_config_file}") + else: + logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") + + if "auto_map" in config_dict and not is_local: + config_dict["auto_map"] = add_model_info_to_auto_map( + config_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in config_dict and not is_local: + config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + config_dict["custom_pipelines"], pretrained_model_name_or_path + ) + + # timm models are not saved with the model_type in the config file + if "model_type" not in config_dict and is_timm_config_dict(config_dict): + config_dict["model_type"] = "timm_wrapper" + + return config_dict, kwargs + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": + """ + Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`PretrainedConfig`]: The configuration object instantiated from those parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + # Those arguments may be passed along for our internal telemetry. + # We remove them so they don't appear in `return_unused_kwargs`. + kwargs.pop("_from_auto", None) + kwargs.pop("_from_pipeline", None) + # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update. + if "_commit_hash" in kwargs and "_commit_hash" in config_dict: + kwargs["_commit_hash"] = config_dict["_commit_hash"] + + # We remove it from kwargs so that it does not appear in `return_unused_kwargs`. + config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None) + + config = cls(**config_dict) + + if hasattr(config, "pruned_heads"): + config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()} + + # Update config with kwargs if needed + if "num_labels" in kwargs and "id2label" in kwargs: + num_labels = kwargs["num_labels"] + id2label = kwargs["id2label"] if kwargs["id2label"] is not None else [] + if len(id2label) != num_labels: + raise ValueError( + f"You passed along `num_labels={num_labels }` with an incompatible id to label map: " + f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove " + "one of them." + ) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + current_attr = getattr(config, key) + # To authorize passing a custom subconfig as kwarg in models that have nested configs. + if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict): + value = current_attr.__class__(**value) + setattr(config, key, value) + if key != "torch_dtype": + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Model config {config}") + if return_unused_kwargs: + return config, kwargs + else: + return config + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig": + """ + Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + [`PretrainedConfig`]: The configuration object instantiated from that JSON file. + + """ + config_dict = cls._dict_from_json_file(json_file) + return cls(**config_dict) + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __eq__(self, other): + return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def __iter__(self): + for attr in self.__dict__: + yield attr + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if ( + isinstance(getattr(self, key, None), PretrainedConfig) + and key in class_config_dict + and isinstance(class_config_dict[key], dict) + ): + # For nested configs we need to clean the diff recursively + diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None)) + if "model_type" in value: + # Needs to be set even if it's not in the diff + diff["model_type"] = value["model_type"] + if len(diff) > 0: + serializable_config_dict[key] = diff + elif ( + key not in default_config_dict + or key == "transformers_version" + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): + serializable_config_dict[key] = value + + if hasattr(self, "quantization_config"): + serializable_config_dict["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = serializable_config_dict.pop("_pre_quantization_dtype", None) + + self.dict_torch_dtype_to_str(serializable_config_dict) + + if "_attn_implementation_internal" in serializable_config_dict: + del serializable_config_dict["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in serializable_config_dict: + del serializable_config_dict["base_model_tp_plan"] + + return serializable_config_dict + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] + if "_attn_implementation_internal" in output: + del output["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in output: + del output["base_model_tp_plan"] + + # Transformers version when serializing the model + output["transformers_version"] = __version__ + + for key, value in output.items(): + # Deal with nested configs like CLIP + if isinstance(value, PretrainedConfig): + value = value.to_dict() + del value["transformers_version"] + + output[key] = value + + if hasattr(self, "quantization_config"): + output["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = output.pop("_pre_quantization_dtype", None) + + self.dict_torch_dtype_to_str(output) + + return output + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(use_diff=use_diff)) + + def update(self, config_dict: Dict[str, Any]): + """ + Updates attributes of this class with attributes from `config_dict`. + + Args: + config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class. + """ + for key, value in config_dict.items(): + setattr(self, key, value) + + def update_from_string(self, update_str: str): + """ + Updates attributes of this class with attributes from `update_str`. + + The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example: + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + + The keys to change have to already exist in the config object. + + Args: + update_str (`str`): String with attributes that should be updated for this class. + + """ + + d = dict(x.split("=") for x in update_str.split(",")) + for k, v in d.items(): + if not hasattr(self, k): + raise ValueError(f"key {k} isn't in the original config dict") + + old_v = getattr(self, k) + if isinstance(old_v, bool): + if v.lower() in ["true", "1", "y", "yes"]: + v = True + elif v.lower() in ["false", "0", "n", "no"]: + v = False + else: + raise ValueError(f"can't derive true or false from {v} (key {k})") + elif isinstance(old_v, int): + v = int(v) + elif isinstance(old_v, float): + v = float(v) + elif not isinstance(old_v, str): + raise TypeError( + f"You can only update int, float, bool or string values in the config, got {v} for key {k}" + ) + + setattr(self, k, v) + + def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self.dict_torch_dtype_to_str(value) + + @classmethod + def register_for_auto_class(cls, auto_class="AutoConfig"): + """ + Register this class with a given auto class. This should only be used for custom configurations as the ones in + the library are already mapped with `AutoConfig`. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`): + The auto class to register this new configuration with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + @staticmethod + def _get_global_generation_defaults() -> Dict[str, Any]: + return { + "max_length": 20, + "min_length": 0, + "do_sample": False, + "early_stopping": False, + "num_beams": 1, + "num_beam_groups": 1, + "diversity_penalty": 0.0, + "temperature": 1.0, + "top_k": 50, + "top_p": 1.0, + "typical_p": 1.0, + "repetition_penalty": 1.0, + "length_penalty": 1.0, + "no_repeat_ngram_size": 0, + "encoder_no_repeat_ngram_size": 0, + "bad_words_ids": None, + "num_return_sequences": 1, + "output_scores": False, + "return_dict_in_generate": False, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "remove_invalid_values": False, + "exponential_decay_length_penalty": None, + "suppress_tokens": None, + "begin_suppress_tokens": None, + } + + def _get_non_default_generation_parameters(self) -> Dict[str, Any]: + """ + Gets the non-default generation parameters on the PretrainedConfig instance + """ + non_default_generation_parameters = {} + decoder_attribute_name = None + + # Composite models don't have a default config, use their decoder config as a fallback for default values + # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults + try: + default_config = self.__class__() + except ValueError: + decoder_config = self.get_text_config(decoder=True) + if decoder_config is not self: + default_config = decoder_config.__class__() + else: + default_config = None + + # If it is a composite model, we want to check the subconfig that will be used for generation + self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name) + + for parameter_name, default_global_value in self._get_global_generation_defaults().items(): + if hasattr(self_decoder_config, parameter_name): + is_default_in_config = is_default_generation_value = None + parameter_value = getattr(self_decoder_config, parameter_name) + # Three cases in which is okay for the model config to hold generation config parameters: + # 1. The parameter is set to `None`, effectivelly delegating its value to the generation config + if parameter_value is None: + continue + # 2. If we have a default config, then the instance should hold the same generation defaults + if default_config is not None: + is_default_in_config = parameter_value == getattr(default_config, parameter_name) + # 3. if we don't have a default config, then the instance should hold the global generation defaults + else: + is_default_generation_value = parameter_value == default_global_value + + is_non_default = (is_default_in_config is False) or ( + is_default_in_config is None and is_default_generation_value is False + ) + if is_non_default: + non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name) + + return non_default_generation_parameters + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + If `decoder` is set to `True`, then only search for decoder config names. + """ + decoder_possible_text_config_names = ("decoder", "generator", "text_config") + encoder_possible_text_config_names = ("text_encoder",) + if decoder: + possible_text_config_names = decoder_possible_text_config_names + else: + possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names + + valid_text_config_names = [] + for text_config_name in possible_text_config_names: + if hasattr(self, text_config_name): + text_config = getattr(self, text_config_name, None) + if text_config is not None: + valid_text_config_names += [text_config_name] + + if len(valid_text_config_names) > 1: + raise ValueError( + f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this " + "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly." + ) + elif len(valid_text_config_names) == 1: + return getattr(self, valid_text_config_names[0]) + return self + + +def get_configuration_file(configuration_files: List[str]) -> str: + """ + Get the configuration file to use for this version of transformers. + + Args: + configuration_files (`List[str]`): The list of available configuration files. + + Returns: + `str`: The configuration file to use. + """ + configuration_files_map = {} + for file_name in configuration_files: + search = _re_configuration_file.search(file_name) + if search is not None: + v = search.groups()[0] + configuration_files_map[v] = file_name + available_versions = sorted(configuration_files_map.keys()) + + # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions. + configuration_file = CONFIG_NAME + transformers_version = version.parse(__version__) + for v in available_versions: + if version.parse(v) <= transformers_version: + configuration_file = configuration_files_map[v] + else: + # No point going further since the versions are sorted. + break + + return configuration_file + + +def recursive_diff_dict(dict_a, dict_b, config_obj=None): + """ + Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the + values from `dict_a` that are different from values in `dict_b`. + """ + diff = {} + default = config_obj.__class__().to_dict() if config_obj is not None else {} + for key, value in dict_a.items(): + obj_value = getattr(config_obj, str(key), None) + if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict): + diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value) + if len(diff_value) > 0: + diff[key] = diff_value + elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]: + diff[key] = value + return diff + + +PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub) +if PretrainedConfig.push_to_hub.__doc__ is not None: + PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( + object="config", object_class="AutoConfig", object_files="configuration file" + ) diff --git a/convert_graph_to_onnx.py b/convert_graph_to_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..051f1d148a84e29d3e706fc4cd42f3ca7d53db26 --- /dev/null +++ b/convert_graph_to_onnx.py @@ -0,0 +1,551 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from argparse import ArgumentParser +from os import listdir, makedirs +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from packaging.version import Version, parse + +from transformers.pipelines import Pipeline, pipeline +from transformers.tokenization_utils import BatchEncoding +from transformers.utils import ModelOutput, is_tf_available, is_torch_available + + +# This is the minimal required version to +# support some ONNX Runtime features +ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0") + + +SUPPORTED_PIPELINES = [ + "feature-extraction", + "ner", + "sentiment-analysis", + "fill-mask", + "question-answering", + "text-generation", + "translation_en_to_fr", + "translation_en_to_de", + "translation_en_to_ro", +] + + +class OnnxConverterArgumentParser(ArgumentParser): + """ + Wraps all the script arguments supported to export transformers models to ONNX IR + """ + + def __init__(self): + super().__init__("ONNX Converter") + + self.add_argument( + "--pipeline", + type=str, + choices=SUPPORTED_PIPELINES, + default="feature-extraction", + ) + self.add_argument( + "--model", + type=str, + required=True, + help="Model's id or path (ex: google-bert/bert-base-cased)", + ) + self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)") + self.add_argument( + "--framework", + type=str, + choices=["pt", "tf"], + help="Framework for loading the model", + ) + self.add_argument("--opset", type=int, default=11, help="ONNX opset to use") + self.add_argument( + "--check-loading", + action="store_true", + help="Check ONNX is able to load the model", + ) + self.add_argument( + "--use-external-format", + action="store_true", + help="Allow exporting model >= than 2Gb", + ) + self.add_argument( + "--quantize", + action="store_true", + help="Quantize the neural network to be run with int8", + ) + self.add_argument("output") + + +def generate_identified_filename(filename: Path, identifier: str) -> Path: + """ + Append a string-identifier at the end (before the extension, if any) to the provided filepath + + Args: + filename: pathlib.Path The actual path object we would like to add an identifier suffix + identifier: The suffix to add + + Returns: String with concatenated identifier at the end of the filename + """ + return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix) + + +def check_onnxruntime_requirements(minimum_version: Version): + """ + Check onnxruntime is installed and if the installed version match is recent enough + + Raises: + ImportError: If onnxruntime is not installed or too old version is found + """ + try: + import onnxruntime + + # Parse the version of the installed onnxruntime + ort_version = parse(onnxruntime.__version__) + + # We require 1.4.0 minimum + if ort_version < ORT_QUANTIZE_MINIMUM_VERSION: + raise ImportError( + f"We found an older version of onnxruntime ({onnxruntime.__version__}) " + f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n" + "Please update onnxruntime by running `pip install --upgrade onnxruntime`" + ) + + except ImportError: + raise ImportError( + "onnxruntime doesn't seem to be currently installed. " + "Please install the onnxruntime by running `pip install onnxruntime`" + " and relaunch the conversion." + ) + + +def ensure_valid_input(model, tokens, input_names): + """ + Ensure inputs are presented in the correct order, without any Non + + Args: + model: The model used to forward the input data + tokens: BatchEncoding holding the input data + input_names: The name of the inputs + + Returns: Tuple + + """ + print("Ensuring inputs are in correct order") + + model_args_name = model.forward.__code__.co_varnames + model_args, ordered_input_names = [], [] + for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument + if arg_name in input_names: + ordered_input_names.append(arg_name) + model_args.append(tokens[arg_name]) + else: + print(f"{arg_name} is not present in the generated input list.") + break + + print(f"Generated inputs order: {ordered_input_names}") + return ordered_input_names, tuple(model_args) + + +def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]: + """ + Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model + + Args: + nlp: The pipeline object holding the model to be exported + framework: The framework identifier to dispatch to the correct inference scheme (pt/tf) + + Returns: + + - List of the inferred input variable names + - List of the inferred output variable names + - Dictionary with input/output variables names as key and shape tensor as value + - a BatchEncoding reference which was used to infer all the above information + """ + + def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int): + if isinstance(tensor, (tuple, list)): + return [build_shape_dict(name, t, is_input, seq_len) for t in tensor] + + else: + # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...) + axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"} + if is_input: + if len(tensor.shape) == 2: + axes[1] = "sequence" + else: + raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})") + else: + seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len] + axes.update({dim: "sequence" for dim in seq_axes}) + + print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}") + return axes + + tokens = nlp.tokenizer("This is a sample output", return_tensors=framework) + seq_len = tokens.input_ids.shape[-1] + outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens) + if isinstance(outputs, ModelOutput): + outputs = outputs.to_tuple() + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + + # Generate input names & axes + input_vars = list(tokens.keys()) + input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()} + + # flatten potentially grouped outputs (past for gpt2, attentions) + outputs_flat = [] + for output in outputs: + if isinstance(output, (tuple, list)): + outputs_flat.extend(output) + else: + outputs_flat.append(output) + + # Generate output names & axes + output_names = [f"output_{i}" for i in range(len(outputs_flat))] + output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)} + + # Create the aggregated axes representation + dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes) + return input_vars, output_names, dynamic_axes, tokens + + +def load_graph_from_args( + pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs +) -> Pipeline: + """ + Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model + + Args: + pipeline_name: The kind of pipeline to use (ner, question-answering, etc.) + framework: The actual model to convert the pipeline from ("pt" or "tf") + model: The model name which will be loaded by the pipeline + tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value + + Returns: Pipeline object + + """ + # If no tokenizer provided + if tokenizer is None: + tokenizer = model + + # Check the wanted framework is available + if framework == "pt" and not is_torch_available(): + raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") + if framework == "tf" and not is_tf_available(): + raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.") + + print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})") + + # Allocate tokenizer and model + return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs) + + +def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool): + """ + Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR + + Args: + nlp: The pipeline to be exported + opset: The actual version of the ONNX operator set to use + output: Path where will be stored the generated ONNX model + use_external_format: Split the model definition from its parameters to allow model bigger than 2GB + + Returns: + + """ + if not is_torch_available(): + raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") + + import torch + from torch.onnx import export + + print(f"Using framework PyTorch: {torch.__version__}") + + with torch.no_grad(): + input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt") + ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names) + + export( + nlp.model, + model_args, + f=output.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): + """ + Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR) + + Args: + nlp: The pipeline to be exported + opset: The actual version of the ONNX operator set to use + output: Path where will be stored the generated ONNX model + + Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow + + """ + if not is_tf_available(): + raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.") + + print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\") + + try: + import tensorflow as tf + import tf2onnx + from tf2onnx import __version__ as t2ov + + print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}") + + # Build + input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf") + + # Forward + nlp.model.predict(tokens.data) + input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()] + model_proto, _ = tf2onnx.convert.from_keras( + nlp.model, input_signature, opset=opset, output_path=output.as_posix() + ) + + except ImportError as e: + raise Exception( + f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}" + ) + + +def convert( + framework: str, + model: str, + output: Path, + opset: int, + tokenizer: Optional[str] = None, + use_external_format: bool = False, + pipeline_name: str = "feature-extraction", + **model_kwargs, +): + """ + Convert the pipeline object to the ONNX Intermediate Representation (IR) format + + Args: + framework: The framework the pipeline is backed by ("pt" or "tf") + model: The name of the model to load for the pipeline + output: The path where the ONNX graph will be stored + opset: The actual version of the ONNX operator set to use + tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided + use_external_format: + Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only) + pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.) + model_kwargs: Keyword arguments to be forwarded to the model constructor + + Returns: + + """ + warnings.warn( + "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of" + " Transformers", + FutureWarning, + ) + print(f"ONNX opset version set to: {opset}") + + # Load the pipeline + nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs) + + if not output.parent.exists(): + print(f"Creating folder {output.parent}") + makedirs(output.parent.as_posix()) + elif len(listdir(output.parent.as_posix())) > 0: + raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion") + + # Export the graph + if framework == "pt": + convert_pytorch(nlp, opset, output, use_external_format) + else: + convert_tensorflow(nlp, opset, output) + + +def optimize(onnx_model_path: Path) -> Path: + """ + Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the + optimizations possible + + Args: + onnx_model_path: filepath where the model binary description is stored + + Returns: Path where the optimized model binary description has been saved + + """ + from onnxruntime import InferenceSession, SessionOptions + + # Generate model name with suffix "optimized" + opt_model_path = generate_identified_filename(onnx_model_path, "-optimized") + sess_option = SessionOptions() + sess_option.optimized_model_filepath = opt_model_path.as_posix() + _ = InferenceSession(onnx_model_path.as_posix(), sess_option) + + print(f"Optimized model has been written at {opt_model_path}: \N{HEAVY CHECK MARK}") + print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\") + + return opt_model_path + + +def quantize(onnx_model_path: Path) -> Path: + """ + Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU + + Args: + onnx_model_path: Path to location the exported ONNX model is stored + + Returns: The Path generated for the quantized + """ + import onnx + import onnxruntime + from onnx.onnx_pb import ModelProto + from onnxruntime.quantization import QuantizationMode + from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer + from onnxruntime.quantization.registry import IntegerOpsRegistry + + # Load the ONNX model + onnx_model = onnx.load(onnx_model_path.as_posix()) + + if parse(onnx.__version__) < parse("1.5.0"): + print( + "Models larger than 2GB will fail to quantize due to protobuf constraint.\n" + "Please upgrade to onnxruntime >= 1.5.0." + ) + + # Copy it + copy_model = ModelProto() + copy_model.CopyFrom(onnx_model) + + # Construct quantizer + # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we + # check the onnxruntime version to ensure backward compatibility. + # See also: https://github.com/microsoft/onnxruntime/pull/12873 + if parse(onnxruntime.__version__) < parse("1.13.1"): + quantizer = ONNXQuantizer( + model=copy_model, + per_channel=False, + reduce_range=False, + mode=QuantizationMode.IntegerOps, + static=False, + weight_qType=True, + input_qType=False, + tensors_range=None, + nodes_to_quantize=None, + nodes_to_exclude=None, + op_types_to_quantize=list(IntegerOpsRegistry), + ) + else: + quantizer = ONNXQuantizer( + model=copy_model, + per_channel=False, + reduce_range=False, + mode=QuantizationMode.IntegerOps, + static=False, + weight_qType=True, + activation_qType=False, + tensors_range=None, + nodes_to_quantize=None, + nodes_to_exclude=None, + op_types_to_quantize=list(IntegerOpsRegistry), + ) + + # Quantize and export + quantizer.quantize_model() + + # Append "-quantized" at the end of the model's name + quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized") + + # Save model + print(f"Quantized model has been written at {quantized_model_path}: \N{HEAVY CHECK MARK}") + onnx.save_model(quantizer.model.model, quantized_model_path.as_posix()) + + return quantized_model_path + + +def verify(path: Path): + from onnxruntime import InferenceSession, SessionOptions + from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException + + print(f"Checking ONNX model loading from: {path} ...") + try: + onnx_options = SessionOptions() + _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"]) + print(f"Model {path} correctly loaded: \N{HEAVY CHECK MARK}") + except RuntimeException as re: + print(f"Error while loading the model {re}: \N{HEAVY BALLOT X}") + + +if __name__ == "__main__": + parser = OnnxConverterArgumentParser() + args = parser.parse_args() + + # Make sure output is absolute path + args.output = Path(args.output).absolute() + + try: + print("\n====== Converting model to ONNX ======") + # Convert + convert( + args.framework, + args.model, + args.output, + args.opset, + args.tokenizer, + args.use_external_format, + args.pipeline, + ) + + if args.quantize: + # Ensure requirements for quantization on onnxruntime is met + check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION) + + # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch + if args.framework == "tf": + print( + "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n" + "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n" + "\t For more information, please refer to the onnxruntime documentation:\n" + "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n" + ) + + print("\n====== Optimizing ONNX model ======") + + # Quantization works best when using the optimized version of the model + args.optimized_output = optimize(args.output) + + # Do the quantization on the right graph + args.quantized_output = quantize(args.optimized_output) + + # And verify + if args.check_loading: + print("\n====== Check exported ONNX model(s) ======") + verify(args.output) + + if hasattr(args, "optimized_output"): + verify(args.optimized_output) + + if hasattr(args, "quantized_output"): + verify(args.quantized_output) + + except Exception as e: + print(f"Error while converting the model: {e}") + exit(1) diff --git a/convert_pytorch_checkpoint_to_tf2.py b/convert_pytorch_checkpoint_to_tf2.py new file mode 100644 index 0000000000000000000000000000000000000000..c3431ad5b2e0ac6e0969e24b3a00922edb382116 --- /dev/null +++ b/convert_pytorch_checkpoint_to_tf2.py @@ -0,0 +1,446 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert pytorch checkpoints to TensorFlow""" + +import argparse +import os + +from . import ( + AlbertConfig, + BartConfig, + BertConfig, + CamembertConfig, + CTRLConfig, + DistilBertConfig, + DPRConfig, + ElectraConfig, + FlaubertConfig, + GPT2Config, + LayoutLMConfig, + LxmertConfig, + OpenAIGPTConfig, + RobertaConfig, + T5Config, + TFAlbertForPreTraining, + TFBartForConditionalGeneration, + TFBartForSequenceClassification, + TFBertForPreTraining, + TFBertForQuestionAnswering, + TFBertForSequenceClassification, + TFCamembertForMaskedLM, + TFCTRLLMHeadModel, + TFDistilBertForMaskedLM, + TFDistilBertForQuestionAnswering, + TFDPRContextEncoder, + TFDPRQuestionEncoder, + TFDPRReader, + TFElectraForPreTraining, + TFFlaubertWithLMHeadModel, + TFGPT2LMHeadModel, + TFLayoutLMForMaskedLM, + TFLxmertForPreTraining, + TFLxmertVisualFeatureEncoder, + TFOpenAIGPTLMHeadModel, + TFRobertaForCausalLM, + TFRobertaForMaskedLM, + TFRobertaForSequenceClassification, + TFT5ForConditionalGeneration, + TFTransfoXLLMHeadModel, + TFWav2Vec2Model, + TFXLMRobertaForMaskedLM, + TFXLMWithLMHeadModel, + TFXLNetLMHeadModel, + TransfoXLConfig, + Wav2Vec2Config, + Wav2Vec2Model, + XLMConfig, + XLMRobertaConfig, + XLNetConfig, + is_torch_available, + load_pytorch_checkpoint_in_tf2_model, +) +from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging + + +if is_torch_available(): + import numpy as np + import torch + + from . import ( + AlbertForPreTraining, + BartForConditionalGeneration, + BertForPreTraining, + BertForQuestionAnswering, + BertForSequenceClassification, + CamembertForMaskedLM, + CTRLLMHeadModel, + DistilBertForMaskedLM, + DistilBertForQuestionAnswering, + DPRContextEncoder, + DPRQuestionEncoder, + DPRReader, + ElectraForPreTraining, + FlaubertWithLMHeadModel, + GPT2LMHeadModel, + LayoutLMForMaskedLM, + LxmertForPreTraining, + LxmertVisualFeatureEncoder, + OpenAIGPTLMHeadModel, + RobertaForMaskedLM, + RobertaForSequenceClassification, + T5ForConditionalGeneration, + TransfoXLLMHeadModel, + XLMRobertaForMaskedLM, + XLMWithLMHeadModel, + XLNetLMHeadModel, + ) + + +logging.set_verbosity_info() + +MODEL_CLASSES = { + "bart": ( + BartConfig, + TFBartForConditionalGeneration, + TFBartForSequenceClassification, + BartForConditionalGeneration, + ), + "bert": ( + BertConfig, + TFBertForPreTraining, + BertForPreTraining, + ), + "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": ( + BertConfig, + TFBertForQuestionAnswering, + BertForQuestionAnswering, + ), + "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": ( + BertConfig, + TFBertForQuestionAnswering, + BertForQuestionAnswering, + ), + "google-bert/bert-base-cased-finetuned-mrpc": ( + BertConfig, + TFBertForSequenceClassification, + BertForSequenceClassification, + ), + "dpr": ( + DPRConfig, + TFDPRQuestionEncoder, + TFDPRContextEncoder, + TFDPRReader, + DPRQuestionEncoder, + DPRContextEncoder, + DPRReader, + ), + "openai-community/gpt2": ( + GPT2Config, + TFGPT2LMHeadModel, + GPT2LMHeadModel, + ), + "xlnet": ( + XLNetConfig, + TFXLNetLMHeadModel, + XLNetLMHeadModel, + ), + "xlm": ( + XLMConfig, + TFXLMWithLMHeadModel, + XLMWithLMHeadModel, + ), + "xlm-roberta": ( + XLMRobertaConfig, + TFXLMRobertaForMaskedLM, + XLMRobertaForMaskedLM, + ), + "transfo-xl": ( + TransfoXLConfig, + TFTransfoXLLMHeadModel, + TransfoXLLMHeadModel, + ), + "openai-community/openai-gpt": ( + OpenAIGPTConfig, + TFOpenAIGPTLMHeadModel, + OpenAIGPTLMHeadModel, + ), + "roberta": ( + RobertaConfig, + TFRobertaForCausalLM, + TFRobertaForMaskedLM, + RobertaForMaskedLM, + ), + "layoutlm": ( + LayoutLMConfig, + TFLayoutLMForMaskedLM, + LayoutLMForMaskedLM, + ), + "FacebookAI/roberta-large-mnli": ( + RobertaConfig, + TFRobertaForSequenceClassification, + RobertaForSequenceClassification, + ), + "camembert": ( + CamembertConfig, + TFCamembertForMaskedLM, + CamembertForMaskedLM, + ), + "flaubert": ( + FlaubertConfig, + TFFlaubertWithLMHeadModel, + FlaubertWithLMHeadModel, + ), + "distilbert": ( + DistilBertConfig, + TFDistilBertForMaskedLM, + DistilBertForMaskedLM, + ), + "distilbert-base-distilled-squad": ( + DistilBertConfig, + TFDistilBertForQuestionAnswering, + DistilBertForQuestionAnswering, + ), + "lxmert": ( + LxmertConfig, + TFLxmertForPreTraining, + LxmertForPreTraining, + ), + "lxmert-visual-feature-encoder": ( + LxmertConfig, + TFLxmertVisualFeatureEncoder, + LxmertVisualFeatureEncoder, + ), + "Salesforce/ctrl": ( + CTRLConfig, + TFCTRLLMHeadModel, + CTRLLMHeadModel, + ), + "albert": ( + AlbertConfig, + TFAlbertForPreTraining, + AlbertForPreTraining, + ), + "t5": ( + T5Config, + TFT5ForConditionalGeneration, + T5ForConditionalGeneration, + ), + "electra": ( + ElectraConfig, + TFElectraForPreTraining, + ElectraForPreTraining, + ), + "wav2vec2": ( + Wav2Vec2Config, + TFWav2Vec2Model, + Wav2Vec2Model, + ), +} + + +def convert_pt_checkpoint_to_tf( + model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True +): + if model_type not in MODEL_CLASSES: + raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.") + + config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type] + + # Initialise TF model + if config_file in aws_config_map: + config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models) + config = config_class.from_json_file(config_file) + config.output_hidden_states = True + config.output_attentions = True + print(f"Building TensorFlow model from configuration: {config}") + tf_model = model_class(config) + + # Load weights from tf checkpoint + if pytorch_checkpoint_path in aws_config_map.keys(): + pytorch_checkpoint_path = cached_file( + pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models + ) + # Load PyTorch checkpoint in tf2 model: + tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) + + if compare_with_pt_model: + tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network + + weights_only_kwarg = {"weights_only": True} + state_dict = torch.load( + pytorch_checkpoint_path, + map_location="cpu", + **weights_only_kwarg, + ) + pt_model = pt_model_class.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + + with torch.no_grad(): + pto = pt_model(**pt_model.dummy_inputs) + + np_pt = pto[0].numpy() + np_tf = tfo[0].numpy() + diff = np.amax(np.abs(np_pt - np_tf)) + print(f"Max absolute difference between models outputs {diff}") + assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}" + + # Save pytorch-model + print(f"Save TensorFlow model to {tf_dump_path}") + tf_model.save_weights(tf_dump_path, save_format="h5") + + +def convert_all_pt_checkpoints_to_tf( + args_model_type, + tf_dump_path, + model_shortcut_names_or_path=None, + config_shortcut_names_or_path=None, + compare_with_pt_model=False, + use_cached_models=False, + remove_cached_files=False, + only_convert_finetuned_models=False, +): + if args_model_type is None: + model_types = list(MODEL_CLASSES.keys()) + else: + model_types = [args_model_type] + + for j, model_type in enumerate(model_types, start=1): + print("=" * 100) + print(f" Converting model type {j}/{len(model_types)}: {model_type}") + print("=" * 100) + if model_type not in MODEL_CLASSES: + raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.") + + config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] + + if model_shortcut_names_or_path is None: + model_shortcut_names_or_path = list(aws_model_maps.keys()) + if config_shortcut_names_or_path is None: + config_shortcut_names_or_path = model_shortcut_names_or_path + + for i, (model_shortcut_name, config_shortcut_name) in enumerate( + zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1 + ): + print("-" * 100) + if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name: + if not only_convert_finetuned_models: + print(f" Skipping finetuned checkpoint {model_shortcut_name}") + continue + model_type = model_shortcut_name + elif only_convert_finetuned_models: + print(f" Skipping not finetuned checkpoint {model_shortcut_name}") + continue + print( + f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}" + ) + print("-" * 100) + + if config_shortcut_name in aws_config_map: + config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models) + else: + config_file = config_shortcut_name + + if model_shortcut_name in aws_model_maps: + model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models) + else: + model_file = model_shortcut_name + + if os.path.isfile(model_shortcut_name): + model_shortcut_name = "converted_model" + + convert_pt_checkpoint_to_tf( + model_type=model_type, + pytorch_checkpoint_path=model_file, + config_file=config_file, + tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"), + compare_with_pt_model=compare_with_pt_model, + ) + if remove_cached_files: + os.remove(config_file) + os.remove(model_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file." + ) + parser.add_argument( + "--model_type", + default=None, + type=str, + help=( + f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and " + "convert all the models from AWS." + ), + ) + parser.add_argument( + "--pytorch_checkpoint_path", + default=None, + type=str, + help=( + "Path to the PyTorch checkpoint path or shortcut name to download from AWS. " + "If not given, will download and convert all the checkpoints from AWS." + ), + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + help=( + "The config json file corresponding to the pre-trained model. \n" + "This specifies the model architecture. If not given and " + "--pytorch_checkpoint_path is not given or is a shortcut name " + "use the configuration associated to the shortcut name on the AWS" + ), + ) + parser.add_argument( + "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions." + ) + parser.add_argument( + "--use_cached_models", + action="store_true", + help="Use cached models if possible instead of updating to latest checkpoint versions.", + ) + parser.add_argument( + "--remove_cached_files", + action="store_true", + help="Remove pytorch models after conversion (save memory when converting in batches).", + ) + parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.") + args = parser.parse_args() + + # if args.pytorch_checkpoint_path is not None: + # convert_pt_checkpoint_to_tf(args.model_type.lower(), + # args.pytorch_checkpoint_path, + # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path, + # args.tf_dump_path, + # compare_with_pt_model=args.compare_with_pt_model, + # use_cached_models=args.use_cached_models) + # else: + convert_all_pt_checkpoints_to_tf( + args.model_type.lower() if args.model_type is not None else None, + args.tf_dump_path, + model_shortcut_names_or_path=[args.pytorch_checkpoint_path] + if args.pytorch_checkpoint_path is not None + else None, + config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, + compare_with_pt_model=args.compare_with_pt_model, + use_cached_models=args.use_cached_models, + remove_cached_files=args.remove_cached_files, + only_convert_finetuned_models=args.only_convert_finetuned_models, + ) diff --git a/convert_slow_tokenizer.py b/convert_slow_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..030e3a666436308a5bcbf199e466934a50258767 --- /dev/null +++ b/convert_slow_tokenizer.py @@ -0,0 +1,1642 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities to convert slow tokenizers in their fast tokenizers counterparts. + +All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and +allow to make our dependency on SentencePiece optional. +""" + +import warnings +from typing import Dict, List, Tuple + +from packaging import version +from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors +from tokenizers.models import BPE, Unigram, WordPiece + +from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends +from .utils.import_utils import PROTOBUF_IMPORT_ERROR + + +logger = logging.get_logger(__name__) + + +def import_protobuf(error_message=""): + if is_sentencepiece_available(): + from sentencepiece import sentencepiece_model_pb2 + + return sentencepiece_model_pb2 + if is_protobuf_available(): + import google.protobuf + + if version.parse(google.protobuf.__version__) < version.parse("4.0.0"): + from transformers.utils import sentencepiece_model_pb2 + else: + from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2 + return sentencepiece_model_pb2 + else: + raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) + + +def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str: + if add_prefix_space: + prepend_scheme = "always" + if not getattr(original_tokenizer, "legacy", True): + prepend_scheme = "first" + else: + prepend_scheme = "never" + return prepend_scheme + + +def generate_merges(vocab, vocab_scores): + reverse = vocab_scores is not None + vocab_scores = dict(vocab_scores) if reverse else vocab + + merges = [] + for merge, piece_score in vocab_scores.items(): + local = [] + for index in range(1, len(merge)): + piece_l, piece_r = merge[:index], merge[index:] + if piece_l in vocab and piece_r in vocab: + local.append((piece_l, piece_r, piece_score)) + local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]])) + merges.extend(local) + + merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse) + merges = [(val[0], val[1]) for val in merges] + return merges + + +class SentencePieceExtractor: + """ + Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece + """ + + def __init__(self, model: str): + requires_backends(self, "sentencepiece") + from sentencepiece import SentencePieceProcessor + + self.sp = SentencePieceProcessor() + self.sp.Load(model) + + def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]: + """ + By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to + order the merges with respect to the piece scores instead. + """ + sp = self.sp + vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())} + + merges = generate_merges(vocab, vocab_scores) + + return vocab, merges + + +class GemmaSentencePieceExtractor(SentencePieceExtractor): + def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]: + """ + By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to + order the merges with respect to the piece scores instead. + """ + sp = self.sp + vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())} + + # there is a missing token in the vocab. We have to do this to support merges + # "<0x09>" is the bytefallback for `\t` + vocab["\t"] = vocab.get("<0x09>") + + merges = generate_merges(vocab, vocab_scores) + return vocab, merges + + +def check_number_comma(piece: str) -> bool: + return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit() + + +class Converter: + def __init__(self, original_tokenizer): + self.original_tokenizer = original_tokenizer + + def converted(self) -> Tokenizer: + raise NotImplementedError() + + +class BertConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + +class SplinterConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + question = str(self.original_tokenizer.question_token) + dot = "." + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + question_token_id = self.original_tokenizer.question_token_id + dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".") + + if self.original_tokenizer.padding_side == "right": + pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1" + else: + pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1" + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=pair, + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + (question, question_token_id), + (dot, dot_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + +class FunnelConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer + pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + +class MPNetConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + +class OpenAIGPTConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.encoder + merges = list(self.original_tokenizer.bpe_ranks.keys()) + unk_token = self.original_tokenizer.unk_token + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + unk_token=str(unk_token), + end_of_word_suffix="", + fuse_unk=False, + ) + ) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + tokenizer.decoder = decoders.BPEDecoder(suffix="") + + return tokenizer + + +class GPT2Converter(Converter): + def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer: + if not vocab: + vocab = self.original_tokenizer.encoder + if not merges: + merges = list(self.original_tokenizer.bpe_ranks) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + ) + ) + + add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + if getattr(self.original_tokenizer, "add_bos_token", False): + bos = self.original_tokenizer.bos_token + bos_token_id = self.original_tokenizer.bos_token_id + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 $B:1", + special_tokens=[ + (bos, bos_token_id), + ], + ) + else: + # XXX trim_offsets=False actually means this post_processor doesn't + # really do anything. + tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + return tokenizer + + +class HerbertConverter(Converter): + def converted(self) -> Tokenizer: + tokenizer_info_str = "#version:" + token_suffix = "" + + vocab = self.original_tokenizer.encoder + merges = list(self.original_tokenizer.bpe_ranks.keys()) + if tokenizer_info_str in merges[0][0]: + merges = merges[1:] + + tokenizer = Tokenizer( + BPE( + vocab, + merges, + dropout=None, + unk_token=self.original_tokenizer.unk_token, + end_of_word_suffix=token_suffix, + ) + ) + + tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix) + tokenizer.post_processor = processors.BertProcessing( + sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id), + cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id), + ) + + return tokenizer + + +class Qwen2Converter(Converter): + def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer: + if not vocab: + vocab = self.original_tokenizer.encoder + if not merges: + merges = list(self.original_tokenizer.bpe_ranks.keys()) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + unk_token=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + byte_fallback=False, + ) + ) + + tokenizer.normalizer = normalizers.NFC() + + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split( + Regex( + r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + ), + behavior="isolated", + invert=False, + ), + pre_tokenizers.ByteLevel( + add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False), + use_regex=False, + ), + ] + ) + + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + + return tokenizer + + +class RobertaConverter(Converter): + def converted(self) -> Tokenizer: + ot = self.original_tokenizer + vocab = ot.encoder + merges = list(ot.bpe_ranks.keys()) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + ) + ) + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.RobertaProcessing( + sep=(ot.sep_token, ot.sep_token_id), + cls=(ot.cls_token, ot.cls_token_id), + add_prefix_space=ot.add_prefix_space, + trim_offsets=True, # True by default on Roberta (historical) + ) + + return tokenizer + + +class RoFormerConverter(Converter): + def converted(self) -> Tokenizer: + from .models.roformer.tokenization_utils import JiebaPreTokenizer + + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=False, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab)) + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + +class DebertaConverter(Converter): + def converted(self) -> Tokenizer: + ot = self.original_tokenizer + vocab = ot.encoder + merges = list(ot.bpe_ranks.keys()) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + ) + ) + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.TemplateProcessing( + single="[CLS]:0 $A:0 [SEP]:0", + pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", + special_tokens=[ + ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")), + ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")), + ], + ) + + return tokenizer + + +class SpmConverter(Converter): + handle_byte_fallback = False + SpmExtractor = SentencePieceExtractor + special_tokens = {} + + def __init__(self, *args): + requires_backends(self, "protobuf") + + super().__init__(*args) + + # from .utils import sentencepiece_model_pb2 as model_pb2 + model_pb2 = import_protobuf() + + m = model_pb2.ModelProto() + with open(self.original_tokenizer.vocab_file, "rb") as f: + m.ParseFromString(f.read()) + self.proto = m + + if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback: + warnings.warn( + "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option" + " which is not implemented in the fast tokenizers. In practice this means that the fast version of the" + " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these " + "unknown tokens into a sequence of byte tokens matching the original piece of text." + ) + + def vocab(self, proto): + return [(piece.piece, piece.score) for piece in proto.pieces] + + def unk_id(self, proto): + return proto.trainer_spec.unk_id + + def tokenizer(self, proto): + model_type = proto.trainer_spec.model_type + vocab_scores = self.vocab(proto) + + if model_type == 1: + tokenizer = Tokenizer( + Unigram( + vocab_scores, + unk_id=self.unk_id(proto), + byte_fallback=self.handle_byte_fallback, + ) + ) + + elif model_type == 2: + _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) + bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)} + tokenizer = Tokenizer( + BPE( + bpe_vocab, + merges, + unk_token=proto.trainer_spec.unk_piece, + fuse_unk=True, + byte_fallback=self.handle_byte_fallback, + dropout=None, + ) + ) + + else: + raise Exception( + "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" + ) + + # control tokens are special + # user defined symbols are not + # both user and control tokens are AddedTokens + # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33) + spm_added_tokens = [ + (id, p.piece, p.type == 3 or p.piece in self.special_tokens) + for id, p in enumerate(proto.pieces) + if p.type in [3, 4] + ] + tokenizer.add_tokens( + [ + AddedToken(token, normalized=False, special=special) + for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0]) + ] + ) + + return tokenizer + + def normalizer(self, proto): + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + _normalizers = [ + normalizers.Strip(left=False, right=True), # stripping is important + normalizers.Replace(Regex(" {2,}"), "▁"), + ] + if not precompiled_charsmap: + return normalizers.Sequence(_normalizers) + else: + return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers) + + def pre_tokenizer(self, replacement, add_prefix_space): + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + def post_processor(self): + return None + + def decoder(self, replacement, add_prefix_space): + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + def converted(self) -> Tokenizer: + tokenizer = self.tokenizer(self.proto) + + # Tokenizer assemble + normalizer = self.normalizer(self.proto) + if normalizer is not None: + tokenizer.normalizer = normalizer + + replacement = "▁" + add_prefix_space = True + if hasattr(self.original_tokenizer, "add_prefix_space"): + add_prefix_space = self.original_tokenizer.add_prefix_space + + pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space) + if pre_tokenizer is not None: + tokenizer.pre_tokenizer = pre_tokenizer + + tokenizer.decoder = self.decoder(replacement, add_prefix_space) + post_processor = self.post_processor() + if post_processor: + tokenizer.post_processor = post_processor + + return tokenizer + + +class AlbertConverter(SpmConverter): + def vocab(self, proto): + return [ + (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100) + for piece in proto.pieces + ] + + def normalizer(self, proto): + list_normalizers = [ + normalizers.Replace("``", '"'), + normalizers.Replace("''", '"'), + ] + if not self.original_tokenizer.keep_accents: + list_normalizers.append(normalizers.NFKD()) + list_normalizers.append(normalizers.StripAccents()) + if self.original_tokenizer.do_lower_case: + list_normalizers.append(normalizers.Lowercase()) + + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + + if precompiled_charsmap: + list_normalizers.append(normalizers.Precompiled(precompiled_charsmap)) + + list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " ")) + return normalizers.Sequence(list_normalizers) + + def post_processor(self): + return processors.TemplateProcessing( + single="[CLS]:0 $A:0 [SEP]:0", + pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", + special_tokens=[ + ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")), + ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")), + ], + ) + + +class BarthezConverter(SpmConverter): + def unk_id(self, proto): + unk_id = 3 + return unk_id + + def post_processor(self): + return processors.TemplateProcessing( + single=" $A ", + pair=" $A $B ", + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class CamembertConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("NOTUSED", 0.0), + ("", 0.0), + ("NOTUSED", 0.0), + ("", 0.0), + ("NOTUSED", -100), + ] + # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead + vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]] + vocab += [("", 0.0)] + return vocab + + def unk_id(self, proto): + # See vocab unk position + return 3 + + def post_processor(self): + return processors.TemplateProcessing( + single=" $A ", + pair=" $A $B ", + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class DebertaV2Converter(SpmConverter): + def pre_tokenizer(self, replacement, add_prefix_space): + list_pretokenizers = [] + if self.original_tokenizer.split_by_punct: + list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated")) + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)) + return pre_tokenizers.Sequence(list_pretokenizers) + + def normalizer(self, proto): + list_normalizers = [] + if self.original_tokenizer.do_lower_case: + list_normalizers.append(normalizers.Lowercase()) + list_normalizers.append(normalizers.Strip()) + + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + if precompiled_charsmap: + list_normalizers.append(normalizers.Precompiled(precompiled_charsmap)) + list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " ")) + + return normalizers.Sequence(list_normalizers) + + def post_processor(self): + return processors.TemplateProcessing( + single="[CLS]:0 $A:0 [SEP]:0", + pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", + special_tokens=[ + ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")), + ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")), + ], + ) + + +class MBartConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + vocab += [ + ("ar_AR", 0.0), + ("cs_CZ", 0.0), + ("de_DE", 0.0), + ("en_XX", 0.0), + ("es_XX", 0.0), + ("et_EE", 0.0), + ("fi_FI", 0.0), + ("fr_XX", 0.0), + ("gu_IN", 0.0), + ("hi_IN", 0.0), + ("it_IT", 0.0), + ("ja_XX", 0.0), + ("kk_KZ", 0.0), + ("ko_KR", 0.0), + ("lt_LT", 0.0), + ("lv_LV", 0.0), + ("my_MM", 0.0), + ("ne_NP", 0.0), + ("nl_XX", 0.0), + ("ro_RO", 0.0), + ("ru_RU", 0.0), + ("si_LK", 0.0), + ("tr_TR", 0.0), + ("vi_VN", 0.0), + ("zh_CN", 0.0), + ] + vocab += [("", 0.0)] + return vocab + + def unk_id(self, proto): + return 3 + + def post_processor(self): + return processors.TemplateProcessing( + single="$A en_XX", + pair="$A $B en_XX", + special_tokens=[ + ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class MBart50Converter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip + vocab += [("", 0.0)] + return vocab + + def unk_id(self, proto): + return 3 + + def post_processor(self): + return processors.TemplateProcessing( + single="en_XX $A ", + pair="en_XX $A $B ", + special_tokens=[ + ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class NllbConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def unk_id(self, proto): + return 3 + + def post_processor(self): + return processors.TemplateProcessing( + single="eng_Latn $A ", + pair="eng_Latn $A $B ", + special_tokens=[ + ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class SeamlessM4TConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def unk_id(self, proto): + return self.original_tokenizer.unk_token_id + + def post_processor(self): + return processors.TemplateProcessing( + single="__eng__ $A ", + pair="__eng__ $A $B ", + special_tokens=[ + ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class XLMRobertaConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + vocab += [("", 0.0)] + return vocab + + def unk_id(self, proto): + unk_id = 3 + return unk_id + + def post_processor(self): + return processors.TemplateProcessing( + single=" $A ", + pair=" $A $B ", + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class XLNetConverter(SpmConverter): + def vocab(self, proto): + return [ + (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100) + for piece in proto.pieces + ] + + def normalizer(self, proto): + list_normalizers = [ + normalizers.Replace("``", '"'), + normalizers.Replace("''", '"'), + ] + if not self.original_tokenizer.keep_accents: + list_normalizers.append(normalizers.NFKD()) + list_normalizers.append(normalizers.StripAccents()) + if self.original_tokenizer.do_lower_case: + list_normalizers.append(normalizers.Lowercase()) + + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + + if precompiled_charsmap: + list_normalizers.append(normalizers.Precompiled(precompiled_charsmap)) + + list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " ")) + return normalizers.Sequence(list_normalizers) + + def post_processor(self): + return processors.TemplateProcessing( + single="$A:0 :0 :2", + pair="$A:0 :0 $B:1 :1 :2", + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class ReformerConverter(SpmConverter): + pass + + +class RemBertConverter(SpmConverter): + # Inspired from AlbertConverter + def normalizer(self, proto): + list_normalizers = [ + normalizers.Replace("``", '"'), + normalizers.Replace("''", '"'), + normalizers.Replace(Regex(" {2,}"), " "), + ] + if not self.original_tokenizer.keep_accents: + list_normalizers.append(normalizers.NFKD()) + list_normalizers.append(normalizers.StripAccents()) + if self.original_tokenizer.do_lower_case: + list_normalizers.append(normalizers.Lowercase()) + + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + + if precompiled_charsmap: + list_normalizers.append(normalizers.Precompiled(precompiled_charsmap)) + + return normalizers.Sequence(list_normalizers) + + def post_processor(self): + return processors.TemplateProcessing( + single="[CLS]:0 $A:0 [SEP]:0", + pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", + special_tokens=[ + ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")), + ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")), + ], + ) + + +class BertGenerationConverter(SpmConverter): + pass + + +class PegasusConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + (self.original_tokenizer.pad_token, 0.0), + (self.original_tokenizer.eos_token, 0.0), + ] + + if self.original_tokenizer.mask_token_sent is not None: + vocab += [(self.original_tokenizer.mask_token_sent, 0.0)] + + if ( + self.original_tokenizer.mask_token is not None + and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset + ): + vocab += [(self.original_tokenizer.mask_token, 0.0)] + + vocab += [(f"", -100.0) for i in range(2, self.original_tokenizer.offset)] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]] + return vocab + + def unk_id(self, proto): + return proto.trainer_spec.unk_id + self.original_tokenizer.offset + + def pre_tokenizer(self, replacement, add_prefix_space): + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + return pre_tokenizers.Sequence( + [ + pre_tokenizers.WhitespaceSplit(), + pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme), + ] + ) + + def post_processor(self): + eos = self.original_tokenizer.eos_token + special_tokens = [ + (eos, self.original_tokenizer.eos_token_id), + ] + return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens) + + +class T5Converter(SpmConverter): + def vocab(self, proto): + num_extra_ids = self.original_tokenizer._extra_ids + vocab = [(piece.piece, piece.score) for piece in proto.pieces] + vocab += [(f"", 0.0) for i in range(num_extra_ids - 1, -1, -1)] + return vocab + + def post_processor(self): + return processors.TemplateProcessing( + single=["$A", ""], + pair=["$A", "", "$B", ""], + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class UdopConverter(SpmConverter): + def post_processor(self): + return processors.TemplateProcessing( + single=["$A", ""], + pair=["$A", "", "$B", ""], + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class WhisperConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.encoder + merges = list(self.original_tokenizer.bpe_ranks.keys()) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + ) + ) + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + + prefix_token_ids = self.original_tokenizer.prefix_tokens + prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids) + eos = self.original_tokenizer.eos_token + eos_token_id = self.original_tokenizer.eos_token_id + prefix_template = " ".join([f"{token}:0" for token in prefixes]) + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{prefix_template} $A:0 {eos}:0", + pair=f"{prefix_template} $A:0 $B:1 {eos}:1", + special_tokens=[ + (eos, eos_token_id), + *zip(prefixes, prefix_token_ids), + ], + ) + + return tokenizer + + +class BigBirdConverter(SpmConverter): + def post_processor(self): + return processors.TemplateProcessing( + single="[CLS]:0 $A:0 [SEP]:0", + pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", + special_tokens=[ + ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")), + ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")), + ], + ) + + +class CLIPConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.encoder + merges = list(self.original_tokenizer.bpe_ranks.keys()) + unk_token = self.original_tokenizer.unk_token + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + unk_token=str(unk_token), + ) + ) + + tokenizer.normalizer = normalizers.Sequence( + [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()] + ) + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split( + Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""), + behavior="removed", + invert=True, + ), + pre_tokenizers.ByteLevel(add_prefix_space=False), + ] + ) + tokenizer.decoder = decoders.ByteLevel() + + # Hack to have a ByteLevel and TemplaceProcessor + tokenizer.post_processor = processors.RobertaProcessing( + sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id), + cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id), + add_prefix_space=False, + trim_offsets=False, + ) + return tokenizer + + +class LayoutLMv2Converter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = True + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + +class BlenderbotConverter(Converter): + def converted(self) -> Tokenizer: + ot = self.original_tokenizer + vocab = ot.encoder + merges = list(ot.bpe_ranks.keys()) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + ) + ) + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.TemplateProcessing( + single=f"$A:0 {ot.eos_token}:0", + special_tokens=[ + (ot.eos_token, ot.eos_token_id), + ], + ) + + return tokenizer + + +class XGLMConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + vocab += [("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0)] # fmt: skip + return vocab + + def unk_id(self, proto): + unk_id = 3 + return unk_id + + def post_processor(self): + return processors.TemplateProcessing( + single=" $A", + pair=" $A $B", + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + +class GemmaConverter(SpmConverter): + handle_byte_fallback = True + SpmExtractor = GemmaSentencePieceExtractor + # start and end of turn tokens must be marked as special + special_tokens = {"", ""} + + """" + split_by_unicode_script: true + split_by_number: true + split_by_whitespace: true + treat_whitespace_as_suffix: false + allow_whitespace_only_pieces: true + split_digits: true + byte_fallback: true + """ + + def normalizer(self, proto): + return normalizers.Replace(" ", "▁") + + def vocab(self, proto): + vocab = [ + (self.original_tokenizer.pad_token, 0.0), + (self.original_tokenizer.eos_token, 0.0), + (self.original_tokenizer.bos_token, 0.0), + ] + for piece in proto.pieces[3:]: + if piece.piece == "<0x09>": + vocab += [("\t", piece.score)] + else: + vocab += [(piece.piece, piece.score)] + # vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def pre_tokenizer(self, replacement, add_prefix_space): + return pre_tokenizers.Split(" ", "merged_with_previous") + + def unk_id(self, proto): + unk_id = 3 + return unk_id + + def decoder(self, replacement, add_prefix_space): + return decoders.Sequence( + [ + decoders.Replace("▁", " "), + decoders.ByteFallback(), + decoders.Fuse(), + ] + ) + + +class LlamaConverter(SpmConverter): + handle_byte_fallback = True + + def vocab(self, proto): + vocab = [ + (self.original_tokenizer.convert_ids_to_tokens(0), 0.0), + (self.original_tokenizer.convert_ids_to_tokens(1), 0.0), + (self.original_tokenizer.convert_ids_to_tokens(2), 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def unk_id(self, proto): + unk_id = 0 + return unk_id + + def decoder(self, replacement, add_prefix_space): + sequence = [ + decoders.Replace("▁", " "), + decoders.ByteFallback(), + decoders.Fuse(), + ] + if add_prefix_space: + sequence += [decoders.Strip(content=" ", left=1)] + return decoders.Sequence(sequence) + + def normalizer(self, proto): + if getattr(self.original_tokenizer, "legacy", True): + sequence = [] + if getattr(self.original_tokenizer, "add_prefix_space", True): + sequence += [normalizers.Prepend(prepend="▁")] + sequence += [normalizers.Replace(pattern=" ", content="▁")] + return normalizers.Sequence(sequence) + return None # non-legacy, no normalizer + + def pre_tokenizer(self, replacement, add_prefix_space): + if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False) + return None + + def post_processor(self): + # the processor is defined in the LlamaTokenizerFast class. + return None + + +class MarkupLMConverter(Converter): + def converted(self) -> Tokenizer: + ot = self.original_tokenizer + vocab = ot.encoder + merges = list(ot.bpe_ranks.keys()) + + tokenizer = Tokenizer( + BPE( + vocab=vocab, + merges=merges, + dropout=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + unk_token=self.original_tokenizer.unk_token, + ) + ) + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space) + tokenizer.decoder = decoders.ByteLevel() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls} $A {sep}", + pair=f"{cls} $A {sep} $B {sep}", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + + return tokenizer + + +class MoshiConverter(SpmConverter): + handle_byte_fallback = True + + def __init__(self, vocab_file, model_max_length=None, **kwargs): + requires_backends(self, "protobuf") + + Converter.__init__(self, vocab_file) + + # from .utils import sentencepiece_model_pb2 as model_pb2 + model_pb2 = import_protobuf() + + m = model_pb2.ModelProto() + with open(vocab_file, "rb") as f: + m.ParseFromString(f.read()) + self.proto = m + + def normalizer(self, proto): + precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap + _normalizers = [ + normalizers.Replace(" ", "▁"), + ] + if not precompiled_charsmap: + return normalizers.Sequence(_normalizers) + else: + return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers) + + def decoder(self, replacement, add_prefix_space): + sequence = [ + decoders.Replace("▁", " "), + decoders.ByteFallback(), + decoders.Fuse(), + ] + if add_prefix_space: + sequence += [decoders.Strip(content=" ", left=1)] + return decoders.Sequence(sequence) + + def pre_tokenizer(self, replacement, add_prefix_space): + prepend_scheme = "first" + return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +class TikTokenConverter: + """ + A general tiktoken converter. + """ + + def __init__( + self, + vocab_file=None, + pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + add_prefix_space=False, + additional_special_tokens=None, + *args, + **kwargs, + ): + super().__init__(*args) + self.vocab_file = vocab_file + self.pattern = pattern + self.add_prefix_space = add_prefix_space + self.additional_special_tokens = additional_special_tokens + + def extract_vocab_merges_from_model(self, tiktoken_url: str): + try: + from tiktoken.load import load_tiktoken_bpe + except Exception: + raise ValueError( + "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`." + ) + + bpe_ranks = load_tiktoken_bpe(tiktoken_url) + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + merges = [] + vocab = {} + for token, rank in bpe_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + local = [] + for index in range(1, len(token)): + piece_l, piece_r = token[:index], token[index:] + if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks: + local.append((piece_l, piece_r, rank)) + local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False) + merges.extend(local) + merges = sorted(merges, key=lambda val: val[2], reverse=False) + merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges] + return vocab, merges + + def tokenizer(self): + vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file) + tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False)) + if hasattr(tokenizer.model, "ignore_merges"): + tokenizer.model.ignore_merges = True + return tokenizer + + def converted(self) -> Tokenizer: + tokenizer = self.tokenizer() + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False), + ] + ) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.add_special_tokens(self.additional_special_tokens) + + tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + + return tokenizer + + +SLOW_TO_FAST_CONVERTERS = { + "AlbertTokenizer": AlbertConverter, + "BartTokenizer": RobertaConverter, + "BarthezTokenizer": BarthezConverter, + "BertTokenizer": BertConverter, + "BigBirdTokenizer": BigBirdConverter, + "BlenderbotTokenizer": BlenderbotConverter, + "CamembertTokenizer": CamembertConverter, + "CLIPTokenizer": CLIPConverter, + "CodeGenTokenizer": GPT2Converter, + "ConvBertTokenizer": BertConverter, + "DebertaTokenizer": DebertaConverter, + "DebertaV2Tokenizer": DebertaV2Converter, + "DistilBertTokenizer": BertConverter, + "DPRReaderTokenizer": BertConverter, + "DPRQuestionEncoderTokenizer": BertConverter, + "DPRContextEncoderTokenizer": BertConverter, + "ElectraTokenizer": BertConverter, + "FNetTokenizer": AlbertConverter, + "FunnelTokenizer": FunnelConverter, + "GPT2Tokenizer": GPT2Converter, + "HerbertTokenizer": HerbertConverter, + "LayoutLMTokenizer": BertConverter, + "LayoutLMv2Tokenizer": BertConverter, + "LayoutLMv3Tokenizer": RobertaConverter, + "LayoutXLMTokenizer": XLMRobertaConverter, + "LongformerTokenizer": RobertaConverter, + "LEDTokenizer": RobertaConverter, + "LxmertTokenizer": BertConverter, + "MarkupLMTokenizer": MarkupLMConverter, + "MBartTokenizer": MBartConverter, + "MBart50Tokenizer": MBart50Converter, + "MPNetTokenizer": MPNetConverter, + "MobileBertTokenizer": BertConverter, + "MvpTokenizer": RobertaConverter, + "NllbTokenizer": NllbConverter, + "OpenAIGPTTokenizer": OpenAIGPTConverter, + "PegasusTokenizer": PegasusConverter, + "Qwen2Tokenizer": Qwen2Converter, + "RealmTokenizer": BertConverter, + "ReformerTokenizer": ReformerConverter, + "RemBertTokenizer": RemBertConverter, + "RetriBertTokenizer": BertConverter, + "RobertaTokenizer": RobertaConverter, + "RoFormerTokenizer": RoFormerConverter, + "SeamlessM4TTokenizer": SeamlessM4TConverter, + "SqueezeBertTokenizer": BertConverter, + "T5Tokenizer": T5Converter, + "UdopTokenizer": UdopConverter, + "WhisperTokenizer": WhisperConverter, + "XLMRobertaTokenizer": XLMRobertaConverter, + "XLNetTokenizer": XLNetConverter, + "SplinterTokenizer": SplinterConverter, + "XGLMTokenizer": XGLMConverter, + "LlamaTokenizer": LlamaConverter, + "CodeLlamaTokenizer": LlamaConverter, + "GemmaTokenizer": GemmaConverter, + "Phi3Tokenizer": LlamaConverter, +} + + +def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer: + """ + Utilities to convert a slow tokenizer instance in a fast tokenizer instance. + + Args: + transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]): + Instance of a slow tokenizer to convert in the backend tokenizer for + [`~tokenization_utils_base.PreTrainedTokenizerFast`]. + from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece. + Defaults to False. + + Return: + A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a + [`~tokenization_utils_base.PreTrainedTokenizerFast`] + """ + + tokenizer_class_name = transformer_tokenizer.__class__.__name__ + if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken: + converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name] + return converter_class(transformer_tokenizer).converted() + + else: + try: + logger.info("Converting from Tiktoken") + return TikTokenConverter( + vocab_file=transformer_tokenizer.vocab_file, + additional_special_tokens=transformer_tokenizer.additional_special_tokens, + ).converted() + except Exception: + raise ValueError( + f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path " + f"with a SentencePiece tokenizer.model file." + f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}" + ) diff --git a/convert_slow_tokenizers_checkpoints_to_fast.py b/convert_slow_tokenizers_checkpoints_to_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..0b93e4c53ff891e70a5ce33a8868237c430b1b18 --- /dev/null +++ b/convert_slow_tokenizers_checkpoints_to_fast.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)""" + +import argparse +import os + +import transformers + +from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS +from .utils import logging + + +logging.set_verbosity_info() + +logger = logging.get_logger(__name__) + + +TOKENIZER_CLASSES = { + # Phi3 uses Llama tokenizer + name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast") + for name in SLOW_TO_FAST_CONVERTERS +} + + +def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download): + if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES: + raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.") + + if tokenizer_name is None: + tokenizer_names = TOKENIZER_CLASSES + else: + tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")} + + logger.info(f"Loading tokenizer classes: {tokenizer_names}") + + for tokenizer_name in tokenizer_names: + tokenizer_class = TOKENIZER_CLASSES[tokenizer_name] + + add_prefix = True + if checkpoint_name is None: + checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys()) + else: + checkpoint_names = [checkpoint_name] + + logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}") + + for checkpoint in checkpoint_names: + logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}") + + # Load tokenizer + tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download) + + # Save fast tokenizer + logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}") + + # For organization names we create sub-directories + if "/" in checkpoint: + checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/") + dump_path_full = os.path.join(dump_path, checkpoint_directory) + elif add_prefix: + checkpoint_prefix_name = checkpoint + dump_path_full = dump_path + else: + checkpoint_prefix_name = None + dump_path_full = dump_path + + logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}") + + if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]: + file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint] + next_char = file_path.split(checkpoint)[-1][0] + if next_char == "/": + dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name) + checkpoint_prefix_name = None + + logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}") + + file_names = tokenizer.save_pretrained( + dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name + ) + logger.info(f"=> File names {file_names}") + + for file_name in file_names: + if not file_name.endswith("tokenizer.json"): + os.remove(file_name) + logger.info(f"=> removing {file_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files." + ) + parser.add_argument( + "--tokenizer_name", + default=None, + type=str, + help=( + f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will " + "download and convert all the checkpoints from AWS." + ), + ) + parser.add_argument( + "--checkpoint_name", + default=None, + type=str, + help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.", + ) + parser.add_argument( + "--force_download", + action="store_true", + help="Re-download checkpoints.", + ) + args = parser.parse_args() + + convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download) diff --git a/convert_tf_hub_seq_to_seq_bert_to_pytorch.py b/convert_tf_hub_seq_to_seq_bert_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..8ccb033b3df1de87a29bfd608090386c16593c5f --- /dev/null +++ b/convert_tf_hub_seq_to_seq_bert_to_pytorch.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Seq2Seq TF Hub checkpoint.""" + +import argparse + +from . import ( + BertConfig, + BertGenerationConfig, + BertGenerationDecoder, + BertGenerationEncoder, + load_tf_weights_in_bert_generation, + logging, +) + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder): + # Initialise PyTorch model + bert_config = BertConfig.from_pretrained( + "google-bert/bert-large-cased", + vocab_size=vocab_size, + max_position_embeddings=512, + is_decoder=True, + add_cross_attention=True, + ) + bert_config_dict = bert_config.to_dict() + del bert_config_dict["type_vocab_size"] + config = BertGenerationConfig(**bert_config_dict) + if is_encoder: + model = BertGenerationEncoder(config) + else: + model = BertGenerationDecoder(config) + print(f"Building PyTorch model from configuration: {config}") + + # Load weights from tf checkpoint + load_tf_weights_in_bert_generation( + model, + tf_hub_path, + model_class="bert", + is_encoder_named_decoder=is_encoder_named_decoder, + is_encoder=is_encoder, + ) + + # Save pytorch-model + print(f"Save PyTorch model and config to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_encoder_named_decoder", + action="store_true", + help="If decoder has to be renamed to encoder in PyTorch model.", + ) + parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.") + parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model") + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_hub_path, + args.pytorch_dump_path, + args.is_encoder_named_decoder, + args.vocab_size, + is_encoder=args.is_encoder, + ) diff --git a/debug_utils.py b/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dbceb1d849076999c6821556accaea05e53a9ff9 --- /dev/null +++ b/debug_utils.py @@ -0,0 +1,346 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from .utils import ExplicitEnum, is_torch_available, logging + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class DebugUnderflowOverflow: + """ + This debug class helps detect and understand where the model starts getting very large or very small, and more + importantly `nan` or `inf` weight and activation elements. + + There are 2 working modes: + + 1. Underflow/overflow detection (default) + 2. Specific batch absolute min/max tracing without detection + + Mode 1: Underflow/overflow detection + + To activate the underflow/overflow detection, initialize the object with the model : + + ```python + debug_overflow = DebugUnderflowOverflow(model) + ``` + + then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output + elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event, + each frame reporting + + 1. the fully qualified module name plus the class name whose `forward` was run + 2. the absolute min and max value of all elements for each module weights, and the inputs and output + + For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 + mixed precision : + + ``` + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + [...] + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + ``` + + You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was + around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which + renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than + 64K, and we get an overlow. + + As you can see it's the previous frames that we need to look into when the numbers start going into very large for + fp16 numbers. + + The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed. + + By default the last 21 frames are printed. You can change the default to adjust for your needs. For example : + + ```python + debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) + ``` + + To validate that you have set up this debugging feature correctly, and you intend to use it in a training that + may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in + the next section. + + + Mode 2. Specific batch absolute min/max tracing without detection + + The second work mode is per-batch tracing with the underflow/overflow detection feature turned off. + + Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a + given batch, and only do that for batches 1 and 3. Then you instantiate this class as : + + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3]) + ``` + + And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed. + + This is helpful if you know that the program starts misbehaving after a certain batch number, so you can + fast-forward right to that area. + + + Early stopping: + + You can also specify the batch number after which to stop the training, with : + + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3) + ``` + + This feature is mainly useful in the tracing mode, but you can use it for any mode. + + + **Performance**: + + As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training + down. Therefore remember to turn it off once the debugging needs have been met. + + Args: + model (`nn.Module`): + The model to debug. + max_frames_to_save (`int`, *optional*, defaults to 21): + How many frames back to record + trace_batch_nums(`List[int]`, *optional*, defaults to `[]`): + Which batch numbers to trace (turns detection off) + abort_after_batch_num (`int``, *optional*): + Whether to abort after a certain batch number has finished + """ + + def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): + self.model = model + self.trace_batch_nums = trace_batch_nums + self.abort_after_batch_num = abort_after_batch_num + + # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence + self.frames = collections.deque([], max_frames_to_save) + self.frame = [] + self.batch_number = 0 + self.total_calls = 0 + self.detected_overflow = False + self.prefix = " " + + self.analyse_model() + + self.register_forward_hook() + + def save_frame(self, frame=None): + if frame is not None: + self.expand_frame(frame) + self.frames.append("\n".join(self.frame)) + self.frame = [] # start a new frame + + def expand_frame(self, line): + self.frame.append(line) + + def trace_frames(self): + print("\n".join(self.frames)) + self.frames = [] + + def reset_saved_frames(self): + self.frames = [] + + def dump_saved_frames(self): + print(f"\nDetected inf/nan during batch_number={self.batch_number}") + print(f"Last {len(self.frames)} forward frames:") + print(f"{'abs min':8} {'abs max':8} metadata") + print("\n".join(self.frames)) + print("\n\n") + self.frames = [] + + def analyse_model(self): + # extract the fully qualified module names, to be able to report at run time. e.g.: + # encoder.block.2.layer.0.SelfAttention.o + # + # for shared weights only the first shared module name will be registered + self.module_names = {m: name for name, m in self.model.named_modules()} + # self.longest_module_name = max(len(v) for v in self.module_names.values()) + + def analyse_variable(self, var, ctx): + if torch.is_tensor(var): + self.expand_frame(get_abs_min_max(var, ctx)) + if detect_overflow(var, ctx): + self.detected_overflow = True + elif var is None: + self.expand_frame(f"{'None':>17} {ctx}") + else: + self.expand_frame(f"{'not a tensor':>17} {ctx}") + + def batch_start_frame(self): + self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***") + self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") + + def batch_end_frame(self): + self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n") + + def create_frame(self, module, input, output): + self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") + + # params + for name, p in module.named_parameters(recurse=False): + self.analyse_variable(p, name) + + # inputs + if isinstance(input, tuple): + for i, x in enumerate(input): + self.analyse_variable(x, f"input[{i}]") + else: + self.analyse_variable(input, "input") + + # outputs + if isinstance(output, tuple): + for i, x in enumerate(output): + # possibly a tuple of tuples + if isinstance(x, tuple): + for j, y in enumerate(x): + self.analyse_variable(y, f"output[{i}][{j}]") + else: + self.analyse_variable(x, f"output[{i}]") + else: + self.analyse_variable(output, "output") + + self.save_frame() + + def register_forward_hook(self): + self.model.apply(self._register_forward_hook) + + def _register_forward_hook(self, module): + module.register_forward_hook(self.forward_hook) + + def forward_hook(self, module, input, output): + # - input is a tuple of packed inputs (could be non-Tensors) + # - output could be a Tensor or a tuple of Tensors and non-Tensors + + last_frame_of_batch = False + + trace_mode = True if self.batch_number in self.trace_batch_nums else False + if trace_mode: + self.reset_saved_frames() + + if self.total_calls == 0: + self.batch_start_frame() + self.total_calls += 1 + + # count batch numbers - the very first forward hook of the batch will be called when the + # batch completes - i.e. it gets called very last - we know this batch has finished + if module == self.model: + self.batch_number += 1 + last_frame_of_batch = True + + self.create_frame(module, input, output) + + # if last_frame_of_batch: + # self.batch_end_frame() + + if trace_mode: + self.trace_frames() + + if last_frame_of_batch: + self.batch_start_frame() + + if self.detected_overflow and not trace_mode: + self.dump_saved_frames() + + # now we can abort, as it's pointless to continue running + raise ValueError( + "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " + "Please scroll up above this traceback to see the activation values prior to this event." + ) + + # abort after certain batch if requested to do so + if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: + raise ValueError( + f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to" + f" `abort_after_batch_num={self.abort_after_batch_num}` arg" + ) + + +def get_abs_min_max(var, ctx): + abs_var = var.abs() + return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}" + + +def detect_overflow(var, ctx): + """ + Report whether the tensor contains any `nan` or `inf` entries. + + This is useful for detecting overflows/underflows and best to call right after the function that did some math that + modified the tensor in question. + + This function contains a few other helper features that you can enable and tweak directly if you want to track + various other things. + + Args: + var: the tensor variable to check + ctx: the message to print as a context + + Return: + `True` if `inf` or `nan` was detected, `False` otherwise + """ + detected = False + if torch.isnan(var).any().item(): + detected = True + print(f"{ctx} has nans") + if torch.isinf(var).any().item(): + detected = True + print(f"{ctx} has infs") + + # if needed to monitor large elements can enable the following + if 0: # and detected: + n100 = var[torch.ge(var.abs(), 100)] + if n100.numel() > 0: + print(f"{ctx}: n100={n100.numel()}") + n1000 = var[torch.ge(var.abs(), 1000)] + if n1000.numel() > 0: + print(f"{ctx}: n1000={n1000.numel()}") + n10000 = var[torch.ge(var.abs(), 10000)] + if n10000.numel() > 0: + print(f"{ctx}: n10000={n10000.numel()}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") + + return detected + + +class DebugOption(ExplicitEnum): + UNDERFLOW_OVERFLOW = "underflow_overflow" + TPU_METRICS_DEBUG = "tpu_metrics_debug" diff --git a/dependency_versions_check.py b/dependency_versions_check.py new file mode 100644 index 0000000000000000000000000000000000000000..82d07850847ec357f36ff51088ddec36aceff093 --- /dev/null +++ b/dependency_versions_check.py @@ -0,0 +1,63 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dependency_versions_table import deps +from .utils.versions import require_version, require_version_core + + +# define which module versions we always want to check at run time +# (usually the ones defined in `install_requires` in setup.py) +# +# order specific notes: +# - tqdm must be checked before tokenizers + +pkgs_to_check_at_runtime = [ + "python", + "tqdm", + "regex", + "requests", + "packaging", + "filelock", + "numpy", + "tokenizers", + "huggingface-hub", + "safetensors", + "accelerate", + "pyyaml", +] + +for pkg in pkgs_to_check_at_runtime: + if pkg in deps: + if pkg == "tokenizers": + # must be loaded here, or else tqdm check may fail + from .utils import is_tokenizers_available + + if not is_tokenizers_available(): + continue # not required, check version only if installed + elif pkg == "accelerate": + # must be loaded here, or else tqdm check may fail + from .utils import is_accelerate_available + + # Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of + # Transformers with PyTorch + if not is_accelerate_available(): + continue # not required, check version only if installed + + require_version_core(deps[pkg]) + else: + raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") + + +def dep_version_check(pkg, hint=None): + require_version(deps[pkg], hint) diff --git a/dependency_versions_table.py b/dependency_versions_table.py new file mode 100644 index 0000000000000000000000000000000000000000..26500c22b167b1894d9038bddff08cb949154405 --- /dev/null +++ b/dependency_versions_table.py @@ -0,0 +1,102 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "Pillow": "Pillow>=10.0.1,<=15.0", + "accelerate": "accelerate>=0.26.0", + "av": "av==9.2.0", + "beautifulsoup4": "beautifulsoup4", + "blobfile": "blobfile", + "codecarbon": "codecarbon>=2.8.1", + "cookiecutter": "cookiecutter==1.7.3", + "dataclasses": "dataclasses", + "datasets": "datasets!=2.5.0", + "deepspeed": "deepspeed>=0.9.3", + "diffusers": "diffusers", + "dill": "dill<0.3.5", + "evaluate": "evaluate>=0.2.0", + "faiss-cpu": "faiss-cpu", + "fastapi": "fastapi", + "filelock": "filelock", + "flax": "flax>=0.4.1,<=0.7.0", + "fsspec": "fsspec<2023.10.0", + "ftfy": "ftfy", + "fugashi": "fugashi>=1.0", + "GitPython": "GitPython<3.1.19", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "huggingface-hub": "huggingface-hub>=0.24.0,<1.0", + "importlib_metadata": "importlib_metadata", + "ipadic": "ipadic>=1.0.0,<2.0", + "isort": "isort>=5.5.4", + "jax": "jax>=0.4.1,<=0.4.13", + "jaxlib": "jaxlib>=0.4.1,<=0.4.13", + "jieba": "jieba", + "jinja2": "jinja2>=3.1.0", + "kenlm": "kenlm", + "keras": "keras>2.9,<2.16", + "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", + "librosa": "librosa", + "nltk": "nltk<=3.8.1", + "natten": "natten>=0.14.6,<0.15.0", + "numpy": "numpy>=1.17", + "onnxconverter-common": "onnxconverter-common", + "onnxruntime-tools": "onnxruntime-tools>=1.4.2", + "onnxruntime": "onnxruntime>=1.4.0", + "opencv-python": "opencv-python", + "optimum-benchmark": "optimum-benchmark>=0.3.0", + "optuna": "optuna", + "optax": "optax>=0.0.8,<=0.1.4", + "packaging": "packaging>=20.0", + "parameterized": "parameterized", + "phonemizer": "phonemizer", + "protobuf": "protobuf", + "psutil": "psutil", + "pyyaml": "pyyaml>=5.1", + "pydantic": "pydantic", + "pytest": "pytest>=7.2.0,<8.0.0", + "pytest-asyncio": "pytest-asyncio", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "python": "python>=3.9.0", + "ray[tune]": "ray[tune]>=2.7.0", + "regex": "regex!=2019.12.17", + "requests": "requests", + "rhoknp": "rhoknp>=1.1.0,<1.3.1", + "rjieba": "rjieba", + "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", + "ruff": "ruff==0.5.1", + "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", + "sacremoses": "sacremoses", + "safetensors": "safetensors>=0.4.1", + "sagemaker": "sagemaker>=2.31.0", + "schedulefree": "schedulefree>=1.2.6", + "scikit-learn": "scikit-learn", + "scipy": "scipy<1.13.0", + "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", + "sigopt": "sigopt", + "starlette": "starlette", + "sudachipy": "sudachipy>=0.6.6", + "sudachidict_core": "sudachidict_core>=20220729", + "tensorboard": "tensorboard", + "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16", + "tensorflow": "tensorflow>2.9,<2.16", + "tensorflow-text": "tensorflow-text<2.16", + "tensorflow-probability": "tensorflow-probability<0.24", + "tf2onnx": "tf2onnx", + "timeout-decorator": "timeout-decorator", + "tiktoken": "tiktoken", + "timm": "timm<=1.0.11", + "tokenizers": "tokenizers>=0.21,<0.22", + "torch": "torch>=2.0", + "torchaudio": "torchaudio", + "torchvision": "torchvision", + "pyctcdecode": "pyctcdecode>=0.4.0", + "tqdm": "tqdm>=4.27", + "unidic": "unidic>=1.0.2", + "unidic_lite": "unidic_lite>=1.0.7", + "urllib3": "urllib3<2.0.0", + "uvicorn": "uvicorn", + "pytest-rich": "pytest-rich", + "libcst": "libcst", + "rich": "rich", +} diff --git a/dynamic_module_utils.py b/dynamic_module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf44d4b427cf7b7bf76e1c550fa08b3dbc56b673 --- /dev/null +++ b/dynamic_module_utils.py @@ -0,0 +1,685 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import filecmp +import hashlib +import importlib +import importlib.util +import os +import re +import shutil +import signal +import sys +import threading +import typing +import warnings +from pathlib import Path +from types import ModuleType +from typing import Any, Dict, List, Optional, Union + +from huggingface_hub import try_to_load_from_cache + +from .utils import ( + HF_MODULES_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + cached_file, + extract_commit_hash, + is_offline_mode, + logging, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name +_HF_REMOTE_CODE_LOCK = threading.Lock() + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + importlib.invalidate_caches() + + +def create_dynamic_module(name: Union[str, os.PathLike]) -> None: + """ + Creates a dynamic module in the cache directory for modules. + + Args: + name (`str` or `os.PathLike`): + The name of the dynamic module to create. + """ + init_hf_modules() + dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve() + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up + # with errors about module that do not exist. Same for all other `invalidate_caches` in this file. + importlib.invalidate_caches() + + +def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]: + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of relative imports in the module. + """ + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]: + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list + of module files a given module needs. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [str(module_path / m) for m in new_imports] + new_import_files = [f for f in new_import_files if f not in all_relative_imports] + files_to_check = [f"{f}.py" for f in new_import_files] + + no_change = len(new_import_files) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def get_imports(filename: Union[str, os.PathLike]) -> List[str]: + """ + Extracts all the libraries (not relative imports this time) that are imported in a file. + + Args: + filename (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of all packages required to use the input module. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # filter out try/except block so in custom code we can have try/except imports + content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL) + + # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment + content = re.sub( + r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE + ) + + # Imports of the form `import xxx` + imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + return list(set(imports)) + + +def check_imports(filename: Union[str, os.PathLike]) -> List[str]: + """ + Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a + library is missing. + + Args: + filename (`str` or `os.PathLike`): The module file to check. + + Returns: + `List[str]`: The list of relative imports in the file. + """ + imports = get_imports(filename) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError as exception: + logger.warning(f"Encountered exception while importing {imp}: {exception}") + # Some packages can fail with an ImportError because of a dependency issue. + # This check avoids hiding such errors. + # See https://github.com/huggingface/transformers/issues/33604 + if "No module named" in str(exception): + missing_packages.append(imp) + else: + raise + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, +) -> typing.Type: + """ + Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + force_reload (`bool`, *optional*, defaults to `False`): + Whether to reload the dynamic module from file if it already exists in `sys.modules`. + Otherwise, the module is only reloaded if the file has changed. + + Returns: + `typing.Type`: The class looked for. + """ + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + # Hash the module file and all its relative imports to check if we need to reload it + module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) + module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + # reload in both cases, unless the module is already imported and the hash hits + if getattr(module, "__transformers_module_hash__", "") != module_hash: + module_spec.loader.exec_module(module) + module.__transformers_module_hash__ = module_hash + return getattr(module, class_name) + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + repo_type: Optional[str] = None, + _commit_hash: Optional[str] = None, + **deprecated_kwargs, +) -> str: + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + repo_type (`str`, *optional*): + Specify the repo type (useful when downloading from a space for instance). + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `str`: The path to the module inside the cache. + """ + use_auth_token = deprecated_kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + submodule = os.path.basename(pretrained_model_name_or_path) + else: + submodule = pretrained_model_name_or_path.replace("/", os.path.sep) + cached_module = try_to_load_from_cache( + pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type + ) + + new_files = [] + try: + # Load from URL or cache if already cached + resolved_module_file = cached_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + revision=revision, + repo_type=repo_type, + _commit_hash=_commit_hash, + ) + if not is_local and cached_module != resolved_module_file: + new_files.append(module_file) + + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == os.path.basename(pretrained_model_name_or_path): + # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or + # has changed since last copy. + if not (submodule_path / module_file).exists() or not filecmp.cmp( + resolved_module_file, str(submodule_path / module_file) + ): + shutil.copy(resolved_module_file, submodule_path / module_file) + importlib.invalidate_caches() + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed) + if not (submodule_path / module_needed).exists() or not filecmp.cmp( + module_needed_file, str(submodule_path / module_needed) + ): + shutil.copy(module_needed_file, submodule_path / module_needed) + importlib.invalidate_caches() + else: + # Get the commit hash + commit_hash = extract_commit_hash(resolved_module_file, _commit_hash) + + # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the + # benefit of versioning. + submodule_path = submodule_path / commit_hash + full_submodule = full_submodule + os.path.sep + commit_hash + create_dynamic_module(full_submodule) + + if not (submodule_path / module_file).exists(): + shutil.copy(resolved_module_file, submodule_path / module_file) + importlib.invalidate_caches() + # Make sure we also have every file with relative + for module_needed in modules_needed: + if not (submodule_path / f"{module_needed}.py").exists(): + get_cached_module_file( + pretrained_model_name_or_path, + f"{module_needed}.py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + _commit_hash=commit_hash, + ) + new_files.append(f"{module_needed}.py") + + if len(new_files) > 0 and revision is None: + new_files = "\n".join([f"- {f}" for f in new_files]) + repo_type_str = "" if repo_type is None else f"{repo_type}s/" + url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}" + logger.warning( + f"A new version of the following files was downloaded from {url}:\n{new_files}" + "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new " + "versions of the code file, you can pin a revision." + ) + + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + class_reference: str, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + repo_type: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, +) -> typing.Type: + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + + + Args: + class_reference (`str`): + The full name of the class to load, including its module and optionally its repo. + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + This is used when `class_reference` does not specify another repo. + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + repo_type (`str`, *optional*): + Specify the repo type (useful when downloading from a space for instance). + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than the + rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for + storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `typing.Type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model") + + # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + # Catch the name of the repo if it's specified in `class_reference` + if "--" in class_reference: + repo_id, class_reference = class_reference.split("--") + else: + repo_id = pretrained_model_name_or_path + module_file, class_name = class_reference.split(".") + + if code_revision is None and pretrained_model_name_or_path == repo_id: + code_revision = revision + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + repo_id, + module_file + ".py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=code_revision, + local_files_only=local_files_only, + repo_type=repo_type, + ) + return get_class_in_module(class_name, final_module, force_reload=force_download) + + +def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]: + """ + Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally + adds the proper fields in a config. + + Args: + obj (`Any`): The object for which to save the module files. + folder (`str` or `os.PathLike`): The folder where to save. + config (`PretrainedConfig` or dictionary, `optional`): + A config in which to register the auto_map corresponding to this custom object. + + Returns: + `List[str]`: The list of files saved. + """ + if obj.__module__ == "__main__": + logger.warning( + f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put " + "this code in a separate module so we can include it in the saved folder and make it easier to share via " + "the Hub." + ) + return + + def _set_auto_map_in_config(_config): + module_name = obj.__class__.__module__ + last_module = module_name.split(".")[-1] + full_name = f"{last_module}.{obj.__class__.__name__}" + # Special handling for tokenizers + if "Tokenizer" in full_name: + slow_tokenizer_class = None + fast_tokenizer_class = None + if obj.__class__.__name__.endswith("Fast"): + # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute. + fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}" + if getattr(obj, "slow_tokenizer_class", None) is not None: + slow_tokenizer = getattr(obj, "slow_tokenizer_class") + slow_tok_module_name = slow_tokenizer.__module__ + last_slow_tok_module = slow_tok_module_name.split(".")[-1] + slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}" + else: + # Slow tokenizer: no way to have the fast class + slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}" + + full_name = (slow_tokenizer_class, fast_tokenizer_class) + + if isinstance(_config, dict): + auto_map = _config.get("auto_map", {}) + auto_map[obj._auto_class] = full_name + _config["auto_map"] = auto_map + elif getattr(_config, "auto_map", None) is not None: + _config.auto_map[obj._auto_class] = full_name + else: + _config.auto_map = {obj._auto_class: full_name} + + # Add object class to the config auto_map + if isinstance(config, (list, tuple)): + for cfg in config: + _set_auto_map_in_config(cfg) + elif config is not None: + _set_auto_map_in_config(config) + + result = [] + # Copy module file to the output folder. + object_file = sys.modules[obj.__module__].__file__ + dest_file = Path(folder) / (Path(object_file).name) + shutil.copy(object_file, dest_file) + result.append(dest_file) + + # Gather all relative imports recursively and make sure they are copied as well. + for needed_file in get_relative_import_files(object_file): + dest_file = Path(folder) / (Path(needed_file).name) + shutil.copy(needed_file, dest_file) + result.append(dest_file) + + return result + + +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +TIME_OUT_REMOTE_CODE = 15 + + +def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code): + if trust_remote_code is None: + if has_local_code: + trust_remote_code = False + elif has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not has_local_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code diff --git a/feature_extraction_sequence_utils.py b/feature_extraction_sequence_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f74a3f0c40e28415644b2b2b4b81ad7ed9320a56 --- /dev/null +++ b/feature_extraction_sequence_utils.py @@ -0,0 +1,372 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Sequence feature extraction class for common feature extractors to preprocess sequences. +""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy + + +logger = logging.get_logger(__name__) + + +class SequenceFeatureExtractor(FeatureExtractionMixin): + """ + This is a general feature extraction class for speech recognition. + + Args: + feature_size (`int`): + The feature dimension of the extracted features. + sampling_rate (`int`): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`): + The value that is used to fill the padding values / vectors. + """ + + def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs): + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.padding_value = padding_value + + self.padding_side = kwargs.pop("padding_side", "right") + self.return_attention_mask = kwargs.pop("return_attention_mask", True) + + super().__init__(**kwargs) + + def pad( + self, + processed_features: Union[ + BatchFeature, + List[BatchFeature], + Dict[str, BatchFeature], + Dict[str, List[BatchFeature]], + List[Dict[str, BatchFeature]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + """ + Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the + max sequence length in the batch. + + Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`, + `self.padding_value`) + + + + If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + + + Args: + processed_features ([`BatchFeature`], list of [`BatchFeature`], `Dict[str, List[float]]`, `Dict[str, List[List[float]]` or `List[Dict[str, List[float]]]`): + Processed inputs. Can represent one input ([`BatchFeature`] or `Dict[str, List[float]]`) or a batch of + input values / vectors (list of [`BatchFeature`], *Dict[str, List[List[float]]]* or *List[Dict[str, + List[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. + + Instead of `List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), + see the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)): + processed_features = { + key: [example[key] for example in processed_features] for key in processed_features[0].keys() + } + + # The model's main input name, usually `input_values`, has be passed for padding + if self.model_input_names[0] not in processed_features: + raise ValueError( + "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`" + f" to this method that includes {self.model_input_names[0]}, but you provided" + f" {list(processed_features.keys())}" + ) + + required_input = processed_features[self.model_input_names[0]] + return_attention_mask = ( + return_attention_mask if return_attention_mask is not None else self.return_attention_mask + ) + + if len(required_input) == 0: + if return_attention_mask: + processed_features["attention_mask"] = [] + return processed_features + + # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + index = 0 + while len(required_input[index]) == 0: + index += 1 + if index < len(required_input): + first_element = required_input[index][0] + + if return_tensors is None: + if is_tf_tensor(first_element): + return_tensors = "tf" + elif is_torch_tensor(first_element): + return_tensors = "pt" + elif isinstance(first_element, (int, float, list, tuple, np.ndarray)): + return_tensors = "np" + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + "Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in processed_features.items(): + if isinstance(value[0], (int, float)): + processed_features[key] = to_numpy(value) + else: + processed_features[key] = [to_numpy(v) for v in value] + + # Convert padding_strategy in PaddingStrategy + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + + required_input = processed_features[self.model_input_names[0]] + + batch_size = len(required_input) + if not all(len(v) == batch_size for v in processed_features.values()): + raise ValueError("Some items in the output dictionary have a different batch size than others.") + + truncated_inputs = [] + for i in range(batch_size): + inputs = {k: v[i] for k, v in processed_features.items()} + # truncation + inputs_slice = self._truncate( + inputs, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + truncation=truncation, + ) + truncated_inputs.append(inputs_slice) + + if padding_strategy == PaddingStrategy.LONGEST: + # make sure that `max_length` cannot be longer than the longest truncated length + max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + # padding + outputs = self._pad( + truncated_inputs[i], + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + if value.dtype is np.dtype(np.float64): + value = value.astype(np.float32) + batch_outputs[key].append(value) + + return BatchFeature(batch_outputs, tensor_type=return_tensors) + + def _pad( + self, + processed_features: Union[Dict[str, np.ndarray], BatchFeature], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad inputs (on left/right and up to predefined length or max length in the batch) + + Args: + processed_features (`Union[Dict[str, np.ndarray], BatchFeature]`): + Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch + of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`) + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see below) + padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`): + PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The feature_extractor padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of (`int`, *optional*): + Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to + enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs + which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Set to False to avoid returning attention mask (default: set to model specifics) + """ + required_input = processed_features[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length + + if return_attention_mask and "attention_mask" not in processed_features: + processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32) + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + processed_features["attention_mask"] = np.pad( + processed_features["attention_mask"], (0, difference) + ) + padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference) + processed_features[self.model_input_names[0]] = np.pad( + required_input, padding_shape, "constant", constant_values=self.padding_value + ) + elif self.padding_side == "left": + if return_attention_mask: + processed_features["attention_mask"] = np.pad( + processed_features["attention_mask"], (difference, 0) + ) + padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0) + processed_features[self.model_input_names[0]] = np.pad( + required_input, padding_shape, "constant", constant_values=self.padding_value + ) + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return processed_features + + def _truncate( + self, + processed_features: Union[Dict[str, np.ndarray], BatchFeature], + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + truncation: Optional[bool] = None, + ): + """ + Truncate inputs to predefined length or max length in the batch + + Args: + processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`): + Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch + of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`) + max_length (`int`, *optional*): + maximum length of the returned list and optionally padding length (see below) + pad_to_multiple_of (`int`, *optional*) : + Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to + enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs + which benefit from having sequence lengths be a multiple of 128. + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + """ + if not truncation: + return processed_features + elif truncation and max_length is None: + raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + + required_input = processed_features[self.model_input_names[0]] + + # find `max_length` that fits `pad_to_multiple_of` + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_truncated = len(required_input) > max_length + + if needs_to_be_truncated: + processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length] + if "attention_mask" in processed_features: + processed_features["attention_mask"] = processed_features["attention_mask"][:max_length] + + return processed_features + + def _get_padding_strategies(self, padding=False, max_length=None): + """ + Find the correct padding strategy + """ + + # Get padding strategy + if padding is not False: + if padding is True: + padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD + + # Set max length if needed + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + raise ValueError( + f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined" + ) + + # Test if we have a padding value + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None): + raise ValueError( + "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use" + " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`." + ) + + return padding_strategy diff --git a/feature_extraction_utils.py b/feature_extraction_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8007edbc0b782970887e83a08ba518de0f264a --- /dev/null +++ b/feature_extraction_utils.py @@ -0,0 +1,702 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extraction saving/loading class for common feature extractors. +""" + +import copy +import json +import os +import warnings +from collections import UserDict +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import numpy as np + +from .dynamic_module_utils import custom_object_save +from .utils import ( + FEATURE_EXTRACTOR_NAME, + PushToHubMixin, + TensorType, + add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, + cached_file, + copy_func, + download_url, + is_flax_available, + is_jax_tensor, + is_numpy_array, + is_offline_mode, + is_remote_url, + is_tf_available, + is_torch_available, + is_torch_device, + is_torch_dtype, + logging, + requires_backends, +) + + +if TYPE_CHECKING: + if is_torch_available(): + import torch # noqa + + +logger = logging.get_logger(__name__) + +PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821 + + +class BatchFeature(UserDict): + r""" + Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`, *optional*): + Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask', + etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def __getitem__(self, item: str) -> Union[Any]: + """ + If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask', + etc.). + """ + if isinstance(item, str): + return self.data[item] + else: + raise KeyError("Indexing with integers is not available when using Python based feature extractors") + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + # Copied from transformers.tokenization_utils_base.BatchEncoding.keys + def keys(self): + return self.data.keys() + + # Copied from transformers.tokenization_utils_base.BatchEncoding.values + def values(self): + return self.data.values() + + # Copied from transformers.tokenization_utils_base.BatchEncoding.items + def items(self): + return self.data.items() + + def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None): + if tensor_type is None: + return None, None + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + as_tensor = tf.constant + is_tensor = tf.is_tensor + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch # noqa + + def as_tensor(value): + if isinstance(value, (list, tuple)) and len(value) > 0: + if isinstance(value[0], np.ndarray): + value = np.array(value) + elif ( + isinstance(value[0], (list, tuple)) + and len(value[0]) > 0 + and isinstance(value[0][0], np.ndarray) + ): + value = np.array(value) + if isinstance(value, np.ndarray): + return torch.from_numpy(value) + else: + return torch.tensor(value) + + is_tensor = torch.is_tensor + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + as_tensor = jnp.array + is_tensor = is_jax_tensor + else: + + def as_tensor(value, dtype=None): + if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)): + value_lens = [len(val) for val in value] + if len(set(value_lens)) > 1 and dtype is None: + # we have a ragged list so handle explicitly + value = as_tensor([np.asarray(val) for val in value], dtype=object) + return np.asarray(value, dtype=dtype) + + is_tensor = is_numpy_array + return is_tensor, as_tensor + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + """ + if tensor_type is None: + return self + + is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if not is_tensor(value): + tensor = as_tensor(value) + + self[key] = tensor + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + return self + + def to(self, *args, **kwargs) -> "BatchFeature": + """ + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. + + Args: + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. + To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`). + + Returns: + [`BatchFeature`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + import torch # noqa + + new_data = {} + device = kwargs.get("device") + non_blocking = kwargs.get("non_blocking", False) + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + # check if v is a floating point + if isinstance(v, torch.Tensor) and torch.is_floating_point(v): + # cast and send to device + new_data[k] = v.to(*args, **kwargs) + elif isinstance(v, torch.Tensor) and device is not None: + new_data[k] = v.to(device=device, non_blocking=non_blocking) + else: + new_data[k] = v + self.data = new_data + return self + + +class FeatureExtractionMixin(PushToHubMixin): + """ + This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature + extractors. + """ + + _auto_class = None + + def __init__(self, **kwargs): + """Set elements of `kwargs` as attributes.""" + # Pop "processor_class" as it should be saved as private attribute + self._processor_class = kwargs.pop("processor_class", None) + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def _set_processor_class(self, processor_class: str): + """Sets processor class as an attribute.""" + self._processor_class = processor_class + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + r""" + Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a + derived class of [`SequenceFeatureExtractor`]. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a feature extractor file saved using the + [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved feature extractor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model feature extractor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the feature extractor files and override the cached versions + if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final feature extractor object. If `True`, then this + functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of + `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. + kwargs (`Dict[str, Any]`, *optional*): + The values in kwargs of any keys which are feature extractor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + Returns: + A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]. + + Examples: + + ```python + # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a + # derived class: *Wav2Vec2FeatureExtractor* + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base-960h" + ) # Download feature_extraction_config from huggingface.co and cache. + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "./test/saved_model/" + ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')* + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json") + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False + ) + assert feature_extractor.return_attention_mask is False + feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True + ) + assert feature_extractor.return_attention_mask is False + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(feature_extractor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the + [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the feature extractor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token", None) is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) + + self.to_json_file(output_feature_extractor_file) + logger.info(f"Feature extractor saved in {output_feature_extractor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_feature_extractor_file] + + @classmethod + def get_feature_extractor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + subfolder = kwargs.pop("subfolder", None) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) + if os.path.isfile(pretrained_model_name_or_path): + resolved_feature_extractor_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + feature_extractor_file = pretrained_model_name_or_path + resolved_feature_extractor_file = download_url(pretrained_model_name_or_path) + else: + feature_extractor_file = FEATURE_EXTRACTOR_NAME + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_feature_extractor_file = cached_file( + pretrained_model_name_or_path, + feature_extractor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + subfolder=subfolder, + token=token, + user_agent=user_agent, + revision=revision, + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {FEATURE_EXTRACTOR_NAME} file" + ) + + try: + # Load feature_extractor dict + with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: + text = reader.read() + feature_extractor_dict = json.loads(text) + + except json.JSONDecodeError: + raise EnvironmentError( + f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_feature_extractor_file}") + else: + logger.info( + f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" + ) + + if not is_local: + if "auto_map" in feature_extractor_dict: + feature_extractor_dict["auto_map"] = add_model_info_to_auto_map( + feature_extractor_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in feature_extractor_dict: + feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path + ) + + return feature_extractor_dict, kwargs + + @classmethod + def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: + """ + Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of + parameters. + + Args: + feature_extractor_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the feature extractor object. + + Returns: + [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those + parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # Update feature_extractor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if key in feature_extractor_dict: + feature_extractor_dict[key] = value + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + feature_extractor = cls(**feature_extractor_dict) + + logger.info(f"Feature extractor {feature_extractor}") + if return_unused_kwargs: + return feature_extractor, kwargs + else: + return feature_extractor + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + if "window" in output: + del output["window"] + return output + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor: + """ + Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to + a JSON file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor + object instantiated from that JSON file. + """ + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + feature_extractor_dict = json.loads(text) + return cls(**feature_extractor_dict) + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this feature_extractor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"): + """ + Register this class with a given auto class. This should only be used for custom feature extractors as the ones + in the library are already mapped with `AutoFeatureExtractor`. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`): + The auto class to register this new feature extractor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + +FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub) +if FeatureExtractionMixin.push_to_hub.__doc__ is not None: + FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format( + object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file" + ) diff --git a/file_utils.py b/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c13d506a613795492f73b3b6c9519096f254f16b --- /dev/null +++ b/file_utils.py @@ -0,0 +1,133 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +File utilities: utilities related to download and cache models + +This module should not be update anymore and is only left for backward compatibility. +""" + +from huggingface_hub import get_full_repo_name # for backward compatibility +from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility + +from . import __version__ + +# Backward compatibility imports, to make sure all those objects can be found in file_utils +from .utils import ( + CLOUDFRONT_DISTRIB_PREFIX, + CONFIG_NAME, + DUMMY_INPUTS, + DUMMY_MASK, + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + FEATURE_EXTRACTOR_NAME, + FLAX_WEIGHTS_NAME, + HF_MODULES_CACHE, + HUGGINGFACE_CO_PREFIX, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + MODEL_CARD_NAME, + MULTIPLE_CHOICE_DUMMY_INPUTS, + PYTORCH_PRETRAINED_BERT_CACHE, + PYTORCH_TRANSFORMERS_CACHE, + S3_BUCKET_PREFIX, + SENTENCEPIECE_UNDERLINE, + SPIECE_UNDERLINE, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + TORCH_FX_REQUIRED_VERSION, + TRANSFORMERS_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + USE_JAX, + USE_TF, + USE_TORCH, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ContextManagers, + DummyObject, + EntryNotFoundError, + ExplicitEnum, + ModelOutput, + PaddingStrategy, + PushToHubMixin, + RepositoryNotFoundError, + RevisionNotFoundError, + TensorType, + _LazyModule, + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + cached_property, + copy_func, + default_cache_path, + define_sagemaker_information, + get_cached_models, + get_file_from_repo, + get_torch_version, + has_file, + http_user_agent, + is_apex_available, + is_bs4_available, + is_coloredlogs_available, + is_datasets_available, + is_detectron2_available, + is_faiss_available, + is_flax_available, + is_ftfy_available, + is_g2p_en_available, + is_in_notebook, + is_ipex_available, + is_librosa_available, + is_offline_mode, + is_onnx_available, + is_pandas_available, + is_phonemizer_available, + is_protobuf_available, + is_psutil_available, + is_py3nvml_available, + is_pyctcdecode_available, + is_pytesseract_available, + is_pytorch_quantization_available, + is_rjieba_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_scipy_available, + is_sentencepiece_available, + is_seqio_available, + is_sklearn_available, + is_soundfile_available, + is_spacy_available, + is_speech_available, + is_tensor, + is_tensorflow_probability_available, + is_tf2onnx_available, + is_tf_available, + is_timm_available, + is_tokenizers_available, + is_torch_available, + is_torch_bf16_available, + is_torch_cuda_available, + is_torch_fx_available, + is_torch_fx_proxy, + is_torch_mps_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torchaudio_available, + is_training_run_on_sagemaker, + is_vision_available, + replace_return_docstrings, + requires_backends, + to_numpy, + to_py_obj, + torch_only_method, +) diff --git a/hf_argparser.py b/hf_argparser.py new file mode 100644 index 0000000000000000000000000000000000000000..d03ff7004f2b6c0f14af55efd1bd4b8336dde305 --- /dev/null +++ b/hf_argparser.py @@ -0,0 +1,437 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import json +import os +import sys +import types +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError +from copy import copy +from enum import Enum +from inspect import isclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints + +import yaml + + +DataClass = NewType("DataClass", Any) +DataClassType = NewType("DataClassType", Any) + + +# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse +def string_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ArgumentTypeError( + f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." + ) + + +def make_choice_type_function(choices: list) -> Callable[[str], Any]: + """ + Creates a mapping function from each choices string representation to the actual value. Used to support multiple + value types for a single argument. + + Args: + choices (list): List of choices. + + Returns: + Callable[[str], Any]: Mapping function from string representation to actual value for each choice. + """ + str_to_choice = {str(choice): choice for choice in choices} + return lambda arg: str_to_choice.get(arg, arg) + + +def HfArg( + *, + aliases: Union[str, List[str]] = None, + help: str = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[[], Any] = dataclasses.MISSING, + metadata: dict = None, + **kwargs, +) -> dataclasses.Field: + """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. + + Example comparing the use of `HfArg` and `dataclasses.field`: + ``` + @dataclass + class Args: + regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) + hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") + ``` + + Args: + aliases (Union[str, List[str]], optional): + Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. + Defaults to None. + help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. + default (Any, optional): + Default value for the argument. If not default or default_factory is specified, the argument is required. + Defaults to dataclasses.MISSING. + default_factory (Callable[[], Any], optional): + The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide + default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. + Defaults to dataclasses.MISSING. + metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. + + Returns: + Field: A `dataclasses.Field` with the desired properties. + """ + if metadata is None: + # Important, don't use as default param in function signature because dict is mutable and shared across function calls + metadata = {} + if aliases is not None: + metadata["aliases"] = aliases + if help is not None: + metadata["help"] = help + + return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) + + +class HfArgumentParser(ArgumentParser): + """ + This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. + + The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) + arguments to the parser after initialization and you'll get the output back after parsing as an additional + namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. + """ + + dataclass_types: Iterable[DataClassType] + + def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): + """ + Args: + dataclass_types: + Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. + kwargs (`Dict[str, Any]`, *optional*): + Passed to `argparse.ArgumentParser()` in the regular way. + """ + # To make the default appear when using --help + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter + super().__init__(**kwargs) + if dataclasses.is_dataclass(dataclass_types): + dataclass_types = [dataclass_types] + self.dataclass_types = list(dataclass_types) + for dtype in self.dataclass_types: + self._add_dataclass_arguments(dtype) + + @staticmethod + def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): + # Long-option strings are conventionlly separated by hyphens rather + # than underscores, e.g., "--long-format" rather than "--long_format". + # Argparse converts hyphens to underscores so that the destination + # string is a valid attribute name. Hf_argparser should do the same. + long_options = [f"--{field.name}"] + if "_" in field.name: + long_options.append(f"--{field.name.replace('_', '-')}") + + kwargs = field.metadata.copy() + # field.metadata is not used at all by Data Classes, + # it is provided as a third-party extension mechanism. + if isinstance(field.type, str): + raise RuntimeError( + "Unresolved type detected, which should have been done with the help of " + "`typing.get_type_hints` method by default" + ) + + aliases = kwargs.pop("aliases", []) + if isinstance(aliases, str): + aliases = [aliases] + + origin_type = getattr(field.type, "__origin__", field.type) + if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): + if str not in field.type.__args__ and ( + len(field.type.__args__) != 2 or type(None) not in field.type.__args__ + ): + raise ValueError( + "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" + " the argument parser only supports one type per argument." + f" Problem encountered in field '{field.name}'." + ) + if type(None) not in field.type.__args__: + # filter `str` in Union + field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1] + origin_type = getattr(field.type, "__origin__", field.type) + elif bool not in field.type.__args__: + # filter `NoneType` in Union (except for `Union[bool, NoneType]`) + field.type = ( + field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] + ) + origin_type = getattr(field.type, "__origin__", field.type) + + # A variable to store kwargs for a boolean field, if needed + # so that we can init a `no_*` complement argument (see below) + bool_kwargs = {} + if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): + if origin_type is Literal: + kwargs["choices"] = field.type.__args__ + else: + kwargs["choices"] = [x.value for x in field.type] + + kwargs["type"] = make_choice_type_function(kwargs["choices"]) + + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + else: + kwargs["required"] = True + elif field.type is bool or field.type == Optional[bool]: + # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. + # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument + bool_kwargs = copy(kwargs) + + # Hack because type=bool in argparse does not behave as we want. + kwargs["type"] = string_to_bool + if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): + # Default value is False if we have no default when of type bool. + default = False if field.default is dataclasses.MISSING else field.default + # This is the value that will get picked if we don't include --{field.name} in any way + kwargs["default"] = default + # This tells argparse we accept 0 or 1 value after --{field.name} + kwargs["nargs"] = "?" + # This is the value that will get picked if we do --{field.name} (without value) + kwargs["const"] = True + elif isclass(origin_type) and issubclass(origin_type, list): + kwargs["type"] = field.type.__args__[0] + kwargs["nargs"] = "+" + if field.default_factory is not dataclasses.MISSING: + kwargs["default"] = field.default_factory() + elif field.default is dataclasses.MISSING: + kwargs["required"] = True + else: + kwargs["type"] = field.type + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kwargs["default"] = field.default_factory() + else: + kwargs["required"] = True + parser.add_argument(*long_options, *aliases, **kwargs) + + # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. + # Order is important for arguments with the same destination! + # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down + # here and we do not need those changes/additional keys. + if field.default is True and (field.type is bool or field.type == Optional[bool]): + bool_kwargs["default"] = False + parser.add_argument( + f"--no_{field.name}", + f"--no-{field.name.replace('_', '-')}", + action="store_false", + dest=field.name, + **bool_kwargs, + ) + + def _add_dataclass_arguments(self, dtype: DataClassType): + if hasattr(dtype, "_argument_group_name"): + parser = self.add_argument_group(dtype._argument_group_name) + else: + parser = self + + try: + type_hints: Dict[str, type] = get_type_hints(dtype) + except NameError: + raise RuntimeError( + f"Type resolution failed for {dtype}. Try declaring the class in global scope or " + "removing line of `from __future__ import annotations` which opts in Postponed " + "Evaluation of Annotations (PEP 563)" + ) + except TypeError as ex: + # Remove this block when we drop Python 3.9 support + if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): + python_version = ".".join(map(str, sys.version_info[:3])) + raise RuntimeError( + f"Type resolution failed for {dtype} on Python {python_version}. Try removing " + "line of `from __future__ import annotations` which opts in union types as " + "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " + "support Python versions that lower than 3.10, you need to use " + "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " + "`X | None`." + ) from ex + raise + + for field in dataclasses.fields(dtype): + if not field.init: + continue + field.type = type_hints[field.name] + self._parse_dataclass_field(parser, field) + + def parse_args_into_dataclasses( + self, + args=None, + return_remaining_strings=False, + look_for_args_file=True, + args_filename=None, + args_file_flag=None, + ) -> Tuple[DataClass, ...]: + """ + Parse command-line args into instances of the specified dataclass types. + + This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: + docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args + + Args: + args: + List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) + return_remaining_strings: + If true, also return a list of remaining argument strings. + look_for_args_file: + If true, will look for a ".args" file with the same base name as the entry point script for this + process, and will append its potential content to the command line args. + args_filename: + If not None, will uses this file instead of the ".args" file specified in the previous argument. + args_file_flag: + If not None, will look for a file in the command-line args specified with this flag. The flag can be + specified multiple times and precedence is determined by the order (last one wins). + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer.abspath + - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser + after initialization. + - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) + """ + + if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): + args_files = [] + + if args_filename: + args_files.append(Path(args_filename)) + elif look_for_args_file and len(sys.argv): + args_files.append(Path(sys.argv[0]).with_suffix(".args")) + + # args files specified via command line flag should overwrite default args files so we add them last + if args_file_flag: + # Create special parser just to extract the args_file_flag values + args_file_parser = ArgumentParser() + args_file_parser.add_argument(args_file_flag, type=str, action="append") + + # Use only remaining args for further parsing (remove the args_file_flag) + cfg, args = args_file_parser.parse_known_args(args=args) + cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) + + if cmd_args_file_paths: + args_files.extend([Path(p) for p in cmd_args_file_paths]) + + file_args = [] + for args_file in args_files: + if args_file.exists(): + file_args += args_file.read_text().split() + + # in case of duplicate arguments the last one has precedence + # args specified via the command line should overwrite args from files, so we add them last + args = file_args + args if args is not None else file_args + sys.argv[1:] + namespace, remaining_args = self.parse_known_args(args=args) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in vars(namespace).items() if k in keys} + for k in keys: + delattr(namespace, k) + obj = dtype(**inputs) + outputs.append(obj) + if len(namespace.__dict__) > 0: + # additional namespace. + outputs.append(namespace) + if return_remaining_strings: + return (*outputs, remaining_args) + else: + if remaining_args: + raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") + + return (*outputs,) + + def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass + types. + + Args: + args (`dict`): + dict containing config values + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + unused_keys = set(args.keys()) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in args.items() if k in keys} + unused_keys.difference_update(inputs.keys()) + obj = dtype(**inputs) + outputs.append(obj) + if not allow_extra_keys and unused_keys: + raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") + return tuple(outputs) + + def parse_json_file( + self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False + ) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the + dataclass types. + + Args: + json_file (`str` or `os.PathLike`): + File name of the json file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + with open(Path(json_file), encoding="utf-8") as open_json_file: + data = json.loads(open_json_file.read()) + outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) + return tuple(outputs) + + def parse_yaml_file( + self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False + ) -> Tuple[DataClass, ...]: + """ + Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the + dataclass types. + + Args: + yaml_file (`str` or `os.PathLike`): + File name of the yaml file to parse + allow_extra_keys (`bool`, *optional*, defaults to `False`): + Defaults to False. If False, will raise an exception if the json file contains keys that are not + parsed. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer. + """ + outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) + return tuple(outputs) diff --git a/hyperparameter_search.py b/hyperparameter_search.py new file mode 100644 index 0000000000000000000000000000000000000000..c14165165ca1f92fb28e27b718c8bd81e1bc3a93 --- /dev/null +++ b/hyperparameter_search.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .integrations import ( + is_optuna_available, + is_ray_tune_available, + is_sigopt_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) +from .trainer_utils import ( + HPSearchBackend, + default_hp_space_optuna, + default_hp_space_ray, + default_hp_space_sigopt, + default_hp_space_wandb, +) +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class HyperParamSearchBackendBase: + name: str + pip_package: str = None + + @staticmethod + def is_available(): + raise NotImplementedError + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + raise NotImplementedError + + def default_hp_space(self, trial): + raise NotImplementedError + + def ensure_available(self): + if not self.is_available(): + raise RuntimeError( + f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}." + ) + + @classmethod + def pip_install(cls): + return f"`pip install {cls.pip_package or cls.name}`" + + +class OptunaBackend(HyperParamSearchBackendBase): + name = "optuna" + + @staticmethod + def is_available(): + return is_optuna_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_optuna(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_optuna(trial) + + +class RayTuneBackend(HyperParamSearchBackendBase): + name = "ray" + pip_package = "'ray[tune]'" + + @staticmethod + def is_available(): + return is_ray_tune_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_ray(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_ray(trial) + + +class SigOptBackend(HyperParamSearchBackendBase): + name = "sigopt" + + @staticmethod + def is_available(): + return is_sigopt_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_sigopt(trial) + + +class WandbBackend(HyperParamSearchBackendBase): + name = "wandb" + + @staticmethod + def is_available(): + return is_wandb_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_wandb(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_wandb(trial) + + +ALL_HYPERPARAMETER_SEARCH_BACKENDS = { + HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend] +} + + +def default_hp_search_backend() -> str: + available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()] + if len(available_backends) > 0: + name = available_backends[0].name + if len(available_backends) > 1: + logger.info( + f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default." + ) + return name + raise RuntimeError( + "No hyperparameter search backend available.\n" + + "\n".join( + f" - To install {backend.name} run {backend.pip_install()}" + for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() + ) + ) diff --git a/image_processing_base.py b/image_processing_base.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ce7af3fa8076958d4a4ac87ca3ab13716c7955 --- /dev/null +++ b/image_processing_base.py @@ -0,0 +1,559 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import json +import os +import warnings +from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import numpy as np +import requests + +from .dynamic_module_utils import custom_object_save +from .feature_extraction_utils import BatchFeature as BaseBatchFeature +from .utils import ( + IMAGE_PROCESSOR_NAME, + PushToHubMixin, + add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, + cached_file, + copy_func, + download_url, + is_offline_mode, + is_remote_url, + is_vision_available, + logging, +) + + +if is_vision_available(): + from PIL import Image + + +ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin") + + +logger = logging.get_logger(__name__) + + +# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils +# We override the class string here, but logic is the same. +class BatchFeature(BaseBatchFeature): + r""" + Holds the output of the image processor specific `__call__` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`): + Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + +# TODO: (Amy) - factor out the common parts of this and the feature extractor +class ImageProcessingMixin(PushToHubMixin): + """ + This is an image processor mixin used to provide saving/loading functionality for sequential and image feature + extractors. + """ + + _auto_class = None + + def __init__(self, **kwargs): + """Set elements of `kwargs` as attributes.""" + # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use + # `XXXImageProcessor`, this attribute and its value are misleading. + kwargs.pop("feature_extractor_type", None) + # Pop "processor_class" as it should be saved as private attribute + self._processor_class = kwargs.pop("processor_class", None) + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def _set_processor_class(self, processor_class: str): + """Sets processor class as an attribute.""" + self._processor_class = processor_class + + @classmethod + def from_pretrained( + cls: Type[ImageProcessorType], + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> ImageProcessorType: + r""" + Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained image_processor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a image processor file saved using the + [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved image processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model image processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the image processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final image processor object. If `True`, then this + functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of + `kwargs` which has not been used to update `image_processor` and is otherwise ignored. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (`Dict[str, Any]`, *optional*): + The values in kwargs of any keys which are image processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + Returns: + A image processor of type [`~image_processing_utils.ImageProcessingMixin`]. + + Examples: + + ```python + # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a + # derived class: *CLIPImageProcessor* + image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32" + ) # Download image_processing_config from huggingface.co and cache. + image_processor = CLIPImageProcessor.from_pretrained( + "./test/saved_model/" + ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')* + image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json") + image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32", do_normalize=False, foo=False + ) + assert image_processor.do_normalize is False + image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True + ) + assert image_processor.do_normalize is False + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(image_processor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the + [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the image processor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token", None) is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME) + + self.to_json_file(output_image_processor_file) + logger.info(f"Image processor saved in {output_image_processor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_image_processor_file] + + @classmethod + def get_image_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + image_processor_filename (`str`, *optional*, defaults to `"config.json"`): + The name of the file in the model directory to use for the image processor config. + + Returns: + `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME) + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename) + if os.path.isfile(pretrained_model_name_or_path): + resolved_image_processor_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + image_processor_file = pretrained_model_name_or_path + resolved_image_processor_file = download_url(pretrained_model_name_or_path) + else: + image_processor_file = image_processor_filename + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_image_processor_file = cached_file( + pretrained_model_name_or_path, + image_processor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {image_processor_filename} file" + ) + + try: + # Load image_processor dict + with open(resolved_image_processor_file, "r", encoding="utf-8") as reader: + text = reader.read() + image_processor_dict = json.loads(text) + + except json.JSONDecodeError: + raise EnvironmentError( + f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_image_processor_file}") + else: + logger.info( + f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" + ) + if "auto_map" in image_processor_dict: + image_processor_dict["auto_map"] = add_model_info_to_auto_map( + image_processor_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in image_processor_dict: + image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + image_processor_dict["custom_pipelines"], pretrained_model_name_or_path + ) + + return image_processor_dict, kwargs + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters. + + Args: + image_processor_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the image processor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~image_processing_utils.ImageProcessingMixin.to_dict`] method. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the image processor object. + + Returns: + [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those + parameters. + """ + image_processor_dict = image_processor_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # The `size` parameter is a dict and was previously an int or tuple in feature extractors. + # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate + # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg. + if "size" in kwargs and "size" in image_processor_dict: + image_processor_dict["size"] = kwargs.pop("size") + if "crop_size" in kwargs and "crop_size" in image_processor_dict: + image_processor_dict["crop_size"] = kwargs.pop("crop_size") + + image_processor = cls(**image_processor_dict) + + # Update image_processor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(image_processor, key): + setattr(image_processor, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Image processor {image_processor}") + if return_unused_kwargs: + return image_processor, kwargs + else: + return image_processor + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance. + """ + output = copy.deepcopy(self.__dict__) + output["image_processor_type"] = self.__class__.__name__ + + return output + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]): + """ + Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON + file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object + instantiated from that JSON file. + """ + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + image_processor_dict = json.loads(text) + return cls(**image_processor_dict) + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this image_processor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def register_for_auto_class(cls, auto_class="AutoImageProcessor"): + """ + Register this class with a given auto class. This should only be used for custom image processors as the ones + in the library are already mapped with `AutoImageProcessor `. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`): + The auto class to register this new image processor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def fetch_images(self, image_url_or_urls: Union[str, List[str]]): + """ + Convert a single or a list of urls into the corresponding `PIL.Image` objects. + + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + headers = { + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0" + " Safari/537.36" + ) + } + if isinstance(image_url_or_urls, list): + return [self.fetch_images(x) for x in image_url_or_urls] + elif isinstance(image_url_or_urls, str): + response = requests.get(image_url_or_urls, stream=True, headers=headers) + response.raise_for_status() + return Image.open(BytesIO(response.content)) + else: + raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}") + + +ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub) +if ImageProcessingMixin.push_to_hub.__doc__ is not None: + ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format( + object="image processor", object_class="AutoImageProcessor", object_files="image processor file" + ) diff --git a/image_processing_utils.py b/image_processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0279f26a963e35fc0d3f74a3b669b8a5e1ccf422 --- /dev/null +++ b/image_processing_utils.py @@ -0,0 +1,287 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Iterable, Optional, Union + +import numpy as np + +from .image_processing_base import BatchFeature, ImageProcessingMixin +from .image_transforms import center_crop, normalize, rescale +from .image_utils import ChannelDimension +from .utils import logging + + +logger = logging.get_logger(__name__) + + +INIT_SERVICE_KWARGS = [ + "processor_class", + "image_processor_type", +] + + +class BaseImageProcessor(ImageProcessingMixin): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __call__(self, images, **kwargs) -> BatchFeature: + """Preprocess an image or a batch of images.""" + return self.preprocess(images, **kwargs) + + def preprocess(self, images, **kwargs) -> BatchFeature: + raise NotImplementedError("Each image processor must implement its own preprocess method") + + def rescale( + self, + image: np.ndarray, + scale: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Rescale an image by a scale factor. image = image * scale. + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The rescaled image. + """ + return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs) + + def normalize( + self, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`np.ndarray`): + Image to normalize. + mean (`float` or `Iterable[float]`): + Image mean to use for normalization. + std (`float` or `Iterable[float]`): + Image standard deviation to use for normalization. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The normalized image. + """ + return normalize( + image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + def center_crop( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along + any edge, the image is padded with 0's and then center cropped. + + Args: + image (`np.ndarray`): + Image to center crop. + size (`Dict[str, int]`): + Size of the output image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + return center_crop( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def to_dict(self): + encoder_dict = super().to_dict() + encoder_dict.pop("_valid_processor_keys", None) + return encoder_dict + + +VALID_SIZE_DICT_KEYS = ( + {"height", "width"}, + {"shortest_edge"}, + {"shortest_edge", "longest_edge"}, + {"longest_edge"}, + {"max_height", "max_width"}, +) + + +def is_valid_size_dict(size_dict): + if not isinstance(size_dict, dict): + return False + + size_dict_keys = set(size_dict.keys()) + for allowed_keys in VALID_SIZE_DICT_KEYS: + if size_dict_keys == allowed_keys: + return True + return False + + +def convert_to_size_dict( + size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True +): + # By default, if size is an int we assume it represents a tuple of (size, size). + if isinstance(size, int) and default_to_square: + if max_size is not None: + raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size") + return {"height": size, "width": size} + # In other configs, if size is an int and default_to_square is False, size represents the length of + # the shortest edge after resizing. + elif isinstance(size, int) and not default_to_square: + size_dict = {"shortest_edge": size} + if max_size is not None: + size_dict["longest_edge"] = max_size + return size_dict + # Otherwise, if size is a tuple it's either (height, width) or (width, height) + elif isinstance(size, (tuple, list)) and height_width_order: + return {"height": size[0], "width": size[1]} + elif isinstance(size, (tuple, list)) and not height_width_order: + return {"height": size[1], "width": size[0]} + elif size is None and max_size is not None: + if default_to_square: + raise ValueError("Cannot specify both default_to_square=True and max_size") + return {"longest_edge": max_size} + + raise ValueError(f"Could not convert size input to size dict: {size}") + + +def get_size_dict( + size: Union[int, Iterable[int], Dict[str, int]] = None, + max_size: Optional[int] = None, + height_width_order: bool = True, + default_to_square: bool = True, + param_name="size", +) -> dict: + """ + Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards + compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height, + width) or (width, height) format. + + - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width": + size[0]}` if `height_width_order` is `False`. + - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`. + - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size` + is set, it is added to the dict as `{"longest_edge": max_size}`. + + Args: + size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*): + The `size` parameter to be cast into a size dictionary. + max_size (`Optional[int]`, *optional*): + The `max_size` parameter to be cast into a size dictionary. + height_width_order (`bool`, *optional*, defaults to `True`): + If `size` is a tuple, whether it's in (height, width) or (width, height) order. + default_to_square (`bool`, *optional*, defaults to `True`): + If `size` is an int, whether to default to a square image or not. + """ + if not isinstance(size, dict): + size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order) + logger.info( + f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}." + f" Converted to {size_dict}.", + ) + else: + size_dict = size + + if not is_valid_size_dict(size_dict): + raise ValueError( + f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}" + ) + return size_dict + + +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + This is done by calculating the effective and wasted resolution for each possible resolution. + + The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. + + Args: + original_size (tuple): + The original size of the image in the format (height, width). + possible_resolutions (list): + A list of possible resolutions in the format [(height1, width1), (height2, width2), ...]. + + Returns: + tuple: The best fit resolution in the format (height, width). + """ + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit diff --git a/image_processing_utils_fast.py b/image_processing_utils_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1be325b7eb304060e3ed8aa28981619c677129 --- /dev/null +++ b/image_processing_utils_fast.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Tuple + +from .image_processing_utils import BaseImageProcessor +from .utils.import_utils import is_torch_available, is_torchvision_available + + +if is_torchvision_available(): + from torchvision.transforms import Compose + +if is_torch_available(): + import torch + + +@dataclass(frozen=True) +class SizeDict: + """ + Hashable dictionary to store image size information. + """ + + height: int = None + width: int = None + longest_edge: int = None + shortest_edge: int = None + max_height: int = None + max_width: int = None + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise KeyError(f"Key {key} not found in SizeDict.") + + +class BaseImageProcessorFast(BaseImageProcessor): + _transform_params = None + + def _build_transforms(self, **kwargs) -> "Compose": + """ + Given the input settings e.g. do_resize, build the image transforms. + """ + raise NotImplementedError + + def _validate_params(self, **kwargs) -> None: + for k, v in kwargs.items(): + if k not in self._transform_params: + raise ValueError(f"Invalid transform parameter {k}={v}.") + + @functools.lru_cache(maxsize=1) + def get_transforms(self, **kwargs) -> "Compose": + self._validate_params(**kwargs) + return self._build_transforms(**kwargs) + + def to_dict(self): + encoder_dict = super().to_dict() + encoder_dict.pop("_transform_params", None) + return encoder_dict + + +def get_image_size_for_max_height_width( + image_size: Tuple[int, int], + max_height: int, + max_width: int, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + image_size (`Tuple[int, int]`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + """ + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": + """ + Squeezes a tensor, but only if the axis specified has dim 1. + """ + if axis is None: + return tensor.squeeze() + + try: + return tensor.squeeze(axis=axis) + except ValueError: + return tensor + + +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]: + """ + Get the maximum height and width across all images in a batch. + """ + + _, max_height, max_width = max_across_indices([img.shape for img in images]) + + return (max_height, max_width) diff --git a/image_transforms.py b/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d3a5abb7a8db634ef1f1f19ea57219f14457b4 --- /dev/null +++ b/image_transforms.py @@ -0,0 +1,860 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from math import ceil +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np + +from .image_utils import ( + ChannelDimension, + ImageInput, + get_channel_dimension_axis, + get_image_size, + infer_channel_dimension_format, +) +from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor +from .utils.import_utils import ( + is_flax_available, + is_tf_available, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + requires_backends, +) + + +if is_vision_available(): + import PIL + + from .image_utils import PILImageResampling + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax.numpy as jnp + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +def to_channel_dimension_format( + image: np.ndarray, + channel_dim: Union[ChannelDimension, str], + input_channel_dim: Optional[Union[ChannelDimension, str]] = None, +) -> np.ndarray: + """ + Converts `image` to the channel dimension format specified by `channel_dim`. + + Args: + image (`numpy.ndarray`): + The image to have its channel dimension set. + channel_dim (`ChannelDimension`): + The channel dimension format to use. + input_channel_dim (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + + Returns: + `np.ndarray`: The image with the channel dimension set to `channel_dim`. + """ + if not isinstance(image, np.ndarray): + raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") + + if input_channel_dim is None: + input_channel_dim = infer_channel_dimension_format(image) + + target_channel_dim = ChannelDimension(channel_dim) + if input_channel_dim == target_channel_dim: + return image + + if target_channel_dim == ChannelDimension.FIRST: + image = image.transpose((2, 0, 1)) + elif target_channel_dim == ChannelDimension.LAST: + image = image.transpose((1, 2, 0)) + else: + raise ValueError("Unsupported channel dimension format: {}".format(channel_dim)) + + return image + + +def rescale( + image: np.ndarray, + scale: float, + data_format: Optional[ChannelDimension] = None, + dtype: np.dtype = np.float32, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Rescales `image` by `scale`. + + Args: + image (`np.ndarray`): + The image to rescale. + scale (`float`): + The scale to use for rescaling the image. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + dtype (`np.dtype`, *optional*, defaults to `np.float32`): + The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature + extractors. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + + Returns: + `np.ndarray`: The rescaled image. + """ + if not isinstance(image, np.ndarray): + raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") + + rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first + if data_format is not None: + rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format) + + rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end + + return rescaled_image + + +def _rescale_for_pil_conversion(image): + """ + Detects whether or not the image needs to be rescaled before being converted to a PIL image. + + The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be + rescaled. + """ + if image.dtype == np.uint8: + do_rescale = False + elif np.allclose(image, image.astype(int)): + if np.all(0 <= image) and np.all(image <= 255): + do_rescale = False + else: + raise ValueError( + "The image to be converted to a PIL image contains values outside the range [0, 255], " + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) + elif np.all(0 <= image) and np.all(image <= 1): + do_rescale = True + else: + raise ValueError( + "The image to be converted to a PIL image contains values outside the range [0, 1], " + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) + return do_rescale + + +def to_pil_image( + image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"], + do_rescale: Optional[bool] = None, + image_mode: Optional[str] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> "PIL.Image.Image": + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`): + The image to convert to the `PIL.Image` format. + do_rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default + to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, + and `False` otherwise. + image_mode (`str`, *optional*): + The mode to use for the PIL image. If unset, will use the default mode for the input image type. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `PIL.Image.Image`: The converted image. + """ + requires_backends(to_pil_image, ["vision"]) + + if isinstance(image, PIL.Image.Image): + return image + + # Convert all tensors to numpy arrays before converting to PIL image + if is_torch_tensor(image) or is_tf_tensor(image): + image = image.numpy() + elif is_jax_tensor(image): + image = np.array(image) + elif not isinstance(image, np.ndarray): + raise ValueError("Input image type not supported: {}".format(type(image))) + + # If the channel has been moved to first dim, we put it back at the end. + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) + + # If there is a single channel, we squeeze it, as otherwise PIL can't handle it. + image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image + + # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed. + do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale + + if do_rescale: + image = rescale(image, 255) + + image = image.astype(np.uint8) + return PIL.Image.fromarray(image, mode=image_mode) + + +# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366 +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + default_to_square: bool = True, + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `Tuple[int, int]` or List[int] or `Tuple[int]`): + The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to + this. + + If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If + `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this + number. i.e, if height > width, then image will be rescaled to (size * height / width, size). + default_to_square (`bool`, *optional*, defaults to `True`): + How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square + (`size`,`size`). If set to `False`, will replicate + [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) + with support for resizing only the smallest edge and providing an optional `max_size`. + max_size (`int`, *optional*): + The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater + than `max_size` after being resized according to `size`, then the image is resized again so that the longer + edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter + than `size`. Only used if `default_to_square` is `False`. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + if isinstance(size, (tuple, list)): + if len(size) == 2: + return tuple(size) + elif len(size) == 1: + # Perform same logic as if size was an int + size = size[0] + else: + raise ValueError("size must have 1 or 2 elements if it is a list or tuple") + + if default_to_square: + return (size, size) + + height, width = get_image_size(input_image, input_data_format) + short, long = (width, height) if width <= height else (height, width) + requested_new_short = size + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + return (new_long, new_short) if width <= height else (new_short, new_long) + + +def resize( + image: np.ndarray, + size: Tuple[int, int], + resample: "PILImageResampling" = None, + reducing_gap: Optional[int] = None, + data_format: Optional[ChannelDimension] = None, + return_numpy: bool = True, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Resizes `image` to `(height, width)` specified by `size` using the PIL library. + + Args: + image (`np.ndarray`): + The image to resize. + size (`Tuple[int, int]`): + The size to use for resizing the image. + resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): + The filter to user for resampling. + reducing_gap (`int`, *optional*): + Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to + the fair resampling. See corresponding Pillow documentation for more details. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If unset, will use the inferred format from the input. + return_numpy (`bool`, *optional*, defaults to `True`): + Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is + returned. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `np.ndarray`: The resized image. + """ + requires_backends(resize, ["vision"]) + + resample = resample if resample is not None else PILImageResampling.BILINEAR + + if not len(size) == 2: + raise ValueError("size must have 2 elements") + + # For all transformations, we want to keep the same data format as the input image unless otherwise specified. + # The resized image from PIL will always have channels last, so find the input format first. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + data_format = input_data_format if data_format is None else data_format + + # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use + # the pillow library to resize the image and then convert back to numpy + do_rescale = False + if not isinstance(image, PIL.Image.Image): + do_rescale = _rescale_for_pil_conversion(image) + image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format) + height, width = size + # PIL images are in the format (width, height) + resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap) + + if return_numpy: + resized_image = np.array(resized_image) + # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image + # so we need to add it back if necessary. + resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image + # The image is always in channels last format after converting from a PIL image + resized_image = to_channel_dimension_format( + resized_image, data_format, input_channel_dim=ChannelDimension.LAST + ) + # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to + # rescale it back to the original range. + resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image + return resized_image + + +def normalize( + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Normalizes `image` using the mean and standard deviation specified by `mean` and `std`. + + image = (image - mean) / std + + Args: + image (`np.ndarray`): + The image to normalize. + mean (`float` or `Iterable[float]`): + The mean to use for normalization. + std (`float` or `Iterable[float]`): + The standard deviation to use for normalization. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If unset, will use the inferred format from the input. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + """ + if not isinstance(image, np.ndarray): + raise ValueError("image must be a numpy array") + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format) + num_channels = image.shape[channel_axis] + + # We cast to float32 to avoid errors that can occur when subtracting uint8 values. + # We preserve the original dtype if it is a float type to prevent upcasting float16. + if not np.issubdtype(image.dtype, np.floating): + image = image.astype(np.float32) + + if isinstance(mean, Iterable): + if len(mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") + else: + mean = [mean] * num_channels + mean = np.array(mean, dtype=image.dtype) + + if isinstance(std, Iterable): + if len(std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") + else: + std = [std] * num_channels + std = np.array(std, dtype=image.dtype) + + if input_data_format == ChannelDimension.LAST: + image = (image - mean) / std + else: + image = ((image.T - mean) / std).T + + image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image + return image + + +def center_crop( + image: np.ndarray, + size: Tuple[int, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + return_numpy: Optional[bool] = None, +) -> np.ndarray: + """ + Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to + the size given, it will be padded (so the returned result will always be of size `size`). + + Args: + image (`np.ndarray`): + The image to crop. + size (`Tuple[int, int]`): + The target size for the cropped image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + return_numpy (`bool`, *optional*): + Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the + previous ImageFeatureExtractionMixin method. + - Unset: will return the same type as the input image. + - `True`: will return a numpy array. + - `False`: will return a `PIL.Image.Image` object. + Returns: + `np.ndarray`: The cropped image. + """ + requires_backends(center_crop, ["vision"]) + + if return_numpy is not None: + warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning) + + return_numpy = True if return_numpy is None else return_numpy + + if not isinstance(image, np.ndarray): + raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") + + if not isinstance(size, Iterable) or len(size) != 2: + raise ValueError("size must have 2 elements representing the height and width of the output image") + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + output_data_format = data_format if data_format is not None else input_data_format + + # We perform the crop in (C, H, W) format and then convert to the output format + image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) + + orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST) + crop_height, crop_width = size + crop_height, crop_width = int(crop_height), int(crop_width) + + # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + top = (orig_height - crop_height) // 2 + bottom = top + crop_height + # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + left = (orig_width - crop_width) // 2 + right = left + crop_width + + # Check if cropped area is within image boundaries + if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width: + image = image[..., top:bottom, left:right] + image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST) + return image + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_height = max(crop_height, orig_height) + new_width = max(crop_width, orig_width) + new_shape = image.shape[:-2] + (new_height, new_width) + new_image = np.zeros_like(image, shape=new_shape) + + # If the image is too small, pad it with zeros + top_pad = ceil((new_height - orig_height) / 2) + bottom_pad = top_pad + orig_height + left_pad = ceil((new_width - orig_width) / 2) + right_pad = left_pad + orig_width + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] + new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST) + + if not return_numpy: + new_image = to_pil_image(new_image) + + return new_image + + +def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor": + center_x, center_y, width, height = bboxes_center.unbind(-1) + bbox_corners = torch.stack( + # top left x, top left y, bottom right x, bottom right y + [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], + dim=-1, + ) + return bbox_corners + + +def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: + center_x, center_y, width, height = bboxes_center.T + bboxes_corners = np.stack( + # top left x, top left y, bottom right x, bottom right y + [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], + axis=-1, + ) + return bboxes_corners + + +def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor": + center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1) + bboxes_corners = tf.stack( + # top left x, top left y, bottom right x, bottom right y + [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], + axis=-1, + ) + return bboxes_corners + + +# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py +def center_to_corners_format(bboxes_center: TensorType) -> TensorType: + """ + Converts bounding boxes from center format to corners format. + + center format: contains the coordinate for the center of the box and its width, height dimensions + (center_x, center_y, width, height) + corners format: contains the coodinates for the top-left and bottom-right corners of the box + (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + """ + # Function is used during model forward pass, so we use the input framework if possible, without + # converting to numpy + if is_torch_tensor(bboxes_center): + return _center_to_corners_format_torch(bboxes_center) + elif isinstance(bboxes_center, np.ndarray): + return _center_to_corners_format_numpy(bboxes_center) + elif is_tf_tensor(bboxes_center): + return _center_to_corners_format_tf(bboxes_center) + + raise ValueError(f"Unsupported input type {type(bboxes_center)}") + + +def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor": + top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1) + b = [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ] + return torch.stack(b, dim=-1) + + +def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: + top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T + bboxes_center = np.stack( + [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ], + axis=-1, + ) + return bboxes_center + + +def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor": + top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1) + bboxes_center = tf.stack( + [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ], + axis=-1, + ) + return bboxes_center + + +def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: + """ + Converts bounding boxes from corners format to center format. + + corners format: contains the coordinates for the top-left and bottom-right corners of the box + (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + center format: contains the coordinate for the center of the box and its the width, height dimensions + (center_x, center_y, width, height) + """ + # Inverse function accepts different input types so implemented here too + if is_torch_tensor(bboxes_corners): + return _corners_to_center_format_torch(bboxes_corners) + elif isinstance(bboxes_corners, np.ndarray): + return _corners_to_center_format_numpy(bboxes_corners) + elif is_tf_tensor(bboxes_corners): + return _corners_to_center_format_tf(bboxes_corners) + + raise ValueError(f"Unsupported input type {type(bboxes_corners)}") + + +# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py +# Copyright (c) 2018, Alexander Kirillov +# All rights reserved. +def rgb_to_id(color): + """ + Converts RGB color to unique ID. + """ + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + +def id_to_rgb(id_map): + """ + Converts unique ID to RGB color. + """ + if isinstance(id_map, np.ndarray): + id_map_copy = id_map.copy() + rgb_shape = tuple(list(id_map.shape) + [3]) + rgb_map = np.zeros(rgb_shape, dtype=np.uint8) + for i in range(3): + rgb_map[..., i] = id_map_copy % 256 + id_map_copy //= 256 + return rgb_map + color = [] + for _ in range(3): + color.append(id_map % 256) + id_map //= 256 + return color + + +class PaddingMode(ExplicitEnum): + """ + Enum class for the different padding modes to use when padding images. + """ + + CONSTANT = "constant" + REFLECT = "reflect" + REPLICATE = "replicate" + SYMMETRIC = "symmetric" + + +def pad( + image: np.ndarray, + padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Pads the `image` with the specified (height, width) `padding` and `mode`. + + Args: + image (`np.ndarray`): + The image to pad. + padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + def _expand_for_data_format(values): + """ + Convert values to be in the format expected by np.pad based on the data format. + """ + if isinstance(values, (int, float)): + values = ((values, values), (values, values)) + elif isinstance(values, tuple) and len(values) == 1: + values = ((values[0], values[0]), (values[0], values[0])) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int): + values = (values, values) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple): + values = values + else: + raise ValueError(f"Unsupported format: {values}") + + # add 0 for channel dimension + values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0)) + + # Add additional padding if there's a batch dimension + values = (0, *values) if image.ndim == 4 else values + return values + + padding = _expand_for_data_format(padding) + + if mode == PaddingMode.CONSTANT: + constant_values = _expand_for_data_format(constant_values) + image = np.pad(image, padding, mode="constant", constant_values=constant_values) + elif mode == PaddingMode.REFLECT: + image = np.pad(image, padding, mode="reflect") + elif mode == PaddingMode.REPLICATE: + image = np.pad(image, padding, mode="edge") + elif mode == PaddingMode.SYMMETRIC: + image = np.pad(image, padding, mode="symmetric") + else: + raise ValueError(f"Invalid padding mode: {mode}") + + image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image + return image + + +# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default +def convert_to_rgb(image: ImageInput) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + Args: + image (Image): + The image to convert. + """ + requires_backends(convert_to_rgb, ["vision"]) + + if not isinstance(image, PIL.Image.Image): + return image + + if image.mode == "RGB": + return image + + image = image.convert("RGB") + return image + + +def flip_channel_order( + image: np.ndarray, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Flips the channel order of the image. + + If the image is in RGB format, it will be converted to BGR and vice versa. + + Args: + image (`np.ndarray`): + The image to flip. + data_format (`ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + """ + input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format + + if input_data_format == ChannelDimension.LAST: + image = image[..., ::-1] + elif input_data_format == ChannelDimension.FIRST: + image = image[::-1, ...] + else: + raise ValueError(f"Unsupported channel dimension: {input_data_format}") + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + +def _cast_tensor_to_float(x): + if x.is_floating_point(): + return x + return x.float() + + +class FusedRescaleNormalize: + """ + Rescale and normalize the input image in one step. + """ + + def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False): + self.mean = torch.tensor(mean) * (1.0 / rescale_factor) + self.std = torch.tensor(std) * (1.0 / rescale_factor) + self.inplace = inplace + + def __call__(self, image: "torch.Tensor"): + image = _cast_tensor_to_float(image) + return F.normalize(image, self.mean, self.std, inplace=self.inplace) + + +class Rescale: + """ + Rescale the input image by rescale factor: image *= rescale_factor. + """ + + def __init__(self, rescale_factor: float = 1.0): + self.rescale_factor = rescale_factor + + def __call__(self, image: "torch.Tensor"): + image = image * self.rescale_factor + return image + + +class NumpyToTensor: + """ + Convert a numpy array to a PyTorch tensor. + """ + + def __call__(self, image: np.ndarray): + # Same as in PyTorch, we assume incoming numpy images are in HWC format + # c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154 + return torch.from_numpy(image.transpose(2, 0, 1)).contiguous() diff --git a/image_utils.py b/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51199d9f3698fc6212b5f8b3c90144fbf147ad41 --- /dev/null +++ b/image_utils.py @@ -0,0 +1,871 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import os +from io import BytesIO +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import requests +from packaging import version + +from .utils import ( + ExplicitEnum, + TensorType, + is_jax_tensor, + is_numpy_array, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + is_torchvision_available, + is_vision_available, + logging, + requires_backends, + to_numpy, +) +from .utils.constants import ( # noqa: F401 + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, +) + + +if is_vision_available(): + import PIL.Image + import PIL.ImageOps + + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PILImageResampling = PIL.Image.Resampling + else: + PILImageResampling = PIL.Image + + if is_torchvision_available(): + from torchvision.transforms import InterpolationMode + + pil_torch_interpolation_mapping = { + PILImageResampling.NEAREST: InterpolationMode.NEAREST, + PILImageResampling.BOX: InterpolationMode.BOX, + PILImageResampling.BILINEAR: InterpolationMode.BILINEAR, + PILImageResampling.HAMMING: InterpolationMode.HAMMING, + PILImageResampling.BICUBIC: InterpolationMode.BICUBIC, + PILImageResampling.LANCZOS: InterpolationMode.LANCZOS, + } + + +if TYPE_CHECKING: + if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +ImageInput = Union[ + "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"] +] # noqa + + +VideoInput = Union[ + List["PIL.Image.Image"], + "np.ndarray", + "torch.Tensor", + List["np.ndarray"], + List["torch.Tensor"], + List[List["PIL.Image.Image"]], + List[List["np.ndarrray"]], + List[List["torch.Tensor"]], +] # noqa + + +class ChannelDimension(ExplicitEnum): + FIRST = "channels_first" + LAST = "channels_last" + + +class AnnotationFormat(ExplicitEnum): + COCO_DETECTION = "coco_detection" + COCO_PANOPTIC = "coco_panoptic" + + +class AnnotionFormat(ExplicitEnum): + COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value + COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value + + +AnnotationType = Dict[str, Union[int, str, List[Dict]]] + + +def is_pil_image(img): + return is_vision_available() and isinstance(img, PIL.Image.Image) + + +class ImageType(ExplicitEnum): + PIL = "pillow" + TORCH = "torch" + NUMPY = "numpy" + TENSORFLOW = "tensorflow" + JAX = "jax" + + +def get_image_type(image): + if is_pil_image(image): + return ImageType.PIL + if is_torch_tensor(image): + return ImageType.TORCH + if is_numpy_array(image): + return ImageType.NUMPY + if is_tf_tensor(image): + return ImageType.TENSORFLOW + if is_jax_tensor(image): + return ImageType.JAX + raise ValueError(f"Unrecognised image type {type(image)}") + + +def is_valid_image(img): + return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img) + + +def valid_images(imgs): + # If we have an list of images, make sure every image is valid + if isinstance(imgs, (list, tuple)): + for img in imgs: + if not valid_images(img): + return False + # If not a list of tuple, we have been given a single image or batched tensor of images + elif not is_valid_image(imgs): + return False + return True + + +def is_batched(img): + if isinstance(img, (list, tuple)): + return is_valid_image(img[0]) + return False + + +def is_scaled_image(image: np.ndarray) -> bool: + """ + Checks to see whether the pixel values have already been rescaled to [0, 1]. + """ + if image.dtype == np.uint8: + return False + + # It's possible the image has pixel values in [0, 255] but is of floating type + return np.min(image) >= 0 and np.max(image) <= 1 + + +def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: + """ + Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1. + If the input is a batch of images, it is converted to a list of images. + + Args: + images (`ImageInput`): + Image of images to turn into a list of images. + expected_ndims (`int`, *optional*, defaults to 3): + Expected number of dimensions for a single input image. If the input image has a different number of + dimensions, an error is raised. + """ + if is_batched(images): + return images + + # Either the input is a single image, in which case we create a list of length 1 + if isinstance(images, PIL.Image.Image): + # PIL images are never batched + return [images] + + if is_valid_image(images): + if images.ndim == expected_ndims + 1: + # Batch of images + images = list(images) + elif images.ndim == expected_ndims: + # Single image + images = [images] + else: + raise ValueError( + f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got" + f" {images.ndim} dimensions." + ) + return images + raise ValueError( + "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " + f"jax.ndarray, but got {type(images)}." + ) + + +def to_numpy_array(img) -> np.ndarray: + if not is_valid_image(img): + raise ValueError(f"Invalid image type: {type(img)}") + + if is_vision_available() and isinstance(img, PIL.Image.Image): + return np.array(img) + return to_numpy(img) + + +def infer_channel_dimension_format( + image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None +) -> ChannelDimension: + """ + Infers the channel dimension format of `image`. + + Args: + image (`np.ndarray`): + The image to infer the channel dimension of. + num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`): + The number of channels of the image. + + Returns: + The channel dimension of the image. + """ + num_channels = num_channels if num_channels is not None else (1, 3) + num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels + + if image.ndim == 3: + first_dim, last_dim = 0, 2 + elif image.ndim == 4: + first_dim, last_dim = 1, 3 + else: + raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") + + if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels: + logger.warning( + f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension." + ) + return ChannelDimension.FIRST + elif image.shape[first_dim] in num_channels: + return ChannelDimension.FIRST + elif image.shape[last_dim] in num_channels: + return ChannelDimension.LAST + raise ValueError("Unable to infer channel dimension format") + + +def get_channel_dimension_axis( + image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None +) -> int: + """ + Returns the channel dimension axis of the image. + + Args: + image (`np.ndarray`): + The image to get the channel dimension axis of. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the image. If `None`, will infer the channel dimension from the image. + + Returns: + The channel dimension axis of the image. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + if input_data_format == ChannelDimension.FIRST: + return image.ndim - 3 + elif input_data_format == ChannelDimension.LAST: + return image.ndim - 1 + raise ValueError(f"Unsupported data format: {input_data_format}") + + +def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: + """ + Returns the (height, width) dimensions of the image. + + Args: + image (`np.ndarray`): + The image to get the dimensions of. + channel_dim (`ChannelDimension`, *optional*): + Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image. + + Returns: + A tuple of the image's height and width. + """ + if channel_dim is None: + channel_dim = infer_channel_dimension_format(image) + + if channel_dim == ChannelDimension.FIRST: + return image.shape[-2], image.shape[-1] + elif channel_dim == ChannelDimension.LAST: + return image.shape[-3], image.shape[-2] + else: + raise ValueError(f"Unsupported data format: {channel_dim}") + + +def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool: + if ( + isinstance(annotation, dict) + and "image_id" in annotation + and "annotations" in annotation + and isinstance(annotation["annotations"], (list, tuple)) + and ( + # an image can have no annotations + len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict) + ) + ): + return True + return False + + +def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool: + if ( + isinstance(annotation, dict) + and "image_id" in annotation + and "segments_info" in annotation + and "file_name" in annotation + and isinstance(annotation["segments_info"], (list, tuple)) + and ( + # an image can have no segments + len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict) + ) + ): + return True + return False + + +def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool: + return all(is_valid_annotation_coco_detection(ann) for ann in annotations) + + +def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool: + return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations) + + +def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image": + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + timeout (`float`, *optional*): + The timeout value in seconds for the URL request. + + Returns: + `PIL.Image.Image`: A PIL Image. + """ + requires_backends(load_image, ["vision"]) + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + # We need to actually check for a real protocol, otherwise it's impossible to use a local file + # like http_huggingface_co.png + image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content)) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + if image.startswith("data:image/"): + image = image.split(",")[1] + + # Try to load as base64 + try: + b64 = base64.decodebytes(image.encode()) + image = PIL.Image.open(BytesIO(b64)) + except Exception as e: + raise ValueError( + f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise TypeError( + "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def load_images( + images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None +) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]: + """Loads images, handling different levels of nesting. + + Args: + images: A single image, a list of images, or a list of lists of images to load. + timeout: Timeout for loading images. + + Returns: + A single image, a list of images, a list of lists of images. + """ + if isinstance(images, (list, tuple)): + if len(images) and isinstance(images[0], (list, tuple)): + return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images] + else: + return [load_image(image, timeout=timeout) for image in images] + else: + return load_image(images, timeout=timeout) + + +def validate_preprocess_arguments( + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[Dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, +): + """ + Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`, + sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow + existing arguments when possible. + + """ + if do_rescale and rescale_factor is None: + raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.") + + if do_pad and size_divisibility is None: + # Here, size_divisor might be passed as the value of size + raise ValueError( + "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`." + ) + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.") + + if do_center_crop and crop_size is None: + raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.") + + if do_resize and (size is None or resample is None): + raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.") + + +def validate_fast_preprocess_arguments( + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[Dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, +): + """ + Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + """ + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # Extra checks for ImageProcessorFast + if return_tensors != "pt": + raise ValueError("Only returning PyTorch tensors is currently supported.") + + if data_format != ChannelDimension.FIRST: + raise ValueError("Only channel first data format is currently supported.") + + +# In the future we can add a TF implementation here when we have TF models. +class ImageFeatureExtractionMixin: + """ + Mixin that contain utilities for preparing image features. + """ + + def _ensure_format_supported(self, image): + if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image): + raise ValueError( + f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and " + "`torch.Tensor` are." + ) + + def to_pil_image(self, image, rescale=None): + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): + The image to convert to the PIL Image format. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will + default to `True` if the image type is a floating type, `False` otherwise. + """ + self._ensure_format_supported(image) + + if is_torch_tensor(image): + image = image.numpy() + + if isinstance(image, np.ndarray): + if rescale is None: + # rescale default to the array being of floating type. + rescale = isinstance(image.flat[0], np.floating) + # If the channel as been moved to first dim, we put it back at the end. + if image.ndim == 3 and image.shape[0] in [1, 3]: + image = image.transpose(1, 2, 0) + if rescale: + image = image * 255 + image = image.astype(np.uint8) + return PIL.Image.fromarray(image) + return image + + def convert_rgb(self, image): + """ + Converts `PIL.Image.Image` to RGB format. + + Args: + image (`PIL.Image.Image`): + The image to convert. + """ + self._ensure_format_supported(image) + if not isinstance(image, PIL.Image.Image): + return image + + return image.convert("RGB") + + def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: + """ + Rescale a numpy image by scale amount + """ + self._ensure_format_supported(image) + return image * scale + + def to_numpy_array(self, image, rescale=None, channel_first=True): + """ + Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first + dimension. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to convert to a NumPy array. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will + default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise. + channel_first (`bool`, *optional*, defaults to `True`): + Whether or not to permute the dimensions of the image to put the channel dimension first. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = np.array(image) + + if is_torch_tensor(image): + image = image.numpy() + + rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale + + if rescale: + image = self.rescale(image.astype(np.float32), 1 / 255.0) + + if channel_first and image.ndim == 3: + image = image.transpose(2, 0, 1) + + return image + + def expand_dims(self, image): + """ + Expands 2-dimensional `image` to 3 dimensions. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to expand. + """ + self._ensure_format_supported(image) + + # Do nothing if PIL image + if isinstance(image, PIL.Image.Image): + return image + + if is_torch_tensor(image): + image = image.unsqueeze(0) + else: + image = np.expand_dims(image, axis=0) + return image + + def normalize(self, image, mean, std, rescale=False): + """ + Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array + if it's a PIL Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to normalize. + mean (`List[float]` or `np.ndarray` or `torch.Tensor`): + The mean (per channel) to use for normalization. + std (`List[float]` or `np.ndarray` or `torch.Tensor`): + The standard deviation (per channel) to use for normalization. + rescale (`bool`, *optional*, defaults to `False`): + Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will + happen automatically. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = self.to_numpy_array(image, rescale=True) + # If the input image is a PIL image, it automatically gets rescaled. If it's another + # type it may need rescaling. + elif rescale: + if isinstance(image, np.ndarray): + image = self.rescale(image.astype(np.float32), 1 / 255.0) + elif is_torch_tensor(image): + image = self.rescale(image.float(), 1 / 255.0) + + if isinstance(image, np.ndarray): + if not isinstance(mean, np.ndarray): + mean = np.array(mean).astype(image.dtype) + if not isinstance(std, np.ndarray): + std = np.array(std).astype(image.dtype) + elif is_torch_tensor(image): + import torch + + if not isinstance(mean, torch.Tensor): + if isinstance(mean, np.ndarray): + mean = torch.from_numpy(mean) + else: + mean = torch.tensor(mean) + if not isinstance(std, torch.Tensor): + if isinstance(std, np.ndarray): + std = torch.from_numpy(std) + else: + std = torch.tensor(std) + + if image.ndim == 3 and image.shape[0] in [1, 3]: + return (image - mean[:, None, None]) / std[:, None, None] + else: + return (image - mean) / std + + def resize(self, image, size, resample=None, default_to_square=True, max_size=None): + """ + Resizes `image`. Enforces conversion of input to PIL.Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to resize. + size (`int` or `Tuple[int, int]`): + The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be + matched to this. + + If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If + `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to + this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). + resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): + The filter to user for resampling. + default_to_square (`bool`, *optional*, defaults to `True`): + How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a + square (`size`,`size`). If set to `False`, will replicate + [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) + with support for resizing only the smallest edge and providing an optional `max_size`. + max_size (`int`, *optional*, defaults to `None`): + The maximum allowed for the longer edge of the resized image: if the longer edge of the image is + greater than `max_size` after being resized according to `size`, then the image is resized again so + that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller + edge may be shorter than `size`. Only used if `default_to_square` is `False`. + + Returns: + image: A resized `PIL.Image.Image`. + """ + resample = resample if resample is not None else PILImageResampling.BILINEAR + + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + if isinstance(size, list): + size = tuple(size) + + if isinstance(size, int) or len(size) == 1: + if default_to_square: + size = (size, size) if isinstance(size, int) else (size[0], size[0]) + else: + width, height = image.size + # specified size only for the smallest edge + short, long = (width, height) if width <= height else (height, width) + requested_new_short = size if isinstance(size, int) else size[0] + + if short == requested_new_short: + return image + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + size = (new_short, new_long) if width <= height else (new_long, new_short) + + return image.resize(size, resample=resample) + + def center_crop(self, image, size): + """ + Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the + size given, it will be padded (so the returned result has the size asked). + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)): + The image to resize. + size (`int` or `Tuple[int, int]`): + The size to which crop the image. + + Returns: + new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels, + height, width). + """ + self._ensure_format_supported(image) + + if not isinstance(size, tuple): + size = (size, size) + + # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width) + if is_torch_tensor(image) or isinstance(image, np.ndarray): + if image.ndim == 2: + image = self.expand_dims(image) + image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2] + else: + image_shape = (image.size[1], image.size[0]) + + top = (image_shape[0] - size[0]) // 2 + bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + left = (image_shape[1] - size[1]) // 2 + right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + + # For PIL Images we have a method to crop directly. + if isinstance(image, PIL.Image.Image): + return image.crop((left, top, right, bottom)) + + # Check if image is in (n_channels, height, width) or (height, width, n_channels) format + channel_first = True if image.shape[0] in [1, 3] else False + + # Transpose (height, width, n_channels) format images + if not channel_first: + if isinstance(image, np.ndarray): + image = image.transpose(2, 0, 1) + if is_torch_tensor(image): + image = image.permute(2, 0, 1) + + # Check if cropped area is within image boundaries + if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]: + return image[..., top:bottom, left:right] + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1])) + if isinstance(image, np.ndarray): + new_image = np.zeros_like(image, shape=new_shape) + elif is_torch_tensor(image): + new_image = image.new_zeros(new_shape) + + top_pad = (new_shape[-2] - image_shape[0]) // 2 + bottom_pad = top_pad + image_shape[0] + left_pad = (new_shape[-1] - image_shape[1]) // 2 + right_pad = left_pad + image_shape[1] + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + new_image = new_image[ + ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right) + ] + + return new_image + + def flip_channel_order(self, image): + """ + Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of + `image` to a NumPy array if it's a PIL Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should + be first. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = self.to_numpy_array(image) + + return image[::-1, :, :] + + def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None): + """ + Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees + counter clockwise around its centre. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before + rotating. + + Returns: + image: A rotated `PIL.Image.Image`. + """ + resample = resample if resample is not None else PIL.Image.NEAREST + + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + return image.rotate( + angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor + ) + + +def validate_annotations( + annotation_format: AnnotationFormat, + supported_annotation_formats: Tuple[AnnotationFormat, ...], + annotations: List[Dict], +) -> None: + if annotation_format not in supported_annotation_formats: + raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}") + + if annotation_format is AnnotationFormat.COCO_DETECTION: + if not valid_coco_detection_annotations(annotations): + raise ValueError( + "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts " + "(batch of images) with the following keys: `image_id` and `annotations`, with the latter " + "being a list of annotations in the COCO format." + ) + + if annotation_format is AnnotationFormat.COCO_PANOPTIC: + if not valid_coco_panoptic_annotations(annotations): + raise ValueError( + "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts " + "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with " + "the latter being a list of annotations in the COCO format." + ) + + +def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]): + unused_keys = set(captured_kwargs).difference(set(valid_processor_keys)) + if unused_keys: + unused_key_str = ", ".join(unused_keys) + # TODO raise a warning here instead of simply logging? + logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.") diff --git a/keras_callbacks.py b/keras_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e832729a1eeb482d1193753cc2c07ad1f16c2e --- /dev/null +++ b/keras_callbacks.py @@ -0,0 +1,413 @@ +import logging +import os +from pathlib import Path +from time import sleep +from typing import Callable, List, Optional, Union + +import numpy as np +import tensorflow as tf +from huggingface_hub import Repository, create_repo +from packaging.version import parse + +from . import IntervalStrategy, PreTrainedTokenizerBase +from .modelcard import TrainingSummary +from .modeling_tf_utils import keras + + +logger = logging.getLogger(__name__) + + +class KerasMetricCallback(keras.callbacks.Callback): + """ + Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be + compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string + operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the + `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute + metrics and return a dict mapping metric names to metric values. + + We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that + this example skips some post-processing for readability and simplicity, and should probably not be used as-is! + + ```py + from datasets import load_metric + + rouge_metric = load_metric("rouge") + + + def rouge_fn(predictions, labels): + decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels) + return {key: value.mid.fmeasure * 100 for key, value in result.items()} + ``` + + The above function will return a dict containing values which will be logged like any other Keras metric: + + ``` + {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781 + ``` + + Args: + metric_fn (`Callable`): + Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`. + These contain the model's outputs and matching labels from the dataset. It should return a dict mapping + metric names to numerical values. + eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`): + Validation data to be used to generate predictions for the `metric_fn`. + output_cols (`List[str], *optional*): + A list of columns to be retained from the model output as the predictions. Defaults to all. + label_cols ('`List[str]`, *optional*'): + A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not + supplied. + batch_size (`int`, *optional*): + Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`. + predict_with_generate (`bool`, *optional*, defaults to `False`): + Whether we should use `model.generate()` to get outputs for the model. + use_xla_generation (`bool`, *optional*, defaults to `False`): + If we're generating, whether to compile model generation with XLA. This can massively increase the speed of + generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA + generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of` + argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and + save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`. + generate_kwargs (`dict`, *optional*): + Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate` + is `False`. + + """ + + def __init__( + self, + metric_fn: Callable, + eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], + output_cols: Optional[List[str]] = None, + label_cols: Optional[List[str]] = None, + batch_size: Optional[int] = None, + predict_with_generate: bool = False, + use_xla_generation: bool = False, + generate_kwargs: Optional[dict] = None, + ): + super().__init__() + self.metric_fn = metric_fn + self.batch_size = batch_size + if not isinstance(eval_dataset, tf.data.Dataset): + if batch_size is None: + raise ValueError( + "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " + "the batch_size argument must be set." + ) + # Wrap a tf.data.Dataset around it + eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False) + self.eval_dataset = eval_dataset + self.predict_with_generate = predict_with_generate + self.output_cols = output_cols + + # This next block attempts to parse out which elements of the dataset should be appended to the labels list + # that is passed to the metric_fn + if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2: + input_spec, label_spec = eval_dataset.element_spec + else: + input_spec = eval_dataset.element_spec + label_spec = None + if label_cols is not None: + for label in label_cols: + if label not in input_spec: + raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!") + self.label_cols = label_cols + self.use_keras_label = False + elif label_spec is not None: + # If the dataset inputs are split into a 2-tuple of inputs and labels, + # assume the second element is the labels + self.label_cols = None + self.use_keras_label = True + elif "labels" in input_spec: + self.label_cols = ["labels"] + self.use_keras_label = False + logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.") + elif "start_positions" in input_spec and "end_positions" in input_spec: + self.label_cols = ["start_positions", "end_positions"] + self.use_keras_label = False + logging.warning( + "No label_cols specified for KerasMetricCallback, assuming you want the " + "start_positions and end_positions keys." + ) + else: + raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") + if parse(tf.__version__) < parse("2.7"): + logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!") + + self.use_xla_generation = use_xla_generation + self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs + + self.generation_function = None + + @staticmethod + def _concatenate_batches(batches, padding_index=-100): + # If all batches are unidimensional or same length, do a simple concatenation + if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches): + return np.concatenate(batches, axis=0) + + # Welp, they're not the same length. Let's do some padding + max_len = max([batch.shape[1] for batch in batches]) + num_samples = sum([batch.shape[0] for batch in batches]) + output = np.full_like( + batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:]) + ) + # i keeps track of which part of the concatenated array we're writing the next batch to + i = 0 + for batch in batches: + output[i : i + len(batch), : batch.shape[1]] = batch + i += len(batch) + return output + + def _postprocess_predictions_or_labels(self, inputs): + if isinstance(inputs[0], dict): + outputs = {} + for key in inputs[0].keys(): + outputs[key] = self._concatenate_batches([batch[key] for batch in inputs]) + # If it's a dict with only one key, just return the array + if len(outputs) == 1: + outputs = list(outputs.values())[0] + elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple): + outputs = [] + for input_list in zip(*inputs): + outputs.append(self._concatenate_batches(input_list)) + if len(outputs) == 1: + outputs = outputs[0] # If it's a list with only one element, just return the array + elif isinstance(inputs[0], np.ndarray): + outputs = self._concatenate_batches(inputs) + elif isinstance(inputs[0], tf.Tensor): + outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs]) + else: + raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!") + return outputs + + def on_epoch_end(self, epoch, logs=None): + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + main_input_name = None + if self.predict_with_generate: + # This dense conditional recognizes the case where we have an encoder-decoder model, but + # avoids getting tangled up when we just have a model with a layer called 'encoder' + if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"): + main_input_name = self.model.encoder.main_input_name + else: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + + if self.use_xla_generation and self.generation_function is None: + + def generation_function(inputs, attention_mask): + return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs) + + self.generation_function = tf.function(generation_function, jit_compile=True) + + prediction_list = [] + label_list = [] + + # The whole predict/generate loop is handled inside this method + for batch in self.eval_dataset: + if isinstance(batch, tuple): + batch, labels = batch + else: + labels = None + if self.predict_with_generate: + if isinstance(batch, dict): + generation_inputs = batch[main_input_name] + attention_mask = batch.get("attention_mask", None) + else: + generation_inputs = batch + attention_mask = None + if self.use_xla_generation: + predictions = self.generation_function(generation_inputs, attention_mask=attention_mask) + else: + predictions = self.model.generate( + generation_inputs, attention_mask=attention_mask, **self.generate_kwargs + ) + else: + predictions = self.model.predict_on_batch(batch) + if isinstance(predictions, dict): + # This converts any dict-subclass to a regular dict + # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class + predictions = dict(predictions) + if self.output_cols is not None: + predictions = {key: predictions[key] for key in self.output_cols} + else: + predictions = { + key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"] + } + prediction_list.append(predictions) + if not self.use_keras_label: + labels = {key: batch[key].numpy() for key in self.label_cols} + elif isinstance(labels, dict): + labels = {key: array.numpy() for key, array in labels.items()} + elif isinstance(labels, list) or isinstance(labels, tuple): + labels = [array.numpy() for array in labels] + elif isinstance(labels, tf.Tensor): + labels = labels.numpy() + else: + raise TypeError(f"Confused by labels of type {type(labels)}") + label_list.append(labels) + + all_preds = self._postprocess_predictions_or_labels(prediction_list) + all_labels = self._postprocess_predictions_or_labels(label_list) + + metric_output = self.metric_fn((all_preds, all_labels)) + if not isinstance(metric_output, dict): + raise TypeError( + f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}" + ) + # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch + # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of + # new keys in there, which will then get read by the History callback and treated like any other metric value. + # I promise that I have it in writing from Chollet that this is okay. + logs.update(metric_output) + + +class PushToHubCallback(keras.callbacks.Callback): + """ + Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can + be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such + as with the `from_pretrained` method. + + ```py + from transformers.keras_callbacks import PushToHubCallback + + push_to_hub_callback = PushToHubCallback( + output_dir="./model_save", + tokenizer=tokenizer, + hub_model_id="gpt5-7xlarge", + ) + + model.fit(train_dataset, callbacks=[push_to_hub_callback]) + ``` + + Args: + output_dir (`str`): + The output directory where the model predictions and checkpoints will be written and synced with the + repository on the Hub. + save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: Save is done at the end of training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps` + save_steps (`int`, *optional*): + The number of steps between saves when using the "steps" `save_strategy`. + tokenizer (`PreTrainedTokenizerBase`, *optional*): + The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. + hub_model_id (`str`, *optional*): + The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, + for instance `"user_name/model"`, which allows you to push to an organization you are a member of with + `"organization_name/model"`. + + Will default to the name of `output_dir`. + hub_token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with + `huggingface-cli login`. + checkpoint (`bool`, *optional*, defaults to `False`): + Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be + resumed. Only usable when `save_strategy` is `"epoch"`. + """ + + def __init__( + self, + output_dir: Union[str, Path], + save_strategy: Union[str, IntervalStrategy] = "epoch", + save_steps: Optional[int] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + hub_model_id: Optional[str] = None, + hub_token: Optional[str] = None, + checkpoint: bool = False, + **model_card_args, + ): + super().__init__() + if checkpoint and save_strategy != "epoch": + raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!") + if isinstance(save_strategy, str): + save_strategy = IntervalStrategy(save_strategy.lower()) + self.save_strategy = save_strategy + if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0): + raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!") + self.save_steps = save_steps + output_dir = Path(output_dir) + + # Create repo and retrieve repo_id + if hub_model_id is None: + hub_model_id = output_dir.absolute().name + self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id + + self.output_dir = output_dir + self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token) + + self.tokenizer = tokenizer + self.last_job = None + self.checkpoint = checkpoint + self.training_history = None + self.model_card_args = model_card_args + + def on_train_begin(self, logs=None): + # Although we can access model.history, we have no guarantees that the History callback will fire before this + # one, so we keep track of it here too + self.training_history = [] + + def on_train_batch_end(self, batch, logs=None): + if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0: + if self.last_job is not None and not self.last_job.is_done: + return # The last upload is still running, don't start another + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + _, self.last_job = self.repo.push_to_hub( + commit_message=f"Training in progress steps {batch}", blocking=False + ) + + def on_epoch_end(self, epoch, logs=None): + logs = logs.copy() # Don't accidentally write things that Keras will read later + if "epoch" not in logs: + logs["epoch"] = epoch + self.training_history.append(logs) + if self.save_strategy == IntervalStrategy.EPOCH: + if self.last_job is not None and not self.last_job.is_done: + return # The last upload is still running, don't start another + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + if self.checkpoint: + checkpoint_dir = os.path.join(self.output_dir, "checkpoint") + self.model._save_checkpoint(checkpoint_dir, epoch) + train_summary = TrainingSummary.from_keras( + model=self.model, + model_name=self.hub_model_id, + keras_history=self.training_history, + **self.model_card_args, + ) + model_card = train_summary.to_model_card() + with (self.output_dir / "README.md").open("w") as f: + f.write(model_card) + _, self.last_job = self.repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False + ) + + def on_train_end(self, logs=None): + # Makes sure the latest version of the model is uploaded + if self.last_job is not None and not self.last_job.is_done: + logging.info("Pushing the last epoch to the Hub, this may take a while...") + while not self.last_job.is_done: + sleep(1) + else: + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + train_summary = TrainingSummary.from_keras( + model=self.model, + model_name=self.hub_model_id, + keras_history=self.training_history, + **self.model_card_args, + ) + model_card = train_summary.to_model_card() + with (self.output_dir / "README.md").open("w") as f: + f.write(model_card) + self.repo.push_to_hub(commit_message="End of training", blocking=True) diff --git a/modelcard.py b/modelcard.py new file mode 100644 index 0000000000000000000000000000000000000000..57cd971dd79a53fb2acc953ba2e920a0073a7e87 --- /dev/null +++ b/modelcard.py @@ -0,0 +1,908 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration base class and utilities.""" + +import copy +import json +import os +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import requests +import yaml +from huggingface_hub import model_info +from huggingface_hub.utils import HFValidationError + +from . import __version__ +from .models.auto.modeling_auto import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_CTC_MAPPING_NAMES, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, + MODEL_FOR_MASKED_LM_MAPPING_NAMES, + MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, +) +from .training_args import ParallelMode +from .utils import ( + MODEL_CARD_NAME, + cached_file, + is_datasets_available, + is_offline_mode, + is_tf_available, + is_tokenizers_available, + is_torch_available, + logging, +) + + +TASK_MAPPING = { + "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, + "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES, + "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, + "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, + "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, + "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES}, + "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, +} + +logger = logging.get_logger(__name__) + + +class ModelCard: + r""" + Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards. + + Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by + Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer, + Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://arxiv.org/abs/1810.03993 + + Note: A model card can be loaded and saved to disk. + """ + + def __init__(self, **kwargs): + warnings.warn( + "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning + ) + # Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers) + self.model_details = kwargs.pop("model_details", {}) + self.intended_use = kwargs.pop("intended_use", {}) + self.factors = kwargs.pop("factors", {}) + self.metrics = kwargs.pop("metrics", {}) + self.evaluation_data = kwargs.pop("evaluation_data", {}) + self.training_data = kwargs.pop("training_data", {}) + self.quantitative_analyses = kwargs.pop("quantitative_analyses", {}) + self.ethical_considerations = kwargs.pop("ethical_considerations", {}) + self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {}) + + # Open additional attributes + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def save_pretrained(self, save_directory_or_file): + """Save a model card object to the directory or file `save_directory_or_file`.""" + if os.path.isdir(save_directory_or_file): + # If we save using the predefined names, we can load using `from_pretrained` + output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME) + else: + output_model_card_file = save_directory_or_file + + self.to_json_file(output_model_card_file) + logger.info(f"Model card saved in {output_model_card_file}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate a [`ModelCard`] from a pre-trained model model card. + + Parameters: + pretrained_model_name_or_path: either: + + - a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co. + - a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`] + method, e.g.: `./my_model_directory/`. + - a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`. + + cache_dir: (*optional*) string: + Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache + should not be used. + + kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading. + + - The values in kwargs of any keys which are model card attributes will be used to override the loaded + values. + - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the + *return_unused_kwargs* keyword parameter. + + proxies: (*optional*) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. + + return_unused_kwargs: (*optional*) bool: + + - If False, then this function returns just the final model card object. + - If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a + dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of + kwargs which has not been used to update *ModelCard* and is otherwise ignored. + + Examples: + + ```python + # Download model card from huggingface.co and cache. + modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased") + # Model card was saved using *save_pretrained('./test/saved_model/')* + modelcard = ModelCard.from_pretrained("./test/saved_model/") + modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json") + modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False) + ```""" + cache_dir = kwargs.pop("cache_dir", None) + proxies = kwargs.pop("proxies", None) + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + from_pipeline = kwargs.pop("_from_pipeline", None) + + user_agent = {"file_type": "model_card"} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + resolved_model_card_file = pretrained_model_name_or_path + is_local = True + else: + try: + # Load from URL or cache if already cached + resolved_model_card_file = cached_file( + pretrained_model_name_or_path, + filename=MODEL_CARD_NAME, + cache_dir=cache_dir, + proxies=proxies, + user_agent=user_agent, + ) + if is_local: + logger.info(f"loading model card file {resolved_model_card_file}") + else: + logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}") + # Load model card + modelcard = cls.from_json_file(resolved_model_card_file) + + except (EnvironmentError, json.JSONDecodeError): + # We fall back on creating an empty model card + modelcard = cls() + + # Update model card with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(modelcard, key): + setattr(modelcard, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Model card: {modelcard}") + if return_unused_kwargs: + return modelcard, kwargs + else: + return modelcard + + @classmethod + def from_dict(cls, json_object): + """Constructs a `ModelCard` from a Python dictionary of parameters.""" + return cls(**json_object) + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `ModelCard` from a json file of parameters.""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + dict_obj = json.loads(text) + return cls(**dict_obj) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path): + """Save this instance to a json file.""" + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +AUTOGENERATED_TRAINER_COMMENT = """ + +""" + +AUTOGENERATED_KERAS_COMMENT = """ + +""" + + +TASK_TAG_TO_NAME_MAPPING = { + "fill-mask": "Masked Language Modeling", + "image-classification": "Image Classification", + "image-segmentation": "Image Segmentation", + "multiple-choice": "Multiple Choice", + "object-detection": "Object Detection", + "question-answering": "Question Answering", + "summarization": "Summarization", + "table-question-answering": "Table Question Answering", + "text-classification": "Text Classification", + "text-generation": "Causal Language Modeling", + "text2text-generation": "Sequence-to-sequence Language Modeling", + "token-classification": "Token Classification", + "translation": "Translation", + "zero-shot-classification": "Zero Shot Classification", + "automatic-speech-recognition": "Automatic Speech Recognition", + "audio-classification": "Audio Classification", +} + + +METRIC_TAGS = [ + "accuracy", + "bleu", + "f1", + "matthews_correlation", + "pearsonr", + "precision", + "recall", + "rouge", + "sacrebleu", + "spearmanr", + "wer", +] + + +def _listify(obj): + if obj is None: + return [] + elif isinstance(obj, str): + return [obj] + else: + return obj + + +def _insert_values_as_list(metadata, name, values): + if values is None: + return metadata + if isinstance(values, str): + values = [values] + values = [v for v in values if v is not None] + if len(values) == 0: + return metadata + metadata[name] = values + return metadata + + +def infer_metric_tags_from_eval_results(eval_results): + if eval_results is None: + return {} + result = {} + for key in eval_results.keys(): + if key.lower().replace(" ", "_") in METRIC_TAGS: + result[key.lower().replace(" ", "_")] = key + elif key.lower() == "rouge1": + result["rouge"] = key + return result + + +def _insert_value(metadata, name, value): + if value is None: + return metadata + metadata[name] = value + return metadata + + +def is_hf_dataset(dataset): + if not is_datasets_available(): + return False + + from datasets import Dataset, IterableDataset + + return isinstance(dataset, (Dataset, IterableDataset)) + + +def _get_mapping_values(mapping): + result = [] + for v in mapping.values(): + if isinstance(v, (tuple, list)): + result += list(v) + else: + result.append(v) + return result + + +@dataclass +class TrainingSummary: + model_name: str + language: Optional[Union[str, List[str]]] = None + license: Optional[str] = None + tags: Optional[Union[str, List[str]]] = None + finetuned_from: Optional[str] = None + tasks: Optional[Union[str, List[str]]] = None + dataset: Optional[Union[str, List[str]]] = None + dataset_tags: Optional[Union[str, List[str]]] = None + dataset_args: Optional[Union[str, List[str]]] = None + dataset_metadata: Optional[Dict[str, Any]] = None + eval_results: Optional[Dict[str, float]] = None + eval_lines: Optional[List[str]] = None + hyperparameters: Optional[Dict[str, Any]] = None + source: Optional[str] = "trainer" + + def __post_init__(self): + # Infer default license from the checkpoint used, if possible. + if ( + self.license is None + and not is_offline_mode() + and self.finetuned_from is not None + and len(self.finetuned_from) > 0 + ): + try: + info = model_info(self.finetuned_from) + for tag in info.tags: + if tag.startswith("license:"): + self.license = tag[8:] + except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError): + pass + + def create_model_index(self, metric_mapping): + model_index = {"name": self.model_name} + + # Dataset mapping tag -> name + dataset_names = _listify(self.dataset) + dataset_tags = _listify(self.dataset_tags) + dataset_args = _listify(self.dataset_args) + dataset_metadata = _listify(self.dataset_metadata) + if len(dataset_args) < len(dataset_tags): + dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args)) + dataset_mapping = dict(zip(dataset_tags, dataset_names)) + dataset_arg_mapping = dict(zip(dataset_tags, dataset_args)) + dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata)) + + task_mapping = { + task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING + } + + model_index["results"] = [] + + if len(task_mapping) == 0 and len(dataset_mapping) == 0: + return [model_index] + if len(task_mapping) == 0: + task_mapping = {None: None} + if len(dataset_mapping) == 0: + dataset_mapping = {None: None} + + # One entry per dataset and per task + all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping] + for task_tag, ds_tag in all_possibilities: + result = {} + if task_tag is not None: + result["task"] = {"name": task_mapping[task_tag], "type": task_tag} + + if ds_tag is not None: + metadata = dataset_metadata_mapping.get(ds_tag, {}) + result["dataset"] = { + "name": dataset_mapping[ds_tag], + "type": ds_tag, + **metadata, + } + if dataset_arg_mapping[ds_tag] is not None: + result["dataset"]["args"] = dataset_arg_mapping[ds_tag] + + if len(metric_mapping) > 0: + result["metrics"] = [] + for metric_tag, metric_name in metric_mapping.items(): + result["metrics"].append( + { + "name": metric_name, + "type": metric_tag, + "value": self.eval_results[metric_name], + } + ) + + # Remove partial results to avoid the model card being rejected. + if "task" in result and "dataset" in result and "metrics" in result: + model_index["results"].append(result) + else: + logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}") + + return [model_index] + + def create_metadata(self): + metric_mapping = infer_metric_tags_from_eval_results(self.eval_results) + + metadata = {} + metadata = _insert_value(metadata, "library_name", "transformers") + metadata = _insert_values_as_list(metadata, "language", self.language) + metadata = _insert_value(metadata, "license", self.license) + if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0: + metadata = _insert_value(metadata, "base_model", self.finetuned_from) + metadata = _insert_values_as_list(metadata, "tags", self.tags) + metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags) + metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys())) + metadata["model-index"] = self.create_model_index(metric_mapping) + + return metadata + + def to_model_card(self): + model_card = "" + + metadata = yaml.dump(self.create_metadata(), sort_keys=False) + if len(metadata) > 0: + model_card = f"---\n{metadata}---\n" + + # Now the model card for realsies. + if self.source == "trainer": + model_card += AUTOGENERATED_TRAINER_COMMENT + else: + model_card += AUTOGENERATED_KERAS_COMMENT + + model_card += f"\n# {self.model_name}\n\n" + + if self.finetuned_from is None: + model_card += "This model was trained from scratch on " + else: + model_card += ( + "This model is a fine-tuned version of" + f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on " + ) + + if self.dataset is None: + model_card += "an unknown dataset." + else: + if isinstance(self.dataset, str): + model_card += f"the {self.dataset} dataset." + elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1: + model_card += f"the {self.dataset[0]} dataset." + else: + model_card += ( + ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets." + ) + + if self.eval_results is not None: + model_card += "\nIt achieves the following results on the evaluation set:\n" + model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()]) + model_card += "\n" + + model_card += "\n## Model description\n\nMore information needed\n" + model_card += "\n## Intended uses & limitations\n\nMore information needed\n" + model_card += "\n## Training and evaluation data\n\nMore information needed\n" + + model_card += "\n## Training procedure\n" + model_card += "\n### Training hyperparameters\n" + if self.hyperparameters is not None: + model_card += "\nThe following hyperparameters were used during training:\n" + model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()]) + model_card += "\n" + else: + model_card += "\nMore information needed\n" + + if self.eval_lines is not None: + model_card += "\n### Training results\n\n" + model_card += make_markdown_table(self.eval_lines) + model_card += "\n" + + model_card += "\n### Framework versions\n\n" + model_card += f"- Transformers {__version__}\n" + + if self.source == "trainer" and is_torch_available(): + import torch + + model_card += f"- Pytorch {torch.__version__}\n" + elif self.source == "keras" and is_tf_available(): + import tensorflow as tf + + model_card += f"- TensorFlow {tf.__version__}\n" + if is_datasets_available(): + import datasets + + model_card += f"- Datasets {datasets.__version__}\n" + if is_tokenizers_available(): + import tokenizers + + model_card += f"- Tokenizers {tokenizers.__version__}\n" + + return model_card + + @classmethod + def from_trainer( + cls, + trainer, + language=None, + license=None, + tags=None, + model_name=None, + finetuned_from=None, + tasks=None, + dataset_tags=None, + dataset_metadata=None, + dataset=None, + dataset_args=None, + ): + # Infer default from dataset + one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset + if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None): + default_tag = one_dataset.builder_name + # Those are not real datasets from the Hub so we exclude them. + if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: + if dataset_metadata is None: + dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}] + if dataset_tags is None: + dataset_tags = [default_tag] + if dataset_args is None: + dataset_args = [one_dataset.config_name] + + if dataset is None and dataset_tags is not None: + dataset = dataset_tags + + # Infer default finetuned_from + if ( + finetuned_from is None + and hasattr(trainer.model.config, "_name_or_path") + and not os.path.isdir(trainer.model.config._name_or_path) + ): + finetuned_from = trainer.model.config._name_or_path + + # Infer default task tag: + if tasks is None: + model_class_name = trainer.model.__class__.__name__ + for task, mapping in TASK_MAPPING.items(): + if model_class_name in _get_mapping_values(mapping): + tasks = task + + if model_name is None: + model_name = Path(trainer.args.output_dir).name + if len(model_name) == 0: + model_name = finetuned_from + + # Add `generated_from_trainer` to the tags + if tags is None: + tags = ["generated_from_trainer"] + elif isinstance(tags, str) and tags != "generated_from_trainer": + tags = [tags, "generated_from_trainer"] + elif "generated_from_trainer" not in tags: + tags.append("generated_from_trainer") + + _, eval_lines, eval_results = parse_log_history(trainer.state.log_history) + hyperparameters = extract_hyperparameters_from_trainer(trainer) + + return cls( + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset=dataset, + dataset_tags=dataset_tags, + dataset_args=dataset_args, + dataset_metadata=dataset_metadata, + eval_results=eval_results, + eval_lines=eval_lines, + hyperparameters=hyperparameters, + ) + + @classmethod + def from_keras( + cls, + model, + model_name, + keras_history=None, + language=None, + license=None, + tags=None, + finetuned_from=None, + tasks=None, + dataset_tags=None, + dataset=None, + dataset_args=None, + ): + # Infer default from dataset + if dataset is not None: + if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None): + default_tag = dataset.builder_name + # Those are not real datasets from the Hub so we exclude them. + if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: + if dataset_tags is None: + dataset_tags = [default_tag] + if dataset_args is None: + dataset_args = [dataset.config_name] + + if dataset is None and dataset_tags is not None: + dataset = dataset_tags + + # Infer default finetuned_from + if ( + finetuned_from is None + and hasattr(model.config, "_name_or_path") + and not os.path.isdir(model.config._name_or_path) + ): + finetuned_from = model.config._name_or_path + + # Infer default task tag: + if tasks is None: + model_class_name = model.__class__.__name__ + for task, mapping in TASK_MAPPING.items(): + if model_class_name in _get_mapping_values(mapping): + tasks = task + + # Add `generated_from_keras_callback` to the tags + if tags is None: + tags = ["generated_from_keras_callback"] + elif isinstance(tags, str) and tags != "generated_from_keras_callback": + tags = [tags, "generated_from_keras_callback"] + elif "generated_from_keras_callback" not in tags: + tags.append("generated_from_keras_callback") + + if keras_history is not None: + _, eval_lines, eval_results = parse_keras_history(keras_history) + else: + eval_lines = [] + eval_results = {} + hyperparameters = extract_hyperparameters_from_keras(model) + + return cls( + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + eval_results=eval_results, + eval_lines=eval_lines, + hyperparameters=hyperparameters, + source="keras", + ) + + +def parse_keras_history(logs): + """ + Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict` + passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`. + """ + if hasattr(logs, "history"): + # This looks like a `History` object + if not hasattr(logs, "epoch"): + # This history looks empty, return empty results + return None, [], {} + logs.history["epoch"] = logs.epoch + logs = logs.history + else: + # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object + logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]} + + lines = [] + for i in range(len(logs["epoch"])): + epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()} + values = {} + for k, v in epoch_dict.items(): + if k.startswith("val_"): + k = "validation_" + k[4:] + elif k != "epoch": + k = "train_" + k + splits = k.split("_") + name = " ".join([part.capitalize() for part in splits]) + values[name] = v + lines.append(values) + + eval_results = lines[-1] + + return logs, lines, eval_results + + +def parse_log_history(log_history): + """ + Parse the `log_history` of a Trainer to get the intermediate and final evaluation results. + """ + idx = 0 + while idx < len(log_history) and "train_runtime" not in log_history[idx]: + idx += 1 + + # If there are no training logs + if idx == len(log_history): + idx -= 1 + while idx >= 0 and "eval_loss" not in log_history[idx]: + idx -= 1 + + if idx >= 0: + return None, None, log_history[idx] + else: + return None, None, None + + # From now one we can assume we have training logs: + train_log = log_history[idx] + lines = [] + training_loss = "No log" + for i in range(idx): + if "loss" in log_history[i]: + training_loss = log_history[i]["loss"] + if "eval_loss" in log_history[i]: + metrics = log_history[i].copy() + _ = metrics.pop("total_flos", None) + epoch = metrics.pop("epoch", None) + step = metrics.pop("step", None) + _ = metrics.pop("eval_runtime", None) + _ = metrics.pop("eval_samples_per_second", None) + _ = metrics.pop("eval_steps_per_second", None) + _ = metrics.pop("eval_jit_compilation_time", None) + values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step} + for k, v in metrics.items(): + if k == "eval_loss": + values["Validation Loss"] = v + else: + splits = k.split("_") + name = " ".join([part.capitalize() for part in splits[1:]]) + values[name] = v + lines.append(values) + + idx = len(log_history) - 1 + while idx >= 0 and "eval_loss" not in log_history[idx]: + idx -= 1 + + if idx > 0: + eval_results = {} + for key, value in log_history[idx].items(): + if key.startswith("eval_"): + key = key[5:] + if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]: + camel_cased_key = " ".join([part.capitalize() for part in key.split("_")]) + eval_results[camel_cased_key] = value + return train_log, lines, eval_results + else: + return train_log, lines, None + + +def extract_hyperparameters_from_keras(model): + from .modeling_tf_utils import keras + + hyperparameters = {} + if hasattr(model, "optimizer") and model.optimizer is not None: + hyperparameters["optimizer"] = model.optimizer.get_config() + else: + hyperparameters["optimizer"] = None + hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name + + return hyperparameters + + +def _maybe_round(v, decimals=4): + if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals: + return f"{v:.{decimals}f}" + return str(v) + + +def _regular_table_line(values, col_widths): + values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)] + return "".join(values_with_space) + "|\n" + + +def _second_table_line(col_widths): + values = ["|:" + "-" * w + ":" for w in col_widths] + return "".join(values) + "|\n" + + +def make_markdown_table(lines): + """ + Create a nice Markdown table from the results in `lines`. + """ + if lines is None or len(lines) == 0: + return "" + col_widths = {key: len(str(key)) for key in lines[0].keys()} + for line in lines: + for key, value in line.items(): + if col_widths[key] < len(_maybe_round(value)): + col_widths[key] = len(_maybe_round(value)) + + table = _regular_table_line(list(lines[0].keys()), list(col_widths.values())) + table += _second_table_line(list(col_widths.values())) + for line in lines: + table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values())) + return table + + +_TRAINING_ARGS_KEYS = [ + "learning_rate", + "train_batch_size", + "eval_batch_size", + "seed", +] + + +def extract_hyperparameters_from_trainer(trainer): + hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS} + + if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]: + hyperparameters["distributed_type"] = ( + "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value + ) + if trainer.args.world_size > 1: + hyperparameters["num_devices"] = trainer.args.world_size + if trainer.args.gradient_accumulation_steps > 1: + hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps + + total_train_batch_size = ( + trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps + ) + if total_train_batch_size != hyperparameters["train_batch_size"]: + hyperparameters["total_train_batch_size"] = total_train_batch_size + total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size + if total_eval_batch_size != hyperparameters["eval_batch_size"]: + hyperparameters["total_eval_batch_size"] = total_eval_batch_size + + if trainer.args.optim: + optimizer_name = trainer.args.optim + optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments" + + if "adam" in optimizer_name.lower(): + hyperparameters["optimizer"] = ( + f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and" + f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}" + ) + else: + hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}" + + hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value + if trainer.args.warmup_ratio != 0.0: + hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio + if trainer.args.warmup_steps != 0.0: + hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps + if trainer.args.max_steps != -1: + hyperparameters["training_steps"] = trainer.args.max_steps + else: + hyperparameters["num_epochs"] = trainer.args.num_train_epochs + + if trainer.args.fp16: + if trainer.use_apex: + hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}" + else: + hyperparameters["mixed_precision_training"] = "Native AMP" + + if trainer.args.label_smoothing_factor != 0.0: + hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor + + return hyperparameters diff --git a/modeling_attn_mask_utils.py b/modeling_attn_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09fc77e46b07ed613a6f390c2dea6cc9703089dd --- /dev/null +++ b/modeling_attn_mask_utils.py @@ -0,0 +1,481 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from .utils.import_utils import is_torchdynamo_compiling + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy + # See https://github.com/pytorch/pytorch/issues/127571 + if is_torchdynamo_compiling(): + mask = mask.clone() + mask.masked_fill_(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.FloatTensor, + min_dtype: float, + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. + + For example, if `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) + + return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) + + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + is_training: bool = False, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be + ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is + passed). + """ + + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling() + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input + # shape, thus SDPA's `is_causal` argument is rightfully updated + # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using + # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is + # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` + # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` + # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). + if ( + (is_training or not is_tracing) + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + return False + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore + # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in + # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling() + + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if ignore_causal_mask: + expanded_4d_mask = None + elif attention_mask is None: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: + if attention_mask.dim() == 4: + expanded_4d_mask = attention_mask + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min + ) + + return expanded_4d_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + _, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + + is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling() + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. + if not is_tracing and torch.all(mask == 1): + return None + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask diff --git a/modeling_flash_attention_utils.py b/modeling_flash_attention_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6adda0036cc0963d1a25e5774050cd1476a8dcb1 --- /dev/null +++ b/modeling_flash_attention_utils.py @@ -0,0 +1,389 @@ +# coding=utf-8 +# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F + +from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging + + +logger = logging.get_logger(__name__) + + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn import flash_attn_func, flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def fa_peft_integration_check( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + target_dtype: Optional[torch.dtype] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + + Args: + query (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value (`torch.Tensor`): + Input value states to be passed to Flash Attention API + target_dtype (`torch.dtype`, *optional*): + The dtype to convert the attention tensors to. Conversion can be ignored by + not providing the target dtype. + """ + if target_dtype is None: + return query, key, value + + input_dtype = value.dtype + if input_dtype == torch.float32: + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + return query, key, value + + +flash_241 = is_flash_attn_greater_or_equal("2.4.1") +deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + + +def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[torch.dtype] = None, + **kwargs, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + """ + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if flash_241: + if deterministic is None: + deterministic = deterministic_g + flash_kwargs["deterministic"] = deterministic + + if softcap is not None: + flash_kwargs["softcap"] = softcap + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype + ) + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and ( + max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) + ): + batch_size = query_states.size(0) + + if cu_seq_lens_q is None or cu_seq_lens_k is None: + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) + ) + + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens + max_length_q, max_length_k = max_seq_lens + + else: + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output + + +class FlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumlative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] diff --git a/modeling_flax_outputs.py b/modeling_flax_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..179a0b787936960c118bbb5ad34f73d00469d481 --- /dev/null +++ b/modeling_flax_outputs.py @@ -0,0 +1,700 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple + +import flax +import jax.numpy as jnp + +from .utils import ModelOutput + + +@flax.struct.dataclass +class FlaxBaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`Dict[str, jnp.ndarray]`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jnp.ndarray = None + past_key_values: Optional[Dict[str, jnp.ndarray]] = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxCausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value + states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. + Only relevant if `config.is_decoder = True`. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +FlaxCausalLMOutput = FlaxMaskedLMOutput + + +@flax.struct.dataclass +class FlaxSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxNextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxMultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + start_logits: jnp.ndarray = None + end_logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + start_logits: jnp.ndarray = None + end_logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None diff --git a/modeling_flax_pytorch_utils.py b/modeling_flax_pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbba8a1651364eef58cbed7fa6c037c6eee86e2 --- /dev/null +++ b/modeling_flax_pytorch_utils.py @@ -0,0 +1,492 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch - Flax general utilities.""" + +import os +from pickle import UnpicklingError +from typing import Dict, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from flax.serialization import from_bytes +from flax.traverse_util import flatten_dict, unflatten_dict + +import transformers + +from . import is_safetensors_available, is_torch_available +from .utils import logging + + +if is_torch_available(): + import torch + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.flax import load_file as safe_load_file + + +logger = logging.get_logger(__name__) + + +##################### +# PyTorch => Flax # +##################### + + +def load_pytorch_checkpoint_in_flax_state_dict( + flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False +): + """Load pytorch checkpoints in a flax model""" + + if not is_sharded: + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info(f"Loading PyTorch weights from {pt_path}") + + if pt_path.endswith(".safetensors"): + pt_state_dict = {} + with safe_open(pt_path, framework="flax") as f: + for k in f.keys(): + pt_state_dict[k] = f.get_tensor(k) + else: + try: + import torch # noqa: F401 + except (ImportError, ModuleNotFoundError): + logger.error( + "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + + weights_only_kwarg = {"weights_only": True} + pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") + + flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + else: + # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files + flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) + return flax_state_dict + + +def rename_key_and_reshape_tensor( + pt_tuple_key: Tuple[str], + pt_tensor: np.ndarray, + random_flax_state_dict: Dict[str, jnp.ndarray], + model_prefix: str, +) -> (Tuple[str], np.ndarray): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + + def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: + """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict""" + return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0 + + # layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # batch norm layer mean + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",) + if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # batch norm layer var + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",) + if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # embedding + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key): + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key): + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + name = None + if pt_tuple_key[-3::2] == ("parametrizations", "original0"): + name = pt_tuple_key[-2] + "_g" + elif pt_tuple_key[-3::2] == ("parametrizations", "original1"): + name = pt_tuple_key[-2] + "_v" + if name is not None: + renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,) + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): + # convert pytorch tensor to numpy + from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor) + bfloat16 = torch.bfloat16 if from_bin else "bfloat16" + + weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} + + if from_bin: + for k, v in pt_state_dict.items(): + # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision + if v.dtype == bfloat16: + v = v.float() + pt_state_dict[k] = v.cpu().numpy() + + model_prefix = flax_model.base_model_prefix + + # use params dict if the model contains batch norm layers + if "params" in flax_model.params: + flax_model_params = flax_model.params["params"] + else: + flax_model_params = flax_model.params + random_flax_state_dict = flatten_dict(flax_model_params) + + # add batch_stats keys,values to dict + if "batch_stats" in flax_model.params: + flax_batch_stats = flatten_dict(flax_model.params["batch_stats"]) + random_flax_state_dict.update(flax_batch_stats) + + flax_state_dict = {} + + load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and ( + model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()} + ) + load_base_model_into_model_with_head = (model_prefix in flax_model_params) and ( + model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()} + ) + + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + pt_tuple_key = tuple(pt_key.split(".")) + is_bfloat_16 = weight_dtypes[pt_key] == bfloat16 + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # add batch stats if the model contains batchnorm layers + if "batch_stats" in flax_model.params: + if "mean" in flax_key[-1] or "var" in flax_key[-1]: + flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) + continue + # remove num_batches_tracked key + if "num_batches_tracked" in flax_key[-1]: + flax_state_dict.pop(flax_key, None) + continue + + # also add unexpected weight so that warning is thrown + flax_state_dict[("params",) + flax_key] = ( + jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + ) + else: + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = ( + jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + ) + + return unflatten_dict(flax_state_dict) + + +############################ +# Sharded Pytorch => Flax # +############################ + + +def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): + import torch + + # Load the index + flax_state_dict = {} + for shard_file in shard_filenames: + # load using msgpack utils + weights_only_kwarg = {"weights_only": True} + pt_state_dict = torch.load(shard_file, **weights_only_kwarg) + weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} + pt_state_dict = { + k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() + } + + model_prefix = flax_model.base_model_prefix + + # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict + if "batch_stats" in flax_model.params: + flax_model_params = flax_model.params["params"] + + random_flax_state_dict = flatten_dict(flax_model_params) + random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"])) + else: + flax_model_params = flax_model.params + random_flax_state_dict = flatten_dict(flax_model_params) + + load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and ( + model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()} + ) + load_base_model_into_model_with_head = (model_prefix in flax_model_params) and ( + model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()} + ) + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + pt_tuple_key = tuple(pt_key.split(".")) + is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16 + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # add batch stats if the model contains batchnorm layers + if "batch_stats" in flax_model.params: + if "mean" in flax_key[-1]: + flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) + continue + if "var" in flax_key[-1]: + flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor) + continue + # remove num_batches_tracked key + if "num_batches_tracked" in flax_key[-1]: + flax_state_dict.pop(flax_key, None) + continue + + # also add unexpected weight so that warning is thrown + flax_state_dict[("params",) + flax_key] = ( + jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + ) + + else: + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = ( + jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) + ) + return unflatten_dict(flax_state_dict) + + +##################### +# Flax => PyTorch # +##################### + + +def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path): + """Load flax checkpoints in a PyTorch model""" + flax_checkpoint_path = os.path.abspath(flax_checkpoint_path) + logger.info(f"Loading Flax weights from {flax_checkpoint_path}") + + # import correct flax class + flax_cls = getattr(transformers, "Flax" + model.__class__.__name__) + + # load flax weight dict + if flax_checkpoint_path.endswith(".safetensors"): + flax_state_dict = safe_load_file(flax_checkpoint_path) + flax_state_dict = unflatten_dict(flax_state_dict, sep=".") + else: + with open(flax_checkpoint_path, "rb") as state_f: + try: + flax_state_dict = from_bytes(flax_cls, state_f.read()) + except UnpicklingError: + raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") + + return load_flax_weights_in_pytorch_model(model, flax_state_dict) + + +def load_flax_weights_in_pytorch_model(pt_model, flax_state): + """Load flax checkpoints in a PyTorch model""" + + try: + import torch # noqa: F401 + except (ImportError, ModuleNotFoundError): + logger.error( + "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + + # check if we have bf16 weights + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + if any(is_type_bf16): + # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 + # and bf16 is not fully supported in PT yet. + logger.warning( + "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " + "before loading those in PyTorch model." + ) + flax_state = jax.tree_util.tree_map( + lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state + ) + + flax_state_dict = flatten_dict(flax_state) + pt_model_dict = pt_model.state_dict() + + load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( + pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()} + ) + load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( + pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()} + ) + + # keep track of unexpected & missing keys + unexpected_keys = [] + missing_keys = set(pt_model_dict.keys()) + + for flax_key_tuple, flax_tensor in flax_state_dict.items(): + has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix + require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict + + # adapt flax_key to prepare for loading from/to base model only + if load_model_with_head_into_base_model and has_base_model_prefix: + flax_key_tuple = flax_key_tuple[1:] + elif load_base_model_into_model_with_head and require_base_model_prefix: + flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple + + # rename flax weights to PyTorch format + if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict: + # conv layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) + elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: + # linear layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = flax_tensor.T + elif flax_key_tuple[-1] in ["scale", "embedding"]: + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + + # adding batch stats from flax batch norm to pt + elif "mean" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",) + elif "var" in flax_key_tuple[-1]: + flax_key_tuple = flax_key_tuple[:-1] + ("running_var",) + + if "batch_stats" in flax_state: + flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header + else: + flax_key = ".".join(flax_key_tuple) + + # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation. + special_pt_names = {} + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + for key in pt_model_dict: + key_components = key.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + key_to_check = ".".join(key_components) + special_pt_names[key_to_check] = key + + if flax_key in special_pt_names: + flax_key = special_pt_names[flax_key] + + if flax_key in pt_model_dict: + if flax_tensor.shape != pt_model_dict[flax_key].shape: + raise ValueError( + f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " + f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + else: + # add weight to pytorch dict + flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor + pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) + # remove from missing keys + missing_keys.remove(flax_key) + else: + # weight is not expected by PyTorch model + unexpected_keys.append(flax_key) + + pt_model.load_state_dict(pt_model_dict) + + # re-transform missing_keys to list + missing_keys = list(missing_keys) + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the Flax model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" + f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " FlaxBertForSequenceClassification model)." + ) + else: + logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {pt_model.__class__.__name__} for predictions without further training." + ) + + return pt_model diff --git a/modeling_flax_utils.py b/modeling_flax_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4a3be732a4f961d08b8b2cb9a9dc490d4b34e5 --- /dev/null +++ b/modeling_flax_utils.py @@ -0,0 +1,1290 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +import json +import os +import re +import warnings +from functools import partial +from pickle import UnpicklingError +from typing import Any, Dict, Optional, Set, Tuple, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from .configuration_utils import PretrainedConfig +from .dynamic_module_utils import custom_object_save +from .generation import FlaxGenerationMixin, GenerationConfig +from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict +from .utils import ( + FLAX_WEIGHTS_INDEX_NAME, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + PushToHubMixin, + add_code_sample_docstrings, + add_start_docstrings_to_model_forward, + cached_file, + copy_func, + download_url, + has_file, + is_offline_mode, + is_remote_url, + logging, + replace_return_docstrings, +) +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.import_utils import is_safetensors_available + + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.flax import load_file as safe_load_file + from safetensors.flax import save_file as safe_save_file + +logger = logging.get_logger(__name__) + + +def quick_gelu(x): + return x * jax.nn.sigmoid(1.702 * x) + + +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.swish, + "swish": nn.swish, + "gelu_new": partial(nn.gelu, approximate=True), + "quick_gelu": quick_gelu, + "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), +} + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. Example: + ```py + >>> dtype_byte_size(np.float32) + 4 + ``` + """ + if dtype is bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)$", dtype.name) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def flax_shard_checkpoint(params, max_shard_size="10GB"): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so + there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For + example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as + [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + # flatten the weights to chunk + weights = flatten_dict(params, sep="/") + for item in weights: + weight_size = weights[item].size * dtype_byte_size(weights[item].dtype) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[item] = weights[item] + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack") + shards[shard_file] = shard + for weight_name in shard.keys(): + weight_map[weight_name] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): + r""" + Base class for all models. + + [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + _auto_class = None + _missing_keys = set() + + def __init__( + self, + config: PretrainedConfig, + module: nn.Module, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + ): + if config is None: + raise ValueError("config cannot be None") + + if module is None: + raise ValueError("module cannot be None") + + # Those are private to be exposed as typed property on derived classes. + self._config = config + self._module = module + + # Those are public as their type is generic to every derived classes. + self.key = PRNGKey(seed) + self.dtype = dtype + self.input_shape = input_shape + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + + # To check if the model was initialized automatically. + self._is_initialized = _do_init + + if _do_init: + # randomly initialized parameters + random_params = self.init_weights(self.key, input_shape) + params_shape_tree = jax.eval_shape(lambda params: params, random_params) + else: + init_fn = partial(self.init_weights, input_shape=input_shape) + params_shape_tree = jax.eval_shape(init_fn, self.key) + + logger.info( + "Model weights are not initialized as `_do_init` is set to `False`. " + f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." + ) + + # get the shape of the parameters + self._params_shape_tree = params_shape_tree + + # save required_params as set + self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + # initialize the parameters + if _do_init: + self.params = random_params + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: + raise NotImplementedError(f"init method has to be implemented for {self}") + + def enable_gradient_checkpointing(self): + raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}") + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a Flax model. + """ + return "flax" + + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def module(self) -> nn.Module: + return self._module + + @property + def params(self) -> Union[Dict, FrozenDict]: + if not self._is_initialized: + raise ValueError( + "`params` cannot be accessed from model when the model is created with `_do_init=False`. " + "You must call `init_weights` manually and store the params outside of the model and " + "pass it explicitly where needed." + ) + return self._params + + @property + def required_params(self) -> Set: + return self._required_params + + @property + def params_shape_tree(self) -> Dict: + return self._params_shape_tree + + @params.setter + def params(self, params: Union[Dict, FrozenDict]): + # don't set params if the model is not initialized + if not self._is_initialized: + raise ValueError( + "`params` cannot be set from model when the model is created with `_do_init=False`. " + "You store the params outside of the model." + ) + + if isinstance(params, FrozenDict): + params = unfreeze(params) + param_keys = set(flatten_dict(params).keys()) + if len(self.required_params - param_keys) > 0: + raise ValueError( + "Some parameters are missing. Make sure that `params` include the following " + f"parameters {self.required_params - param_keys}" + ) + self._params = params + + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_util.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_util.tree_flatten(mask) + + for masked, key in zip(flat_mask, sorted(flat_params.keys())): + if masked: + flat_params[key] = conditional_cast(flat_params[key]) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip. + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> model.params = model.to_bf16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_bf16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # Download model and configuration from huggingface.co + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> model.params = model.to_f16(model.params) + >>> # now cast back to fp32 + >>> model.params = model.to_fp32(model.params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> model.params = model.to_fp16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_fp16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + @classmethod + def load_flax_weights(cls, resolved_archive_file): + try: + if resolved_archive_file.endswith(".safetensors"): + state = safe_load_file(resolved_archive_file) + state = unflatten_dict(state, sep=".") + else: + with open(resolved_archive_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ") + + return state + + @classmethod + def load_flax_sharded_weights(cls, shard_files): + """ + This is the same as [`flax.serialization.from_bytes`] + (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + shard_files (`List[str]`: + The list of shard files to load. + + Returns: + `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model': + {'params': {'...'}}}`. + """ + + # Load the index + state_sharded_dict = {} + + for shard_file in shard_files: + # load using msgpack utils + try: + with open(shard_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + with open(shard_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ") + + state = flatten_dict(state, sep="/") + state_sharded_dict.update(state) + del state + gc.collect() + + # the state dict is unflattened to the match the format of model.params + return unflatten_dict(state_sharded_dict, sep="/") + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternativelly, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + r""" + Instantiate a pretrained flax model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, + `from_pt` should be set to `True`. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import BertConfig, FlaxBertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") + >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./pt_model/config.json") + >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) + ```""" + from_pt = kwargs.pop("from_pt", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _do_init = kwargs.pop("_do_init", True) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + # Not relevant for Flax Models + _ = kwargs.pop("adapter_kwargs", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + _commit_hash=commit_hash, + **kwargs, + ) + else: + model_kwargs = kwargs.copy() + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # Add the dtype to model_kwargs + model_kwargs["dtype"] = dtype + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + + # Load model + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): + # Load from a sharded Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) + is_sharded = True + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + elif from_pt and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) + ): + # Load from a sharded pytorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " + "weights." + ) + else: + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + if from_pt: + filename = WEIGHTS_NAME + else: + filename = FLAX_WEIGHTS_NAME + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + + # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. + if resolved_archive_file is None and from_pt: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + + # If we still haven't found anything, look for `safetensors`. + if resolved_archive_file is None: + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = SAFE_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs + ) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): + is_sharded = True + raise NotImplementedError( + "Support for sharded checkpoints using safetensors is coming soon!" + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" + " load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" + " `from_pt=True` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + filename = resolved_archive_file.split(os.path.sep)[-1] + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, _ = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + safetensors_from_pt = False + if filename == SAFE_WEIGHTS_NAME: + with safe_open(resolved_archive_file, framework="flax") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + + # init random models + model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) + + if from_pt or safetensors_from_pt: + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) + else: + if is_sharded: + state = cls.load_flax_sharded_weights(resolved_archive_file) + else: + state = cls.load_flax_weights(resolved_archive_file) + # make sure all arrays are stored as jnp.arrays + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + if _do_init: + state = jax.tree_util.tree_map(jnp.array, state) + else: + # keep the params on CPU if we don't want to initialize + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) + + if "batch_stats" in state: # if flax model contains batch norm layers + # if model is base model only use model_prefix key + if ( + cls.base_model_prefix not in dict(model.params_shape_tree["params"]) + and cls.base_model_prefix in state["params"] + ): + state["params"] = state["params"][cls.base_model_prefix] + state["batch_stats"] = state["batch_stats"][cls.base_model_prefix] + + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if ( + cls.base_model_prefix in dict(model.params_shape_tree["params"]) + and cls.base_model_prefix not in state["params"] + ): + state = { + "params": {cls.base_model_prefix: state["params"]}, + "batch_stats": {cls.base_model_prefix: state["batch_stats"]}, + } + + else: + # if model is base model only use model_prefix key + if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: + state = state[cls.base_model_prefix] + + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: + state = {cls.base_model_prefix: state} + + # flatten dicts + state = flatten_dict(state) + + random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree)) + + missing_keys = model.required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - model.required_params + + # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked + for unexpected_key in unexpected_keys.copy(): + if "num_batches_tracked" in unexpected_key[-1]: + unexpected_keys.remove(unexpected_key) + + if missing_keys and not _do_init: + logger.warning( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + "Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + for key in state.keys(): + if key in random_state and state[key].shape != random_state[key].shape: + if ignore_mismatched_sizes: + mismatched_keys.append((key, state[key].shape, random_state[key].shape)) + state[key] = random_state[key] + else: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " + "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " + "model." + ) + + # add missing keys as random parameters if we are initializing + if missing_keys and _do_init: + for missing_key in missing_keys: + state[missing_key] = random_state[missing_key] + + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + # dictionary of key: dtypes for the model params + param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state) + # extract keys of parameters not in jnp.float32 + fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] + bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] + + # raise a warning if any of the parameters are not in jnp.float32 + if len(fp16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." + ) + + if len(bf16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." + ) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if _do_init: + # set correct parameters + model.params = unflatten_dict(state) + return model + else: + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params=None, + push_to_hub=False, + max_shard_size="10GB", + token: Optional[Union[str, bool]] = None, + safe_serialization: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~FlaxPreTrainedModel.from_pretrained`]` class method + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or through msgpack. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # get abs dir + save_directory = os.path.abspath(save_directory) + # save config as well + self.config.architectures = [self.__class__.__name__[4:]] + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + self.config.save_pretrained(save_directory) + if self.can_generate(): + self.generation_config.save_pretrained(save_directory) + + # save model + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME + output_model_file = os.path.join(save_directory, weights_name) + + shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards.keys() + ): + os.remove(full_filename) + + if index is None: + if safe_serialization: + params = params if params is not None else self.params + flat_dict = flatten_dict(params, sep=".") + safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"}) + else: + with open(output_model_file, "wb") as f: + params = params if params is not None else self.params + model_bytes = to_bytes(params) + f.write(model_bytes) + + else: + save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + for shard_file, shard in shards.items(): + # the shard item are unflattened, to save them we need to flatten them again + with open(os.path.join(save_directory, shard_file), mode="wb") as f: + params = unflatten_dict(shard, sep="/") + shard_bytes = to_bytes(params) + f.write(shard_bytes) + + logger.info(f"Model weights saved in {output_model_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + @classmethod + def register_for_auto_class(cls, auto_class="FlaxAutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + +# To update the docstring, we need to copy the method, otherwise we change the original docstring. +FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub) +if FlaxPreTrainedModel.push_to_hub.__doc__ is not None: + FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format( + object="model", object_class="FlaxAutoModel", object_files="model checkpoint" + ) + + +def overwrite_call_docstring(model_class, docstring): + # copy __call__ function to be sure docstring is changed only for this function + model_class.__call__ = copy_func(model_class.__call__) + # delete existing docstring + model_class.__call__.__doc__ = None + # set correct docstring + model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) + + +def append_call_sample_docstring( + model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None +): + model_class.__call__ = copy_func(model_class.__call__) + model_class.__call__ = add_code_sample_docstrings( + checkpoint=checkpoint, + output_type=output_type, + config_class=config_class, + model_cls=model_class.__name__, + revision=revision, + real_checkpoint=real_checkpoint, + )(model_class.__call__) + + +def append_replace_return_docstrings(model_class, output_type, config_class): + model_class.__call__ = copy_func(model_class.__call__) + model_class.__call__ = replace_return_docstrings( + output_type=output_type, + config_class=config_class, + )(model_class.__call__) diff --git a/modeling_gguf_pytorch_utils.py b/modeling_gguf_pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b20c1b61226a0fa94a5160d7ab6d64cd476617c --- /dev/null +++ b/modeling_gguf_pytorch_utils.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991) +# https://github.com/99991/pygguf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, NamedTuple, Optional + +import numpy as np +from tqdm import tqdm + +from .integrations import ( + GGUF_CONFIG_MAPPING, + GGUF_TOKENIZER_MAPPING, + _gguf_parse_value, +) +from .utils import is_torch_available +from .utils.import_utils import is_gguf_available +from .utils.logging import get_logger + + +if is_torch_available(): + import torch + +logger = get_logger(__name__) + + +GGUF_TO_TRANSFORMERS_MAPPING = { + "ignore": { + "GGUF": { + "version": "version", + "tensor_count": "tensor_count", + "kv_count": "kv_count", + }, + "general": {"file_type": "file_type", "quantization_version": "quantization_version"}, + }, + "config": GGUF_CONFIG_MAPPING, + "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]}, + "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]}, +} + +GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["config"].keys()) + + +class GGUFTensor(NamedTuple): + weights: np.ndarray + name: str + metadata: dict + + +class TensorProcessor: + def __init__(self, config=None): + self.config = config or {} + + def process(self, weights, name, **kwargs): + return GGUFTensor(weights, name, {}) + + +class LlamaTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if ".attn_k." in name or ".attn_q." in name: + num_heads = self.config.get("num_attention_heads") + num_kv_heads = self.config.get("num_key_value_heads") + + if None in (num_heads, num_kv_heads): + return GGUFTensor(weights, name, {}) + if ".attn_q." in name: + weights = self._reverse_permute_weights(weights, num_heads, num_heads) + elif ".attn_k." in name: + weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads) + return GGUFTensor(weights, name, {}) + + def _reverse_permute_weights( + self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None + ) -> np.ndarray: + # Original permutation implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408 + if num_kv_heads is not None and n_head != num_kv_heads: + n_head = num_kv_heads + + dim = weights.shape[0] // n_head // 2 + w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) + return w.swapaxes(2, 1).reshape(weights.shape) + + +class Qwen2MoeTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "_exp" in name: + tensor_key_mapping = kwargs.get("tensor_key_mapping") + parsed_parameters = kwargs.get("parsed_parameters") + if tensor_key_mapping: + self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping) + return GGUFTensor(weights, None, {}) + if "ffn_gate_inp_shexp" in name: + # for compatibility tensor shared_expert_gate must be (1, 2048) dim, + # quantized one is (2048) + weights = np.expand_dims(weights, axis=0) + return GGUFTensor(weights, name, {}) + + def _split_moe_expert_tensor( + self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict + ): + # Original merge implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022 + name = tensor_key_mapping[name] + w_counter = self.config.get("num_experts", 60) + for i in range(0, w_counter): + temp_name = name.replace("mlp.experts.", f"mlp.experts.{i}.") + exp_weight = weights[i] + parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight)) + + +class BloomTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "attn_qkv" in name: + num_heads = self.config["n_head"] + n_embed = self.config["hidden_size"] + if "weight" in name: + weights = self._reverse_reshape_weights(weights, num_heads, n_embed) + else: + weights = self._reverse_reshape_bias(weights, num_heads, n_embed) + return GGUFTensor(weights, name, {}) + + def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int): + # Original reshape implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985 + q, k, v = np.array_split(weights, 3, axis=0) + + q = q.reshape(n_head, n_embed // n_head, n_embed) + k = k.reshape(n_head, n_embed // n_head, n_embed) + v = v.reshape(n_head, n_embed // n_head, n_embed) + qkv_weights = np.stack([q, k, v], axis=1) + + return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed) + + def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int): + # Original reshape implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998 + q_bias, k_bias, v_bias = np.array_split(weights, 3) + + q_bias = q_bias.reshape(n_head, n_embed // n_head) + k_bias = k_bias.reshape(n_head, n_embed // n_head) + v_bias = v_bias.reshape(n_head, n_embed // n_head) + + qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten() + return qkv_bias + + +class T5TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + bid = None + for chunk in name.split("."): + if chunk.isdigit(): + bid = int(chunk) + break + return GGUFTensor(weights, name, {"bid": bid}) + + +class GPT2TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + # Original transpose implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061 + if ( + "attn_qkv.weight" in name + or "ffn_down.weight" in name + or "ffn_up.weight" in name + or "attn_output.weight" in name + ): + weights = weights.T + + # Handle special case for output.weight + if name == "output.weight": + # output.weight has conflicts with attn_output.weight in name checking + # Store the tensor directly and signal to skip further processing + name = "lm_head.weight" + parsed_parameters = kwargs.get("parsed_parameters", {}) + parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) + name = None # Signal to skip further processing + return GGUFTensor(weights, name, {}) + + +class MambaTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "ssm_conv1d.weight" in name: + # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim, + # quantized one is (5120, 4) + weights = np.expand_dims(weights, axis=1) + if "ssm_a" in name: + # Original exponential implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977 + weights = np.log(-weights) + return GGUFTensor(weights, name, {}) + + +class Gemma2TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + # ref: https://github.com/ggerganov/llama.cpp/blob/d79d8f39b4da6deca4aea8bf130c6034c482b320/convert_hf_to_gguf.py#L3191 + # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 + def process(self, weights, name, **kwargs): + if "norm.weight" in name: + weights = weights - 1 + return GGUFTensor(weights, name, {}) + + +TENSOR_PROCESSORS = { + "llama": LlamaTensorProcessor, + "qwen2moe": Qwen2MoeTensorProcessor, + "bloom": BloomTensorProcessor, + "t5": T5TensorProcessor, + "t5encoder": T5TensorProcessor, + "gpt2": GPT2TensorProcessor, + "mamba": MambaTensorProcessor, + "gemma2": Gemma2TensorProcessor, +} + + +def read_field(reader, field): + value = reader.fields[field] + return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] + + +# modified from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/loader.py#L1115-L1147 +def get_gguf_hf_weights_map( + hf_model, + model_type: Optional[str] = None, + num_layers: Optional[int] = None, + qual_name: str = "", +): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + if is_gguf_available() and is_torch_available(): + from gguf import MODEL_ARCH_NAMES, get_tensor_name_map + else: + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") + + model_type = hf_model.config.model_type if model_type is None else model_type + num_layers = hf_model.config.num_hidden_layers if num_layers is None else num_layers + # hack: ggufs have a different name for cohere + if model_type == "cohere": + model_type = "command-r" + if model_type == "qwen2_moe": + model_type = "qwen2moe" + arch = None + for key, value in MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise NotImplementedError( + f"Unknown gguf model_type: {model_type} in gguf-py. " + "This might because you're using an outdated version of gguf-py package, " + "you can install `gguf` package from source refer to " + "https://github.com/ggerganov/llama.cpp/tree/master/gguf-py#development" + ) + name_map = get_tensor_name_map(arch, num_layers) + + # Use a dummy conversion to get the mapping, because + # hf => gguf and gguf => hf mappings are reversed + gguf_to_hf_name_map = {} + state_dict = hf_model.state_dict() + for hf_name in state_dict.keys(): + # An exception for qwen2moe model, where the expert layers are packed + if model_type == "qwen2moe" and "mlp.experts." in hf_name: + hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name) + + name, suffix = hf_name, "" + if hf_name.endswith(".weight") or hf_name.endswith(".bias"): + name, suffix = hf_name.rsplit(".", 1) + suffix = "." + suffix + + gguf_name = name_map.get_name(name) + if gguf_name is None: + continue + + gguf_to_hf_name_map[gguf_name + suffix] = qual_name + hf_name + + # Some model like Bloom converted from BloomModel instead of BloomForCausalLM + # Therefore, we need to check submodule as well to get a correct mapping + if named_children := hf_model.named_children(): + for name, child in named_children: + sub_map = get_gguf_hf_weights_map(child, model_type, num_layers, qual_name=f"{qual_name}{name}.") + # Ignore the keys that are already in the main map to avoid overwriting + sub_map = {k: v for k, v in sub_map.items() if k not in gguf_to_hf_name_map} + gguf_to_hf_name_map.update(sub_map) + + return gguf_to_hf_name_map + + +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_load=None): + """ + Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed + tokenizer and config attributes. + + Args: + gguf_checkpoint_path (`str`): + The path the to GGUF file to load + return_tensors (`bool`, defaults to `True`): + Whether to read the tensors from the file and return them. Not doing so is faster + and only loads the metadata in memory. + """ + if is_gguf_available() and is_torch_available(): + from gguf import GGUFReader, dequantize + else: + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") + + reader = GGUFReader(gguf_checkpoint_path) + fields = reader.fields + reader_keys = list(fields.keys()) + + parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING} + + architecture = read_field(reader, "general.architecture")[0] + model_name = read_field(reader, "general.name") + + # in llama.cpp mistral models use the same architecture as llama. We need + # to add this patch to ensure things work correctly on our side. + if "llama" in architecture and "mistral" in model_name: + updated_architecture = "mistral" + # FIXME: Currnetly this implementation is only for flan-t5 architecture. + # It needs to be developed for supporting legacy t5. + elif "t5" in architecture or "t5encoder" in architecture: + parsed_parameters["config"]["is_gated_act"] = True + updated_architecture = "t5" + else: + updated_architecture = architecture + + if "qwen2moe" in architecture: + updated_architecture = "qwen2_moe" + + # For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors + # If `qkv_bias=True`, qkv_proj with bias will be present in the tensors + # If `use_parallel_residual=False`, ffn_norm will be present in the tensors + if "stablelm" in architecture: + attn_bias_name = {"attn_q.bias", "attn_k.bias", "attn_v.bias"} + ffn_norm_name = "ffn_norm" + qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name) + use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors) + parsed_parameters["config"]["use_qkv_bias"] = qkv_bias + parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual + + if architecture not in GGUF_SUPPORTED_ARCHITECTURES: + raise ValueError(f"GGUF model with architecture {architecture} is not supported yet.") + + # Handle tie_word_embeddings, if lm_head.weight is not present in tensors, + # tie_word_embeddings is true otherwise false + parsed_parameters["config"]["tie_word_embeddings"] = all( + "output.weight" != tensor.name for tensor in reader.tensors + ) + + # List all key-value pairs in a columnized format + for gguf_key, field in reader.fields.items(): + gguf_key = gguf_key.replace(architecture, updated_architecture) + split = gguf_key.split(".") + prefix = split[0] + config_key = ".".join(split[1:]) + + value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data] + + if len(value) == 1: + value = value[0] + + if isinstance(value, str) and architecture in value: + value = value.replace(architecture, updated_architecture) + + for parameter in GGUF_TO_TRANSFORMERS_MAPPING: + parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter] + if prefix in parameter_renames and config_key in parameter_renames[prefix]: + renamed_config_key = parameter_renames[prefix][config_key] + if renamed_config_key == -1: + continue + + if renamed_config_key is not None: + parsed_parameters[parameter][renamed_config_key] = value + + if gguf_key in reader_keys: + reader_keys.remove(gguf_key) + + if gguf_key in reader_keys: + logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}") + + # retrieve config vocab_size from tokenizer + # Pleas refer to https://github.com/huggingface/transformers/issues/32526 for more details + if "vocab_size" not in parsed_parameters["config"]: + tokenizer_parameters = parsed_parameters["tokenizer"] + if "tokens" in tokenizer_parameters: + parsed_parameters["config"]["vocab_size"] = len(tokenizer_parameters["tokens"]) + else: + logger.warning( + "Can't find a way to retrieve missing config vocab_size from tokenizer parameters. " + "This will use default value from model config class and cause unexpected behavior." + ) + + if return_tensors: + parsed_parameters["tensors"] = {} + + tensor_key_mapping = get_gguf_hf_weights_map(model_to_load) + config = parsed_parameters.get("config", {}) + + ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor) + processor = ProcessorClass(config=config) + + for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."): + name = tensor.name + weights = dequantize(tensor.data, tensor.tensor_type) + + result = processor.process( + weights=weights, + name=name, + tensor_key_mapping=tensor_key_mapping, + parsed_parameters=parsed_parameters, + ) + + weights = result.weights + name = result.name + + if name not in tensor_key_mapping: + continue + + name = tensor_key_mapping[name] + + parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) + + if len(reader_keys) > 0: + logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") + + return parsed_parameters diff --git a/modeling_outputs.py b/modeling_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..7328e05186f2deddebb54f76d64427475de849a6 --- /dev/null +++ b/modeling_outputs.py @@ -0,0 +1,1753 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from .utils import ModelOutput + + +@dataclass +class BaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + z_loss: torch.FloatTensor = None + aux_loss: torch.FloatTensor = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as + Mixture of Expert's router hidden states terms, to train a MoE model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, + value states of the self-attention and the cross-attention layers if model is used in encoder-decoder + setting. Only relevant if `config.is_decoder = True`. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts + models. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + encoder_z_loss: torch.FloatTensor = None + decoder_z_loss: torch.FloatTensor = None + encoder_aux_loss: torch.FloatTensor = None + decoder_aux_loss: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class NextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided): + Next sequence prediction (classification) loss. + logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class TokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class QuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DepthEstimatorOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): + Predicted depth for each pixel. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + predicted_depth: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageSuperResolutionOutput(ModelOutput): + """ + Base class for outputs of image super resolution models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed images, possibly upscaled. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Wav2Vec2BaseModelOutput(ModelOutput): + """ + Base class for models that have been trained with the Wav2Vec2 loss objective. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + extract_features: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ForXVector`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + embeddings: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BackboneOutput(ModelOutput): + """ + Base class for outputs of backbones. + + Args: + feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): + Feature maps of the stages. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`, + depending on the backbone. + + Hidden-states of the model at the output of each stage plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Only applicable if the backbone uses attention. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + feature_maps: Tuple[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndProjection(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. + + Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + projection_state: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqSpectrogramOutput(ModelOutput): + """ + Base class for sequence-to-sequence spectrogram outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Spectrogram generation loss. + spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): + The predicted spectrogram. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + spectrogram: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqTSModelOutput(ModelOutput): + """ + Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up + sequential decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class Seq2SeqTSPredictionOutput(ModelOutput): + """ + Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the + chosen distribution. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided): + Distributional loss. + params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`): + Parameters of the chosen distribution. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + loss: Optional[torch.FloatTensor] = None + params: Optional[Tuple[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class SampleTSPredictionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Args: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`): + Sampled values from the chosen distribution. + """ + + sequences: torch.FloatTensor = None + + +@dataclass +class MaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/modeling_rope_utils.py b/modeling_rope_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d343e0237fa0492208e1a00ec899f548d7d5e4 --- /dev/null +++ b/modeling_rope_utils.py @@ -0,0 +1,568 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +from .configuration_utils import PretrainedConfig +from .utils import is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_linear_scaling_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + factor = rope_kwargs["factor"] + elif config is not None: + factor = config.rope_scaling["factor"] + + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + # Then applies linear scaling to the frequencies. + # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so + # applying scaling to the inverse frequencies is equivalent. + inv_freq /= factor + return inv_freq, attention_factor + + +def _compute_dynamic_ntk_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, used to update the dynamic RoPE at inference time. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + max_position_embeddings = rope_kwargs["max_position_embeddings"] + factor = rope_kwargs["factor"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + + # Compute the inverse frequencies + base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _compute_longrope_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with LongRoPE scaling. Please refer to the + [original implementation](https://github.com/microsoft/LongRoPE) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + # No need to keep BC with longrope, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " + f"{rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + long_factor = config.rope_scaling["long_factor"] + short_factor = config.rope_scaling["short_factor"] + factor = config.rope_scaling.get("factor") + attention_factor = config.rope_scaling.get("attention_factor") + + # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if hasattr(config, "original_max_position_embeddings"): + original_max_position_embeddings = config.original_max_position_embeddings + factor = config.max_position_embeddings / config.original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) + + # Compute the inverse frequencies -- scaled based on the target sequence length + if seq_len and seq_len > original_max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + return inv_freq, attention_factor + + +def _compute_llama3_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "linear": _compute_linear_scaling_rope_parameters, + "dynamic": _compute_dynamic_ntk_parameters, + "yarn": _compute_yarn_parameters, + "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, +} + + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "short_factor", "long_factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + short_factor = rope_scaling.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + if not len(short_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + + long_factor = rope_scaling.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + if not len(long_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + + # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is + # unique to longrope (= undesirable) + if hasattr(config, "original_max_position_embeddings"): + logger.warning_once( + "This model has set a `original_max_position_embeddings` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + else: + factor = rope_scaling.get("factor") + if factor is None: + logger.warning("Missing required keys in `rope_scaling`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + + +def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor <= low_freq_factor: + logger.warning( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) + + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "linear": _validate_linear_scaling_rope_parameters, + "dynamic": _validate_dynamic_scaling_rope_parameters, + "yarn": _validate_yarn_parameters, + "longrope": _validate_longrope_parameters, + "llama3": _validate_llama3_parameters, +} + + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) diff --git a/modeling_tf_outputs.py b/modeling_tf_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..357c34bc1f25fc1ea8da9dd9d5870cf3bdc7add7 --- /dev/null +++ b/modeling_tf_outputs.py @@ -0,0 +1,991 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import tensorflow as tf + +from .utils import ModelOutput + + +@dataclass +class TFBaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor, ...]] = None + + +@dataclass +class TFBaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + + This output is usually *not* a good summary of the semantic content of the input, you're often better with + averaging or pooling the sequence of hidden-states for the whole input sequence. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor = None + pooler_output: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: tf.Tensor = None + pooler_output: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor, ...]] = None + + +@dataclass +class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + + This output is usually *not* a good summary of the semantic content of the input, you're often better with + averaging or pooling the sequence of hidden-states for the whole input sequence. + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor = None + pooler_output: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + decoder_hidden_states: Tuple[tf.Tensor] | None = None + decoder_attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: Tuple[tf.Tensor] | None = None + encoder_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFCausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + decoder_hidden_states: Tuple[tf.Tensor] | None = None + decoder_attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: Tuple[tf.Tensor] | None = None + encoder_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFNextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `next_sentence_label` is provided): + Next sentence prediction loss. + logits (`tf.Tensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)` + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + decoder_hidden_states: Tuple[tf.Tensor] | None = None + decoder_attentions: Tuple[tf.Tensor] | None = None + cross_attentions: Tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: Tuple[tf.Tensor] | None = None + encoder_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSemanticSegmenterOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of semantic segmentation models that do not output attention scores. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFMultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`tf.Tensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided) : + Classification loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + start_logits: tf.Tensor = None + end_logits: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: tf.Tensor | None = None + start_logits: tf.Tensor = None + end_logits: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + decoder_hidden_states: Tuple[tf.Tensor] | None = None + decoder_attentions: Tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: Tuple[tf.Tensor] | None = None + encoder_attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFSequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + past_key_values: List[tf.Tensor] | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor, ...]] = None + + +@dataclass +class TFMaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + reconstruction: tf.Tensor = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/modeling_tf_pytorch_utils.py b/modeling_tf_pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec24d6e1872ef1ab8878e5cf6e8df2919b76cf7 --- /dev/null +++ b/modeling_tf_pytorch_utils.py @@ -0,0 +1,673 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch - TF 2.0 general utilities.""" + +import os +import re + +import numpy + +from .utils import ( + ExplicitEnum, + expand_dims, + is_numpy_array, + is_safetensors_available, + is_torch_tensor, + logging, + reshape, + squeeze, + tensor_size, +) +from .utils import transpose as transpose_func + + +if is_safetensors_available(): + from safetensors import safe_open + + +logger = logging.get_logger(__name__) + + +class TransposeType(ExplicitEnum): + """ + Possible ... + """ + + NO = "no" + SIMPLE = "simple" + CONV1D = "conv1d" + CONV2D = "conv2d" + + +def convert_tf_weight_name_to_pt_weight_name( + tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None +): + """ + Convert a TF 2.0 model variable name in a pytorch model weight name. + + Conventions for TF2.0 scopes -> PyTorch attribute names conversions: + + - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) + - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) + + return tuple with: + + - pytorch model weight name + - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be + transposed with regards to each other + """ + if name_scope is not None: + if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name: + raise ValueError( + f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error " + "in Transformers, so (unless you were doing something really evil) please open an issue to report it!" + ) + tf_name = tf_name[len(name_scope) :] + tf_name = tf_name.lstrip("/") + tf_name = tf_name.replace(":0", "") # device ids + tf_name = re.sub( + r"/[^/]*___([^/]*)/", r"/\1/", tf_name + ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) + tf_name = tf_name.replace( + "_._", "/" + ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) + tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end + tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators + # Some weights have a single name without "/" such as final_logits_bias in BART + if len(tf_name) > 1: + tf_name = tf_name[1:] # Remove level zero + + tf_weight_shape = list(tf_weight_shape) + + # When should we transpose the weights + if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4: + transpose = TransposeType.CONV2D + elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3: + transpose = TransposeType.CONV1D + elif bool( + tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] + or "emb_projs" in tf_name + or "out_projs" in tf_name + ): + transpose = TransposeType.SIMPLE + else: + transpose = TransposeType.NO + + # Convert standard TF2.0 names in PyTorch names + if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": + tf_name[-1] = "weight" + if tf_name[-1] == "beta": + tf_name[-1] = "bias" + + # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here + if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel": + tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") + + # Remove prefix if needed + tf_name = ".".join(tf_name) + if start_prefix_to_remove: + tf_name = tf_name.replace(start_prefix_to_remove, "", 1) + + return tf_name, transpose + + +def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True): + """ + Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a + framework agnostic way. + """ + if transpose is TransposeType.CONV2D: + # Conv2D weight: + # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) + # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) + axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1) + weight = transpose_func(weight, axes=axes) + elif transpose is TransposeType.CONV1D: + # Conv1D weight: + # PT: (num_out_channel, num_in_channel, kernel) + # -> TF: (kernel, num_in_channel, num_out_channel) + weight = transpose_func(weight, axes=(2, 1, 0)) + elif transpose is TransposeType.SIMPLE: + weight = transpose_func(weight) + + if match_shape is None: + return weight + + if len(match_shape) < len(weight.shape): + weight = squeeze(weight) + elif len(match_shape) > len(weight.shape): + weight = expand_dims(weight, axis=0) + + if list(match_shape) != list(weight.shape): + try: + weight = reshape(weight, match_shape) + except AssertionError as e: + e.args += (match_shape, match_shape) + raise e + + return weight + + +##################### +# PyTorch => TF 2.0 # +##################### + + +def load_pytorch_checkpoint_in_tf2_model( + tf_model, + pytorch_checkpoint_path, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, +): + """Load pytorch checkpoints in a TF 2.0 model""" + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + from safetensors.torch import load_file as safe_load_file # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Treats a single file as a collection of shards with 1 shard. + if isinstance(pytorch_checkpoint_path, str): + pytorch_checkpoint_path = [pytorch_checkpoint_path] + + # Loads all shards into a single state dictionary + pt_state_dict = {} + for path in pytorch_checkpoint_path: + pt_path = os.path.abspath(path) + logger.info(f"Loading PyTorch weights from {pt_path}") + if pt_path.endswith(".safetensors"): + state_dict = safe_load_file(pt_path) + else: + weights_only_kwarg = {"weights_only": True} + state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) + + pt_state_dict.update(state_dict) + + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters") + + return load_pytorch_weights_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=output_loading_info, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + +def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False): + """Load pytorch checkpoints in a TF 2.0 model""" + pt_state_dict = pt_model.state_dict() + + return load_pytorch_weights_in_tf2_model( + tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys + ) + + +def load_pytorch_weights_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, +): + """Load pytorch state_dict in a TF 2.0 model.""" + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision + pt_state_dict = { + k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() + } + return load_pytorch_state_dict_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=output_loading_info, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + +def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" + f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {class_name} from a PyTorch model trained on another task or with another architecture" + " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" + f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect" + " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the" + f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" + " down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {class_name} were initialized from the PyTorch model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {class_name} for predictions without further training." + ) + + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {class_name} were not initialized from the model checkpoint" + f" are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + +def load_pytorch_state_dict_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, + ignore_mismatched_sizes=False, + skip_logger_warnings=False, +): + """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading + safetensors archive created with the safe_open() function.""" + import tensorflow as tf + + if tf_inputs is None: + tf_inputs = tf_model.dummy_inputs + + if _prefix is None: + _prefix = "" + if tf_inputs: + with tf.name_scope(_prefix): + tf_model(tf_inputs, training=False) # Make sure model is built + # Convert old format to new format if needed from a PyTorch state_dict + tf_keys_to_pt_keys = {} + for key in pt_state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if "running_var" in key: + new_key = key.replace("running_var", "moving_variance") + if "running_mean" in key: + new_key = key.replace("running_mean", "moving_mean") + + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + key_components = key.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + new_key = ".".join(key_components) + + if new_key is None: + new_key = key + tf_keys_to_pt_keys[new_key] = key + + # Matt: All TF models store the actual model stem in a MainLayer class, including the base model. + # In PT, the derived models (with heads) use the base model class as the stem instead, + # and there is no MainLayer class. This means that TF base classes have one + # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that. + start_prefix_to_remove = "" + if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()): + start_prefix_to_remove = tf_model.base_model_prefix + "." + + symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights + tf_loaded_numel = 0 + all_pytorch_weights = set(tf_keys_to_pt_keys.keys()) + missing_keys = [] + mismatched_keys = [] + is_safetensor_archive = hasattr(pt_state_dict, "get_tensor") + for symbolic_weight in symbolic_weights: + sw_name = symbolic_weight.name + name, transpose = convert_tf_weight_name_to_pt_weight_name( + sw_name, + start_prefix_to_remove=start_prefix_to_remove, + tf_weight_shape=symbolic_weight.shape, + name_scope=_prefix, + ) + if tf_to_pt_weight_rename is not None: + aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing + for alias in aliases: # The aliases are in priority order, take the first one that matches + if alias in tf_keys_to_pt_keys: + name = alias + break + else: + # If none of the aliases match, just use the first one (it'll be reported as missing) + name = aliases[0] + + # Find associated numpy array in pytorch model state dict + if name not in tf_keys_to_pt_keys: + if allow_missing_keys: + missing_keys.append(name) + continue + elif tf_model._keys_to_ignore_on_load_missing is not None: + # authorized missing keys don't have to be loaded + if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing): + continue + raise AttributeError(f"{name} not found in PyTorch model") + state_dict_name = tf_keys_to_pt_keys[name] + if is_safetensor_archive: + array = pt_state_dict.get_tensor(state_dict_name) + else: + array = pt_state_dict[state_dict_name] + try: + array = apply_transpose(transpose, array, symbolic_weight.shape) + except tf.errors.InvalidArgumentError as e: + if not ignore_mismatched_sizes: + error_msg = str(e) + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise tf.errors.InvalidArgumentError(error_msg) + else: + mismatched_keys.append((name, array.shape, symbolic_weight.shape)) + continue + + tf_loaded_numel += tensor_size(array) + + symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype)) + del array # Immediately free memory to keep peak usage as low as possible + all_pytorch_weights.discard(name) + + logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.") + + unexpected_keys = list(all_pytorch_weights) + + if tf_model._keys_to_ignore_on_load_missing is not None: + for pat in tf_model._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + if tf_model._keys_to_ignore_on_load_unexpected is not None: + for pat in tf_model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if not skip_logger_warnings: + _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + return tf_model, loading_info + + return tf_model + + +def load_sharded_pytorch_safetensors_in_tf2_model( + tf_model, + safetensors_shards, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, + ignore_mismatched_sizes=False, +): + all_loading_infos = [] + for shard in safetensors_shards: + with safe_open(shard, framework="tf") as safetensors_archive: + tf_model, loading_info = load_pytorch_state_dict_in_tf2_model( + tf_model, + safetensors_archive, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=True, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ignore_mismatched_sizes=ignore_mismatched_sizes, + skip_logger_warnings=True, # We will emit merged warnings at the end + ) + all_loading_infos.append(loading_info) + # Now we just need to merge the loading info + # Keys are missing only if they're missing in *every* shard + missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos])) + # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard + unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) + mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) + + _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + return tf_model, loading_info + + return tf_model + + +##################### +# TF 2.0 => PyTorch # +##################### + + +def load_tf2_checkpoint_in_pytorch_model( + pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False +): + """ + Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see + https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). + """ + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + import transformers + + from .modeling_tf_utils import load_tf_weights + + logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}") + + # Instantiate and load the associated TF 2.0 model + tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning + tf_model_class = getattr(transformers, tf_model_class_name) + tf_model = tf_model_class(pt_model.config) + + if tf_inputs is None: + tf_inputs = tf_model.dummy_inputs + + if tf_inputs is not None: + tf_model(tf_inputs, training=False) # Make sure model is built + + load_tf_weights(tf_model, tf_checkpoint_path) + + return load_tf2_model_in_pytorch_model( + pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False): + """Load TF 2.0 model in a pytorch model""" + weights = tf_model.weights + + return load_tf2_weights_in_pytorch_model( + pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False): + """Load TF2.0 symbolic weights in a PyTorch model""" + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights} + return load_tf2_state_dict_in_pytorch_model( + pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False): + import torch + + new_pt_params_dict = {} + current_pt_params_dict = dict(pt_model.named_parameters()) + + # Make sure we are able to load PyTorch base models as well as derived models (with heads) + # TF models always have a prefix, some of PyTorch models (base ones) don't + start_prefix_to_remove = "" + if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()): + start_prefix_to_remove = pt_model.base_model_prefix + "." + + # Build a map from potential PyTorch weight names to TF 2.0 Variables + tf_weights_map = {} + for name, tf_weight in tf_state_dict.items(): + pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( + name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape + ) + tf_weights_map[pt_name] = (tf_weight, transpose) + + all_tf_weights = set(tf_weights_map.keys()) + loaded_pt_weights_data_ptr = {} + missing_keys_pt = [] + for pt_weight_name, pt_weight in current_pt_params_dict.items(): + # Handle PyTorch shared weight ()not duplicated in TF 2.0 + if pt_weight.data_ptr() in loaded_pt_weights_data_ptr: + new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] + continue + + pt_weight_name_to_check = pt_weight_name + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + key_components = pt_weight_name.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + pt_weight_name_to_check = ".".join(key_components) + + # Find associated numpy array in pytorch model state dict + if pt_weight_name_to_check not in tf_weights_map: + if allow_missing_keys: + missing_keys_pt.append(pt_weight_name) + continue + + raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model") + + array, transpose = tf_weights_map[pt_weight_name_to_check] + + array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False) + + if numpy.isscalar(array): + array = numpy.array(array) + if not is_torch_tensor(array) and not is_numpy_array(array): + array = array.numpy() + if is_numpy_array(array): + # Convert to torch tensor + array = torch.from_numpy(array) + + new_pt_params_dict[pt_weight_name] = array + loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array + all_tf_weights.discard(pt_weight_name) + + missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) + missing_keys += missing_keys_pt + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if pt_model._keys_to_ignore_on_load_missing is not None: + for pat in pt_model._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if pt_model._keys_to_ignore_on_load_unexpected is not None: + for pat in pt_model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the TF 2.0 model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS" + f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " TFBertForSequenceClassification model)." + ) + else: + logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {pt_model.__class__.__name__} for predictions without further training." + ) + + logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}") + + if output_loading_info: + loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} + return pt_model, loading_info + + return pt_model diff --git a/modeling_tf_utils.py b/modeling_tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8264f48818cb38bf6eacc7debabeae5793a44f85 --- /dev/null +++ b/modeling_tf_utils.py @@ -0,0 +1,3555 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF general model utils.""" + +from __future__ import annotations + +import functools +import gc +import inspect +import json +import os +import pickle +import re +import warnings +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import h5py +import numpy as np +import tensorflow as tf +from packaging.version import parse + +from . import DataCollatorWithPadding, DefaultDataCollator +from .activations_tf import get_tf_activation +from .configuration_utils import PretrainedConfig +from .dynamic_module_utils import custom_object_save +from .generation import GenerationConfig, TFGenerationMixin +from .tf_utils import ( + convert_batch_encoding, + expand_1d, + load_attributes_from_hdf5_group, + save_attributes_to_hdf5_group, + shape_list, +) +from .utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_INDEX_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ModelOutput, + PushToHubMixin, + cached_file, + download_url, + find_labels, + has_file, + is_offline_mode, + is_remote_url, + is_safetensors_available, + is_tf_symbolic_tensor, + logging, + requires_backends, + working_or_temp_dir, +) +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files + + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.tensorflow import save_file as safe_save_file + +if TYPE_CHECKING: + from . import PreTrainedTokenizerBase + +logger = logging.get_logger(__name__) + +if "TF_USE_LEGACY_KERAS" not in os.environ: + os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2 +elif os.environ["TF_USE_LEGACY_KERAS"] != "1": + logger.warning( + "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " + "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models." + ) + +try: + import tf_keras as keras + from tf_keras import backend as K +except (ModuleNotFoundError, ImportError): + import keras + from keras import backend as K + + if parse(keras.__version__).major > 2: + raise ValueError( + "Your currently installed version of Keras is Keras 3, but this is not yet supported in " + "Transformers. Please install the backwards-compatible tf-keras package with " + "`pip install tf-keras`." + ) + + +tf_logger = tf.get_logger() + +TFModelInputType = Union[ + List[tf.Tensor], + List[np.ndarray], + Dict[str, tf.Tensor], + Dict[str, np.ndarray], + tf.Tensor, + np.ndarray, +] + + +def dummy_loss(y_true, y_pred): + if y_pred.shape.rank <= 1: + return y_pred + else: + reduction_axes = list(range(1, y_pred.shape.rank)) + return tf.reduce_mean(y_pred, axis=reduction_axes) + + +class TFModelUtilsMixin: + """ + A few utilities for `keras.Model`, to be used as a mixin. + """ + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get the number of (optionally, trainable) parameters in the model. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + Returns: + `int`: The number of parameters. + """ + if only_trainable: + return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables)) + else: + return self.count_params() + + +def keras_serializable(cls): + """ + Decorate a Keras Layer class to support Keras serialization. + + This is done by: + + 1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at + serialization time. + 2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and + convert it to a config object for the actual layer initializer. + 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not + need to be supplied in `custom_objects` in the call to `keras.models.load_model`. + + Args: + cls (a `keras.layers.Layers subclass`): + Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its + initializer. + + Returns: + The same class object, with modifications for Keras deserialization. + """ + initializer = cls.__init__ + + config_class = getattr(cls, "config_class", None) + if config_class is None: + raise AttributeError("Must set `config_class` to use @keras_serializable") + + @functools.wraps(initializer) + def wrapped_init(self, *args, **kwargs): + config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None) + + if isinstance(config, dict): + config = config_class.from_dict(config) + initializer(self, config, *args, **kwargs) + elif isinstance(config, PretrainedConfig): + if len(args) > 0: + initializer(self, *args, **kwargs) + else: + initializer(self, config, *args, **kwargs) + else: + raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)") + + self._config = config + self._kwargs = kwargs + + cls.__init__ = wrapped_init + + if not hasattr(cls, "get_config"): + raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses") + if hasattr(cls.get_config, "_is_default"): + + def get_config(self): + cfg = super(cls, self).get_config() + cfg["config"] = self._config.to_dict() + cfg.update(self._kwargs) + return cfg + + cls.get_config = get_config + + cls._keras_serializable = True + if hasattr(keras.utils, "register_keras_serializable"): + cls = keras.utils.register_keras_serializable()(cls) + return cls + + +class TFCausalLanguageModelingLoss: + """ + Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 affect the loss + active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) + reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) + labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) + return loss_fn(labels, reduced_logits) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_loss = loss_fn(tf.nn.relu(labels), logits) + # make sure only labels that are not equal to -100 affect the loss + loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype) + masked_loss = unmasked_loss * loss_mask + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) + return tf.reshape(reduced_masked_loss, (1,)) + + +class TFQuestionAnsweringLoss: + """ + Loss function suitable for question answering. + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + start_loss = loss_fn(labels["start_position"], logits[0]) + end_loss = loss_fn(labels["end_position"], logits[1]) + + return (start_loss + end_loss) / 2.0 + + +class TFTokenClassificationLoss: + """ + Loss function suitable for token classification. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA + if tf.math.reduce_any(labels == -1): + tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") + + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 + # are taken into account as loss + if tf.math.reduce_any(labels == -1): + tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") + active_loss = tf.reshape(labels, (-1,)) != -1 + else: + active_loss = tf.reshape(labels, (-1,)) != -100 + reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) + labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) + + return loss_fn(labels, reduced_logits) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_loss = loss_fn(tf.nn.relu(labels), logits) + # make sure only labels that are not equal to -100 or -1 + # are taken into account as loss + loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype) + # Avoid possible division by zero later + # Masked positions will have a loss of NaN because -100 and -1 are not valid labels + masked_loss = unmasked_loss * loss_mask + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) + return tf.reshape(reduced_masked_loss, (1,)) + + +class TFSequenceClassificationLoss: + """ + Loss function suitable for sequence classification. + """ + + def hf_compute_loss(self, labels, logits): + if logits.shape.rank == 1 or logits.shape[1] == 1: + loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE) + if labels.shape.rank == 1: + # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that + labels = tf.expand_dims(labels, axis=-1) + else: + loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.NONE + ) + + return loss_fn(labels, logits) + + +class TFMultipleChoiceLoss: + """Loss function suitable for multiple choice tasks.""" + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + return loss_fn(labels, logits) + + +class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): + """ + Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + +class TFNextSentencePredictionLoss: + """ + Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 + # are taken into account as loss + next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) + next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss) + next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss) + + return loss_fn(next_sentence_label, next_sentence_reduced_logits) + + # make sure only labels that are not equal to -100 + # are taken into account as loss + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits) + ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype) + # Just zero out samples where label is -100, no reduction + masked_ns_loss = unmasked_ns_loss * ns_loss_mask + + return masked_ns_loss + + +def booleans_processing(config, **kwargs): + """ + Process the input booleans of each model. + + Args: + config ([`PretrainedConfig`]): + The config of the running model. + **kwargs: + The boolean parameters + + Returns: + A dictionary with the proper values for each boolean + """ + final_booleans = {} + + # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has + # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`) + if "output_attentions" in kwargs: + final_booleans["output_attentions"] = ( + kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions + ) + final_booleans["output_hidden_states"] = ( + kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states + ) + final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict + + if "use_cache" in kwargs: + final_booleans["use_cache"] = ( + kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None) + ) + return final_booleans + + +def unpack_inputs(func): + """ + Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables + downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input + (common case in Keras). + + Args: + func (`callable`): + The callable function of the TensorFlow model. + + + Returns: + A callable that wraps the original `func` with the behavior described above. + """ + + original_signature = inspect.signature(func) + + @functools.wraps(func) + def run_call_with_unpacked_inputs(self, *args, **kwargs): + # isolates the actual `**kwargs` for the decorated function + kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)} + fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call} + fn_args_and_kwargs.update({"kwargs_call": kwargs_call}) + + # move any arg into kwargs, if they exist + fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) + + # Encoder Decoder models delegate the application of the configuration options to their inner models. + if "EncoderDecoder" in self.__class__.__name__: + config = None + else: + config = self.config + + unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs) + return func(self, **unpacked_inputs) + + # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This + # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below + # Keras would attempt to check the first argument against the literal signature of the wrapper. + run_call_with_unpacked_inputs.__signature__ = original_signature + + return run_call_with_unpacked_inputs + + +def input_processing(func, config, **kwargs): + """ + Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input + has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32', + name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training. + + Args: + func (`callable`): + The callable function of the TensorFlow model. + config ([`PretrainedConfig`]): + The config of the running model. + **kwargs: + The inputs of the model. + + Returns: + Two lists, one for the missing layers, and another one for the unexpected layers. + """ + signature = dict(inspect.signature(func).parameters) + has_kwargs = bool(signature.pop("kwargs", None)) + signature.pop("self", None) + parameter_names = list(signature.keys()) + main_input_name = parameter_names[0] + main_input = kwargs.pop(main_input_name, None) + output = {} + allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) + + if "inputs" in kwargs["kwargs_call"]: + warnings.warn( + "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.", + FutureWarning, + ) + + output["input_ids"] = kwargs["kwargs_call"].pop("inputs") + + if "decoder_cached_states" in kwargs["kwargs_call"]: + warnings.warn( + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" + " `past_key_values` instead.", + FutureWarning, + ) + output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") + + if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names: + warnings.warn( + "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`" + " instead.", + FutureWarning, + ) + kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past") + elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names: + kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values") + + if has_kwargs: + output["kwargs"] = kwargs.pop("kwargs_call", {}) + else: + if len(kwargs["kwargs_call"]) > 0: + raise ValueError( + "The following keyword arguments are not supported by this model:" + f" {list(kwargs['kwargs_call'].keys())}." + ) + kwargs.pop("kwargs_call") + + for k, v in kwargs.items(): + if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None: + output[k] = v + else: + raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") + + if isinstance(main_input, (tuple, list)): + for i, input in enumerate(main_input): + # EagerTensors don't allow to use the .name property so we check for a real Tensor + if is_tf_symbolic_tensor(input): + # Tensor names have always the pattern `name:id` then we check only the + # `name` part + tensor_name = input.name.split(":")[0] + + if tensor_name in parameter_names: + output[tensor_name] = input + else: + output[parameter_names[i]] = input + elif isinstance(input, allowed_types) or input is None: + output[parameter_names[i]] = input + else: + raise ValueError( + f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" + f" {parameter_names[i]}." + ) + elif isinstance(main_input, Mapping): + if "inputs" in main_input: + warnings.warn( + "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`" + " instead.", + FutureWarning, + ) + + output["input_ids"] = main_input.pop("inputs") + + if "decoder_cached_states" in main_input: + warnings.warn( + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" + " `past_key_values` instead.", + FutureWarning, + ) + output["past_key_values"] = main_input.pop("decoder_cached_states") + + for k, v in dict(main_input).items(): + if isinstance(v, allowed_types) or v is None: + output[k] = v + elif k not in parameter_names and "args" not in parameter_names: + logger.warning( + f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored." + ) + continue + else: + raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") + else: + if tf.is_tensor(main_input) or main_input is None: + output[main_input_name] = main_input + else: + raise ValueError( + f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for" + f" {main_input_name}." + ) + + # Populates any unspecified argument with their default value, according to the signature. + for name in parameter_names: + if name not in list(output.keys()) and name != "args": + output[name] = kwargs.pop(name, signature[name].default) + + # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) + # So to respect the proper output we have to add this exception + if "args" in output: + if output["args"] is not None and is_tf_symbolic_tensor(output["args"]): + tensor_name = output["args"].name.split(":")[0] + output[tensor_name] = output["args"] + else: + # `args` in this case is always the first parameter, then `input_ids` + output["input_ids"] = output["args"] + + del output["args"] + + if "kwargs" in output: + del output["kwargs"] + + cast_output = {} + for key, val in output.items(): + if isinstance(val, tf.Tensor) and val.dtype == tf.int64: + cast_output[key] = tf.cast(val, tf.int32) + elif isinstance(val, np.ndarray) and val.dtype == np.int64: + cast_output[key] = val.astype(np.int32) + else: + cast_output[key] = val + + output = cast_output + del cast_output + + if config is not None: + boolean_dict = { + k: v + for k, v in output.items() + if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] + } + + output.update( + booleans_processing( + config=config, + **boolean_dict, + ) + ) + + return output + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(tf.float32) + 4 + ``` + """ + if dtype == tf.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)$", dtype.name) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def strip_model_name_and_prefix(name, _prefix=None): + if _prefix is not None and name.startswith(_prefix): + name = name[len(_prefix) :] + if name.startswith("/"): + name = name[1:] + if "model." not in name and len(name.split("/")) > 1: + name = "/".join(name.split("/")[1:]) + return name + + +def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [] + current_block = [] + current_block_size = 0 + total_size = 0 + + for item in weights: + weight_size = item.numpy().size * dtype_byte_size(item.dtype) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = [] + current_block_size = 0 + + current_block.append(item) + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for weight in shard: + weight_name = weight.name + weight_map[weight_name] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None): + """ + This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load + the TF weights from the shard file accordingly to their names and shapes. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`keras.models.Model`): The model in which to load the checkpoint. + shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. + ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): + Whether or not to ignore the mismatch between the sizes + strict (`bool`, *optional*, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + + # Load the index + unexpected_keys = set() + saved_keys = set() + mismatched_keys = set() + + # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load + # the weight, we have to get rid of the first prefix of the name of the layer. + model_keys = set() + model_layer_map = {} + for i, k in enumerate(model.weights): + layer_name = k.name + if _prefix is not None and layer_name.startswith(_prefix): + layer_name = layer_name[len(_prefix) :] + layer_name = layer_name.lstrip("/") + if not ("model." in layer_name or len(layer_name.split("/")) == 1): + layer_name = "/".join(layer_name.split("/")[1:]) + model_keys.add(layer_name) + model_layer_map[layer_name] = i + + for shard_file in shard_files: + saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard( + model, + model_layer_map, + shard_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=_prefix, + ) + saved_keys.update(saved_weight_names_set) + unexpected_keys.update(unexpected_keys_set) + mismatched_keys.update(mismatched_keys_set) + gc.collect() + + missing_keys = model_keys - saved_keys + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + return missing_keys, unexpected_keys, mismatched_keys + + +def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + """ + Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors. + Handles missing keys and unexpected keys. + + Args: + model (`keras.models.Model`): Model in which the weights are loaded + model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model. + resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys + + Returns: + `keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the + shard file), one for the mismatched layers, and another one for the unexpected layers. + """ + saved_weight_names_set = set() + saved_weights = {} + mismatched_keys = set() + unexpected_keys = set() + # Read the H5 file + try: + with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: + # Retrieve the name of each layer from the H5 file + saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")) + weight_value_tuples = [] + + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] + for layer_name in saved_h5_model_layers_name: + h5_layer_object = sharded_checkpoint_file[layer_name] + saved_weights[layer_name] = np.asarray(h5_layer_object) + + saved_weight_names_set.add(layer_name) + + if layer_name not in model_layer_map: + unexpected_keys.add(layer_name) + else: + symbolic_weight = model.weights[model_layer_map[layer_name]] + + saved_weight_value = saved_weights[layer_name] + # If the current weight is found + if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) + except ValueError as e: + if ignore_mismatched_sizes: + mismatched_keys.add( + (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) + ) + continue + else: + raise e + else: + array = saved_weight_value + + # We create the tuple that will be loaded and add it to the final list + weight_value_tuples.append((symbolic_weight, array)) + + K.batch_set_value(weight_value_tuples) + + return saved_weight_names_set, unexpected_keys, mismatched_keys + + except Exception as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained" + " model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' " + f"at '{resolved_archive_file}'. " + "If you tried to load a TF model from a sharded checkpoint, you should try converting the model " + "by loading it in pytorch and saving it localy. A convertion script should be realeased soon." + ) + + +def load_tf_sharded_weights_from_safetensors( + model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None +): + """ + This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint. + Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and + shapes. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`keras.models.Model`): The model in which to load the checkpoint. + shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. + ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): + Whether or not to ignore the mismatch between the sizes + strict (`bool`, *optional*, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + + # Load the index + unexpected_keys = set() + all_missing_keys = [] + mismatched_keys = set() + + for shard_file in shard_files: + missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors( + model, + shard_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=_prefix, + ) + all_missing_keys.append(set(missing_layers)) + unexpected_keys.update(unexpected_layers) + mismatched_keys.update(mismatched_layers) + gc.collect() + missing_keys = set.intersection(*all_missing_keys) + + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + return missing_keys, unexpected_keys, mismatched_keys + + +def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + """ + Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and + shapes. + + Args: + model (`keras.models.Model`): + The model to load the weights into. + resolved_archive_file (`str`): + The location of the H5 file. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to ignore weights with shapes that don't match between the checkpoint of the model. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + if resolved_archive_file.endswith(".safetensors"): + load_function = load_tf_weights_from_safetensors + else: + load_function = load_tf_weights_from_h5 + + return load_function( + model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix + ) + + +def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + mismatched_layers = [] + + # Read the H5 file + with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: + # Retrieve the name of each layer from the H5 file + saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")) + + # Find the missing layers from the high level list of layers + missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name) + + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers}) + saved_weight_names_set = set() + symbolic_weights_names = set() + weight_value_tuples = [] + + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] + for layer in model.layers: + # if layer_name from the H5 file belongs to the layers from the instantiated model + if layer.name in saved_h5_model_layers_name: + # Get the H5 layer object from its name + h5_layer_object = sharded_checkpoint_file[layer.name] + # Get all the weights as a list from the layer object + symbolic_weights = layer.trainable_weights + layer.non_trainable_weights + saved_weights = {} + + # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} + # And a set with only the names + for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): + # TF names always start with the model name so we ignore it + name = "/".join(weight_name.split("/")[1:]) + + if _prefix is not None: + name = _prefix + "/" + name + + saved_weights[name] = np.asarray(h5_layer_object[weight_name]) + + # Add the updated name to the final list for computing missing/unexpected values + saved_weight_names_set.add(name) + + # Loop over each weights from the instantiated model and compare with the weights from the H5 file + for symbolic_weight in symbolic_weights: + # TF names always start with the model name so we ignore it + if _prefix is not None: + delimeter = len(_prefix.split("/")) + symbolic_weight_name = "/".join( + symbolic_weight.name.split("/")[:delimeter] + + symbolic_weight.name.split("/")[delimeter + 1 :] + ) + else: + symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) + + # here we check if the current weight is among the weights from the H5 file + # If yes, get the weight_value of the corresponding weight from the H5 file + # If not, make the value to None + saved_weight_value = saved_weights.get(symbolic_weight_name, None) + + # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's + # `model.shared/embeddings:0` are stored as `model.shared/weights:0`) + if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"): + symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0" + saved_weight_value = saved_weights.get(symbolic_weight_name, None) + + # Add the updated name to the final list for computing missing/unexpected values + symbolic_weights_names.add(symbolic_weight_name) + + # If the current weight is found + if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) + except ValueError as e: + if ignore_mismatched_sizes: + mismatched_layers.append( + (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) + ) + continue + else: + raise e + else: + array = saved_weight_value + + # We create the tuple that will be loaded and add it to the final list + weight_value_tuples.append((symbolic_weight, array)) + + # Load all the weights + K.batch_set_value(weight_value_tuples) + + # Compute the missing and unexpected layers + missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) + unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + + return missing_layers, unexpected_layers, mismatched_layers + + +def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + # Read the safetensors file + with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: + mismatched_layers = [] + weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights] + loaded_weight_names = list(safetensors_archive.keys()) + # Find the missing layers from the high level list of layers + missing_layers = list(set(weight_names) - set(loaded_weight_names)) + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) + + for weight in model.weights: + weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix) + if weight_name in loaded_weight_names: + weight_value = safetensors_archive.get_tensor(weight_name) + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(weight) != weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + weight_value = tf.reshape(weight_value, K.int_shape(weight)) + except (ValueError, tf.errors.InvalidArgumentError) as e: + if ignore_mismatched_sizes: + mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight))) + continue + else: + raise e + + K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor + return missing_layers, unexpected_layers, mismatched_layers + + +def init_copy_embeddings(old_embeddings, new_num_tokens): + r""" + This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case + new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be + kept or not. Example: + + - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4] + + - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1] + - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5] + + - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4] + """ + old_num_tokens, old_embedding_dim = shape_list(old_embeddings) + size_diff = new_num_tokens - old_num_tokens + + # initialize new embeddings + # Copy token embeddings from the previous ones + if tf.math.greater(size_diff, 0): + # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size + # and we create a mask to properly identify the padded values and be replaced by the values of the newly created + # embeddings + current_weights = tf.pad( + old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1 + ) + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True) + mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False) + else: + # if the new size if lower than the old one, we take the current embeddings until the new size + current_weights = tf.slice( + old_embeddings.value(), + tf.convert_to_tensor([0, 0]), + tf.convert_to_tensor([new_num_tokens, old_embedding_dim]), + ) + mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True) + + return mask, current_weights + + +class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin): + r""" + Base class for all TF models. + + [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + _auto_class = None + _using_dummy_loss = None + _label_to_output_map = None + + # a list of re pattern of tensor names to ignore from the model when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_missing = None + # a list of re pattern of tensor names to ignore from the weights when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_unexpected = None + _requires_load_weight_prefix = False + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummies = {} + for key, spec in self.input_signature.items(): + # 2 is the most correct arbitrary size. I will not be taking questions + dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] + if spec.shape[0] is None: + # But let's make the batch size 1 to save memory anyway + dummy_shape[0] = 1 + dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype) + if key == "token_type_ids": + # Some models have token_type_ids but with a vocab_size of 1 + dummies[key] = tf.zeros_like(dummies[key]) + if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters: + if "encoder_hidden_states" not in dummies: + if self.main_input_name == "input_ids": + dummies["encoder_hidden_states"] = tf.ones( + shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states" + ) + else: + raise NotImplementedError( + "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!" + ) + return dummies + + def build_in_name_scope(self): + with tf.name_scope(self.name): + self.build(input_shape=None) + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a TensorFlow model. + """ + return "tf" + + def build(self, input_shape=None): + pass # This is just here to make sure we don't call the superclass build() + + def __init__(self, config, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + if not isinstance(config, PretrainedConfig): + raise TypeError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + self.config = config + self.name_or_path = config.name_or_path + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + self._set_save_spec(self.input_signature) + + def get_config(self): + return self.config.to_dict() + + @functools.wraps(keras.Model.fit) + def fit(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().fit(*args, **kwargs) + + @functools.wraps(keras.Model.train_on_batch) + def train_on_batch(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().train_on_batch(*args, **kwargs) + + @functools.wraps(keras.Model.test_on_batch) + def test_on_batch(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().test_on_batch(*args, **kwargs) + + @functools.wraps(keras.Model.predict_on_batch) + def predict_on_batch(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().predict_on_batch(*args, **kwargs) + + @functools.wraps(keras.Model.predict) + def predict(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().predict(*args, **kwargs) + + @functools.wraps(keras.Model.evaluate) + def evaluate(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().evaluate(*args, **kwargs) + + @classmethod + def from_config(cls, config, **kwargs): + if isinstance(config, PretrainedConfig): + return cls._from_config(config, **kwargs) + return cls._from_config(cls.config_class.from_dict(config, **kwargs)) + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + + Returns: + `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.shape.rank == 1: + head_mask = head_mask[None, None, :, None, None] + head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0) + elif head_mask.shape.rank == 2: + head_mask = head_mask[:, None, :, None, None] + assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility + return head_mask + + @tf.function + def serving(self, inputs): + """ + Args: + Method used for serving the model. Does not have a specific signature, but will be specialized as concrete + functions when saving with `save_pretrained`. + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) + + @property + def input_signature(self) -> Dict[str, tf.TensorSpec]: + """ + This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected + shape and dtype for model inputs. It is used for both serving and for generating dummy inputs. + """ + model_inputs = list(inspect.signature(self.call).parameters) + sig = {} + if "input_ids" in model_inputs: + if self.__class__.__name__.endswith("ForMultipleChoice"): + text_dims = 3 + else: + text_dims = 2 + for input_name in ( + "input_ids", + "attention_mask", + "token_type_ids", + "decoder_input_ids", + "decoder_attention_mask", + ): + if input_name in model_inputs: + sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name) + if "pixel_values" in model_inputs: + pixel_values_shape = [None, None, None, None] + if hasattr(self.config, "vision_config"): + vision_config = self.config.vision_config + else: + vision_config = self.config + if hasattr(vision_config, "num_channels"): + pixel_values_shape[1] = vision_config.num_channels + else: + raise NotImplementedError( + "Could not infer number of channels from config, please override input_signature to specify input shapes." + ) + if hasattr(vision_config, "image_size"): + pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size + elif hasattr(vision_config, "input_size"): + pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size + else: + raise NotImplementedError( + "Could not infer input image shape from config, please override input_signature to specify input shapes." + ) + sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values") + if "input_features" in model_inputs: + raise NotImplementedError("Audio models need a manually defined input_signature") + return sig + + def serving_output(self, output): + """ + Prepare the output of the saved model. Can be overridden if specific serving modifications are required. + """ + if not isinstance(output, ModelOutput): + return output + for key in output: + if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False): + output[key] = None + elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False): + output[key] = None + elif key == "past_key_values" and not getattr(self.config, "use_cache", False): + output[key] = None + elif key == "cross_attentions" and not ( + getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False) + ): + output[key] = None + if isinstance(output[key], (tuple, list)): + try: + output[key] = tf.convert_to_tensor(output[key]) + except (ValueError, tf.errors.InvalidArgumentError): + pass # Layers may not have the same dimensions + return output + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternativelly, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + def get_input_embeddings(self) -> keras.layers.Layer: + """ + Returns the model's input embeddings layer. + + Returns: + `tf.Variable`: The embeddings layer mapping vocabulary to hidden states. + """ + main_layer = getattr(self, self.base_model_prefix, self) + + if main_layer is not self: + return main_layer.get_input_embeddings() + else: + raise NotImplementedError + + def _save_checkpoint(self, checkpoint_dir, epoch): + if not os.path.isdir(checkpoint_dir): + os.mkdir(checkpoint_dir) + # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer + # state for us, because it requires special handling for objects like custom losses, which we use + # internally and which users are likely to use too + weights_path = os.path.join(checkpoint_dir, "weights.h5") + self.save_weights(weights_path) + extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()} + extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle") + with open(extra_data_path, "wb") as f: + pickle.dump(extra_data, f) + + def prepare_tf_dataset( + self, + dataset: "datasets.Dataset", # noqa:F821 + batch_size: int = 8, + shuffle: bool = True, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + collate_fn: Optional[Callable] = None, + collate_fn_args: Optional[Dict[str, Any]] = None, + drop_remainder: Optional[bool] = None, + prefetch: bool = True, + ): + """ + Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is + designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without + further modification. The method will drop columns from the dataset if they don't match input names for the + model. If you want to specify the column names to return rather than using the names that match this model, we + recommend using `Dataset.to_tf_dataset()` instead. + + Args: + dataset (`Any`): + A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`. + batch_size (`int`, *optional*, defaults to 8): + The size of batches to return. + shuffle (`bool`, defaults to `True`): + Whether to return samples from the dataset in random order. Usually `True` for training datasets and + `False` for validation/test datasets. + tokenizer ([`PreTrainedTokenizerBase`], *optional*): + A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific + `collate_fn` is passed instead. + collate_fn (`Callable`, *optional*): + A function that collates samples from the dataset into a single batch. Defaults to + `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is + passed. + collate_fn_args (`Dict[str, Any]`, *optional*): + A dict of arguments to pass to the `collate_fn` alongside the list of samples. + drop_remainder (`bool`, *optional*): + Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults + to the same setting as `shuffle`. + prefetch (`bool`, defaults to `True`): + Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for + performance, but can be disabled in edge cases. + + + Returns: + `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API. + """ + requires_backends(self, ["datasets"]) + import datasets + + if collate_fn is None: + if tokenizer is None: + collate_fn = DefaultDataCollator(return_tensors="np") + else: + collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np") + if collate_fn_args is None: + collate_fn_args = {} + + if not isinstance(dataset, datasets.Dataset): + raise TypeError("Dataset argument should be a datasets.Dataset!") + model_inputs = list(inspect.signature(self.call).parameters) + model_labels = find_labels(self.__class__) + if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()): + output_signature, _ = dataset._get_output_signature( + dataset, + batch_size=None, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + cols_to_retain=model_inputs, + ) + else: + # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain` + # argument. We should remove this once the minimum supported version of datasets is > 2.3.2 + unwanted_columns = [ + feature + for feature in dataset.features + if feature not in model_inputs and feature not in ("label_ids", "label") + ] + dataset = dataset.remove_columns(unwanted_columns) + output_signature, _ = dataset._get_output_signature( + dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args + ) + output_columns = list(output_signature.keys()) + feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels] + label_cols = [col for col in output_columns if col in model_labels] + + # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols` + # were a single element list, the returned element spec would be a single element. Now, passing [feature] + # will return a dict structure {"feature": feature}, and passing a single string will return a single element. + feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols + label_cols = label_cols[0] if len(label_cols) == 1 else label_cols + + if drop_remainder is None: + drop_remainder = shuffle + tf_dataset = dataset.to_tf_dataset( + columns=feature_cols, + label_cols=label_cols, + batch_size=batch_size, + shuffle=shuffle, + drop_remainder=drop_remainder, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + prefetch=prefetch, + ) + return tf_dataset + + def compile( + self, + optimizer="rmsprop", + loss="auto_with_warning", + metrics=None, + loss_weights=None, + weighted_metrics=None, + run_eagerly=None, + steps_per_execution=None, + **kwargs, + ): + """ + This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss + function themselves. + """ + if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility + logger.info( + "No loss specified in compile() - the model's internal loss computation will be used as the " + "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! " + "To disable this behaviour please pass a loss argument, or explicitly pass " + "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to " + "get the internal loss without printing this info string." + ) + loss = "auto" + if loss == "auto": + loss = dummy_loss + self._using_dummy_loss = True + else: + self._using_dummy_loss = False + parent_args = list(inspect.signature(keras.Model.compile).parameters.keys()) + # This argument got renamed, we need to support both versions + if "steps_per_execution" in parent_args: + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + steps_per_execution=steps_per_execution, + **kwargs, + ) + else: + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + experimental_steps_per_execution=steps_per_execution, + **kwargs, + ) + + def compute_loss(self, *args, **kwargs): + if hasattr(keras.Model, "compute_loss"): + # This will be true in TF 2.8 or greater + return super().compute_loss(*args, **kwargs) + else: + warnings.warn( + "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss " + "method added in TF 2.8. If you want the original HF compute_loss, please call " + "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, " + "calling compute_loss() will get the Keras method instead.", + FutureWarning, + ) + return self.hf_compute_loss(*args, **kwargs) + + def get_label_to_output_name_mapping(self): + arg_names = list(inspect.signature(self.call).parameters) + if self._label_to_output_map is not None: + return self._label_to_output_map + elif "start_positions" in arg_names: + return {"start_positions": "start_logits", "end_positions": "end_logits"} + elif "sentence_order_label" in arg_names: + return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"} + elif "next_sentence_label" in arg_names: + return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"} + elif "mc_labels" in arg_names: + return {"labels": "logits", "mc_labels": "mc_logits"} + else: + return {} + + def train_step(self, data): + """ + A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models + and supports directly training on the loss output head. In addition, it ensures input keys are copied to the + labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure + that they are available to the model during the forward pass. + """ + + # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` + arg_names = list(inspect.signature(self.call).parameters) + label_kwargs = find_labels(self.__class__) + label_to_output = self.get_label_to_output_name_mapping() + output_to_label = {val: key for key, val in label_to_output.items()} + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer TF train steps leave this out + data = expand_1d(data) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify + # them during input/label pre-processing. This avoids surprising the user by wrecking their data. + # In addition, modifying mutable Python inputs makes XLA compilation impossible. + if isinstance(x, dict): + x = x.copy() + if isinstance(y, dict): + y = y.copy() + + # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, + # if those keys are not already present in the input dict + if self._using_dummy_loss and y is not None: + # If y is a tensor and the model only has one label-like input, map y to that input + if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + label_kwarg = next(iter(label_kwargs)) + if label_kwarg not in x: + x[label_kwarg] = y + # Otherwise, copy keys from y to x as long as they weren't already present in x + elif isinstance(y, dict): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + for key, val in y.items(): + if key in arg_names and key not in x: + x[key] = val + elif output_to_label.get(key, None) in arg_names and key not in x: + x[output_to_label[key]] = val + if y is None: + y = {key: val for key, val in x.items() if key in label_kwargs} + if not y and not self._using_dummy_loss: + raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!") + + if isinstance(y, dict): + # Rename labels at this point to match output heads + y = {label_to_output.get(key, key): val for key, val in y.items()} + + # Run forward pass. + with tf.GradientTape() as tape: + if self._using_dummy_loss and "return_loss" in arg_names: + y_pred = self(x, training=True, return_loss=True) + else: + y_pred = self(x, training=True) + if self._using_dummy_loss: + loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) + else: + loss = None + + # This next block matches outputs to label keys. Tensorflow's standard method for doing this + # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors) + if isinstance(y, dict) and len(y) == 1: + if list(y.keys())[0] in y_pred.keys(): + y_pred = y_pred[list(y.keys())[0]] + elif list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + _, y = y.popitem() + elif isinstance(y, dict): + # If the labels are a dict, match keys from the output by name + y_pred = {key: val for key, val in y_pred.items() if key in y} + elif isinstance(y, tuple) or isinstance(y, list): + # If the labels are a tuple/list, match keys to the output by order, skipping the loss. + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred.to_tuple()[1:] + else: + y_pred = y_pred.to_tuple() + y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems + else: + # If the labels are a single tensor, match them to the first non-loss tensor in the output + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + + if loss is None: + loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) + + # Run backwards pass. + self.optimizer.minimize(loss, self.trainable_variables, tape=tape) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + # Collect metrics to return + return_metrics = {} + for metric in self.metrics: + result = metric.result() + if isinstance(result, dict): + return_metrics.update(result) + else: + return_metrics[metric.name] = result + return return_metrics + + def test_step(self, data): + """ + A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models + and supports directly training on the loss output head. In addition, it ensures input keys are copied to the + labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure + that they are available to the model during the forward pass. + """ + # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` + arg_names = list(inspect.signature(self.call).parameters) + label_kwargs = find_labels(self.__class__) + label_to_output = self.get_label_to_output_name_mapping() + output_to_label = {val: key for key, val in label_to_output.items()} + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer versions leave this out + data = expand_1d(data) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify + # them during input/label pre-processing. This avoids surprising the user by wrecking their data. + # In addition, modifying mutable Python inputs makes XLA compilation impossible. + if isinstance(x, dict): + x = x.copy() + if isinstance(y, dict): + y = y.copy() + + # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, + # if those keys are not already present in the input dict + if self._using_dummy_loss and y is not None: + arg_names = list(inspect.signature(self.call).parameters) + # If y is a tensor and the model only has one label-like input, map y to that input + if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + label_kwarg = next(iter(label_kwargs)) + if label_kwarg not in x: + x[label_kwarg] = y + # Otherwise, copy keys from y to x as long as they weren't already present in x + elif isinstance(y, dict): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + for key, val in y.items(): + if key in arg_names and key not in x: + x[key] = val + elif output_to_label.get(key, None) in arg_names and key not in x: + x[output_to_label[key]] = val + if y is None: + y = {key: val for key, val in x.items() if key in label_kwargs} + if not y and not self._using_dummy_loss: + raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!") + + if isinstance(y, dict): + # Rename labels at this point to match output heads + y = {label_to_output.get(key, key): val for key, val in y.items()} + + # Run forward pass. + if self._using_dummy_loss and "return_loss" in arg_names: + y_pred = self(x, return_loss=True, training=False) + else: + y_pred = self(x, training=False) + if self._using_dummy_loss: + loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) + else: + loss = None + + # This next block matches outputs to label keys. Tensorflow's standard method for doing this + # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors) + if isinstance(y, dict) and len(y) == 1: + if list(y.keys())[0] in y_pred.keys(): + y_pred = y_pred[list(y.keys())[0]] + elif list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + _, y = y.popitem() + elif isinstance(y, dict): + # If the labels are a dict, match keys from the output by name + y_pred = {key: val for key, val in y_pred.items() if key in y} + elif isinstance(y, tuple) or isinstance(y, list): + # If the labels are a tuple/list, match keys to the output by order, skipping the loss. + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred.to_tuple()[1:] + else: + y_pred = y_pred.to_tuple() + y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems + else: + # If the labels are a single tensor, match them to the first non-loss tensor in the output + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + + if loss is None: + loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + # Collect metrics to return + return_metrics = {} + for metric in self.metrics: + result = metric.result() + if isinstance(result, dict): + return_metrics.update(result) + else: + return_metrics[metric.name] = result + return return_metrics + + def create_model_card( + self, + output_dir, + model_name: str, + language: Optional[str] = None, + license: Optional[str] = None, + tags: Optional[str] = None, + finetuned_from: Optional[str] = None, + tasks: Optional[str] = None, + dataset_tags: Optional[Union[str, List[str]]] = None, + dataset: Optional[Union[str, List[str]]] = None, + dataset_args: Optional[Union[str, List[str]]] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + output_dir (`str` or `os.PathLike`): + The folder in which to create the model card. + model_name (`str`, *optional*): + The name of the model. + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `List[str]`, *optional*): + Some tags to be included in the metadata of the model card. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `List[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `List[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `List[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `List[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + # Avoids a circular import by doing this when necessary. + from .modelcard import TrainingSummary # tests_ignore + + training_summary = TrainingSummary.from_keras( + self, + keras_history=self.history, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(os.path.join(output_dir, "README.md"), "w") as f: + f.write(model_card) + + def set_input_embeddings(self, value): + """ + Set model's input embeddings + + Args: + value (`tf.Variable`): + The new weights mapping hidden states to vocabulary. + """ + main_layer = getattr(self, self.base_model_prefix) + + if main_layer is None: + raise NotImplementedError("The model does not implements the base_model_prefix attribute.") + + try: + main_layer.set_input_embeddings(value) + except AttributeError: + logger.info("Building the model") + self.build_in_name_scope() + main_layer.set_input_embeddings(value) + + def get_output_embeddings(self) -> Union[None, keras.layers.Layer]: + """ + Returns the model's output embeddings + + Returns: + `tf.Variable`: The new weights mapping vocabulary to hidden states. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + + try: + return lm_head.get_output_embeddings() + except AttributeError: + logger.info("Building the model") + self.build_in_name_scope() + + return lm_head().get_output_embeddings() + + return None # Overwrite for models with output embeddings + + def set_output_embeddings(self, value): + """ + Set model's output embeddings + + Args: + value (`tf.Variable`): + The new weights mapping hidden states to vocabulary. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + lm_head.set_output_embeddings(value) + except AttributeError: + logger.info("Building the model") + self.build_in_name_scope() + lm_head.set_output_embeddings(value) + + def get_output_layer_with_bias(self) -> Union[None, keras.layers.Layer]: + """ + Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the + embeddings + + Return: + `keras.layers.Layer`: The layer that handles the bias, None if not an LM model. + """ + warnings.warn( + "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning + ) + return self.get_lm_head() + + def get_prefix_bias_name(self) -> Union[None, str]: + """ + Get the concatenated _prefix name of the bias from the model name to the parent layer + + Return: + `str`: The _prefix name of the bias. + """ + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return None + + def get_bias(self) -> Union[None, Dict[str, tf.Variable]]: + """ + Dict of bias attached to an LM head. The key represents the name of the bias attribute. + + Return: + `tf.Variable`: The weights representing the bias, None if not an LM model. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + return lm_head.get_bias() + except AttributeError: + self.build_in_name_scope() + + return lm_head.get_bias() + return None + + def set_bias(self, value): + """ + Set all the bias in the LM head. + + Args: + value (`Dict[tf.Variable]`): + All the new bias attached to an LM head. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + lm_head.set_bias(value) + except AttributeError: + self.build_in_name_scope() + lm_head.set_bias(value) + + def get_lm_head(self) -> keras.layers.Layer: + """ + The LM Head layer. This method must be overwritten by all the models that have a lm head. + + Return: + `keras.layers.Layer`: The LM head layer if the model has one, None if not. + """ + return None + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None + ) -> Union[keras.layers.Embedding, tf.Variable]: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens without doing anything. + + Return: + `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model. + """ + # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor + + # Run the new code path if the model has a keras embeddings layer + if isinstance(self.get_input_embeddings(), keras.layers.Embedding): + return self._v2_resized_token_embeddings(new_num_tokens) + + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: + return self._get_word_embedding_weight(self.get_input_embeddings()) + + model_embeds = self._resize_token_embeddings(new_num_tokens) + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + + return model_embeds + + def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> keras.layers.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens without doing anything. + + Return: + `keras.layers.Embedding`: Pointer to the input tokens of the model. + """ + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: + return self.get_input_embeddings() + + model_embeds = self._v2_resize_token_embeddings(new_num_tokens) + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + + return model_embeds + + def _get_word_embedding_weight(model, embedding_layer): + # TODO (joao): flagged for delection due to embeddings refactor + + # If the variable holds the weights themselves, return them + if isinstance(embedding_layer, tf.Tensor): + return embedding_layer + # Otherwise, try to get them from the layer's attributes + + embeds = getattr(embedding_layer, "weight", None) + if embeds is not None: + return embeds + + embeds = getattr(embedding_layer, "decoder", None) + if embeds is not None: + return embeds + + # The reason why the attributes don't exist might be + # because the model is not built, so retry getting + # the argument after building the model + model.build_in_name_scope() + + embeds = getattr(embedding_layer, "weight", None) + if embeds is not None: + return embeds + + embeds = getattr(embedding_layer, "decoder", None) + if embeds is not None: + return embeds + + return None + + def _resize_token_embeddings(self, new_num_tokens): + # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor + old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings()) + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + + # if word embeddings are not tied, make sure that lm head bias is resized as well + if self.get_bias() is not None: + old_lm_head_bias = self.get_bias() + new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) + + self.set_bias(new_lm_head_bias) + + # if word embeddings are not tied, make sure that lm head decoder is resized as well + if self.get_output_embeddings() is not None: + old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) + new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) + + self.set_output_embeddings(new_lm_head_decoder) + + self.set_input_embeddings(new_embeddings) + + return self.get_input_embeddings() + + def _v2_resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + + # If word embeddings are not tied, make sure that lm head bias is resized as well + if self.get_bias() is not None: + old_lm_head_bias = self.get_bias() + new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) + self.set_bias(new_lm_head_bias) + + # If word embeddings are not tied, make sure that lm head decoder is resized as well. + tied_weights = self.get_input_embeddings() == self.get_output_embeddings() + if self.get_output_embeddings() is not None and not tied_weights: + old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) + # TODO (joao): this one probably needs a v2 version with other models + new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) + self.set_output_embeddings(new_lm_head_decoder) + + return self.get_input_embeddings() + + def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens): + """ + Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end + + Args: + old_lm_head_bias (`tf.Variable`): + Old lm head bias to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns None + + Return: + `tf.Variable`: Pointer to the resized bias. + """ + # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor + new_lm_head_bias = {} + + for attr, weight in old_lm_head_bias.items(): + first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) + size_diff = new_num_tokens - old_num_tokens + final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens] + + # initialize new bias + if tf.math.greater(size_diff, 0): + padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] + current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1) + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy] + bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True) + bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False) + else: + slice_from = [0] if first_dim is None else [0, 0] + current_bias = tf.slice( + weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape) + ) + bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True) + + new_bias = self.add_weight( + shape=final_shape, + initializer="zeros", + trainable=True, + name=weight.name.split(":")[0], + ) + init_bias = tf.where(bias_mask, current_bias, new_bias.value()) + + new_bias.assign(init_bias) + new_lm_head_bias[attr] = new_bias + + return new_lm_head_bias + + def _v2_get_resized_lm_head_bias( + self, old_lm_head_bias: Dict[str, tf.Variable], new_num_tokens: int + ) -> Dict[str, tf.Tensor]: + """ + Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end + + Args: + old_lm_head_bias (`Dict[str, tf.Variable]`): + Old lm head bias to be resized. + new_num_tokens (`int`): + New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at + the end. Reducing the size will remove vectors from the end. + + Return: + `tf.Tensor`: Values for the resized bias. + """ + new_lm_head_bias = {} + + for attr, weight in old_lm_head_bias.items(): + # Determine the size difference (depending on the shape) + first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) + size_diff = new_num_tokens - old_num_tokens + + # Copy the old bias values to the new bias + if old_num_tokens > new_num_tokens: + new_bias = weight.value()[..., :new_num_tokens] + else: + padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] + new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape)) + + new_lm_head_bias[attr] = new_bias + return new_lm_head_bias + + def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens): + """ + Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end + + Args: + old_lm_head_decoder (`tf.Variable`): + Old lm head decoder to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns None + + Return: + `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input + ones. + """ + new_lm_head_decoder = old_lm_head_decoder + is_input_output_equals = tf.reduce_any( + self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder + ) + + if old_lm_head_decoder is not None and not is_input_output_equals: + old_embedding_dim = shape_list(old_lm_head_decoder)[1] + decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens) + new_lm_head_decoder = self.add_weight( + shape=(new_num_tokens, old_embedding_dim), + initializer="zeros", + trainable=True, + name=old_lm_head_decoder.name.split(":")[0], + ) + init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value()) + + new_lm_head_decoder.assign(init_decoder) + + return new_lm_head_decoder + + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: + """ + Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_embeddings (`tf.Variable`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `tf.Variable` module of the model without doing anything. + + Return: + `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is + `None` + """ + # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor + old_embedding_dim = shape_list(old_embeddings)[1] + init_range = getattr(self.config, "initializer_range", 0.02) + embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens) + new_embeddings = self.add_weight( + name=old_embeddings.name.split(":")[0], + shape=[new_num_tokens, old_embedding_dim], + initializer=get_initializer(init_range), + dtype=tf.float32, + ) + init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value()) + + new_embeddings.assign(init_embeddings) + + return new_embeddings + + def _v2_get_resized_embeddings( + self, old_embeddings: keras.layers.Embedding, new_num_tokens: int + ) -> keras.layers.Embedding: + """ + Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. + + Args: + old_embeddings (`keras.layers.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Return: + `keras.layers.Embedding`: Resized Embedding layer. + """ + + # Get the initialization range for the embeddings + init_range = 0.02 # default value + potential_initialization_variable_names = [ + "initializer_range", # most common + "initializer_factor", # e.g. T5 + "init_std", # e.g BART + ] + for var_name in potential_initialization_variable_names: + if hasattr(self.config, var_name): + init_range = getattr(self.config, var_name) + + # Get a new (initialized) embeddings layer + new_embeddings = keras.layers.Embedding( + input_dim=new_num_tokens, + output_dim=old_embeddings.output_dim, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range), + name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0" + ) + new_embeddings(tf.constant([[0]])) + + # Copy the old embeddings to the new embeddings + if old_embeddings.input_dim >= new_num_tokens: + init_embeddings = old_embeddings.embeddings[:new_num_tokens] + else: + init_embeddings = tf.concat( + [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0 + ) + new_embeddings.embeddings.assign(init_embeddings) + return new_embeddings + + def prune_heads(self, heads_to_prune): + """ + Prunes heads of the base model. + + Arguments: + heads_to_prune (`Dict[int, List[int]]`): + Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads + to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on + layer 1 and heads 2 and 3 on layer 2. + """ + raise NotImplementedError + + def save_pretrained( + self, + save_directory, + saved_model=False, + version=1, + push_to_hub=False, + signatures=None, + max_shard_size: Union[int, str] = "5GB", + create_pr: bool = False, + safe_serialization: bool = False, + token: Optional[Union[str, bool]] = None, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~TFPreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str`): + Directory to which to save. Will be created if it doesn't exist. + saved_model (`bool`, *optional*, defaults to `False`): + If the model has to be saved in saved model format as well or not. + version (`int`, *optional*, defaults to 1): + The version of the saved model. A saved model needs to be versioned in order to be properly loaded by + TensorFlow Serving as detailed in the official documentation + https://www.tensorflow.org/tfx/serving/serving_basic + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + signatures (`dict` or `tf.function`, *optional*): + Model's signature used for serving. This will be passed to the `signatures` argument of model.save(). + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + if saved_model: + # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string. + # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.) + if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): + self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] + if signatures is None: + serving_default = self.serving.get_concrete_function(self.input_signature) + if any(spec.dtype == tf.int32 for spec in self.input_signature.values()): + int64_spec = { + key: tf.TensorSpec( + shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name + ) + for key, spec in self.input_signature.items() + } + int64_serving = self.serving.get_concrete_function(int64_spec) + signatures = {"serving_default": serving_default, "int64_serving": int64_serving} + else: + signatures = serving_default + saved_model_dir = os.path.join(save_directory, "saved_model", str(version)) + self.save(saved_model_dir, include_optimizer=False, signatures=signatures) + logger.info(f"Saved model created in {saved_model_dir}") + + # Save configuration file + self.config.architectures = [self.__class__.__name__[2:]] + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + self.config.save_pretrained(save_directory) + if self.can_generate(): + self.generation_config.save_pretrained(save_directory) + + # If we save using the predefined names, we can load using `from_pretrained` + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME + output_model_file = os.path.join(save_directory, weights_name) + + shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards.keys() + ): + os.remove(full_filename) + + if index is None: + if safe_serialization: + state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights} + safe_save_file(state_dict, output_model_file, metadata={"format": "tf"}) + else: + self.save_weights(output_model_file) + logger.info(f"Model weights saved in {output_model_file}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as index_file: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + index_file.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + for shard_file, shard in shards.items(): + if safe_serialization: + shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard} + safe_save_file( + shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"} + ) + else: + with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: + layers = [] + for layer in sorted(shard, key=lambda x: x.name): + if "model." in layer.name or len(layer.name.split("/")) == 1: + layer_name = layer.name + else: + layer_name = "/".join(layer.name.split("/")[1:]) + param_dset = shard_file.create_dataset( + layer_name, layer.numpy().shape, dtype=layer.numpy().dtype + ) + param_dset[:] = layer.numpy() + layers.append(layer_name.encode("utf8")) + save_attributes_to_hdf5_group(shard_file, "layer_names", layers) + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + r""" + Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this + case, `from_pt` should be set to `True` and a configuration object should be provided as `config` + argument. This loading path is slower than converting the PyTorch model in a TensorFlow model + using the provided conversion scripts and loading the TensorFlow model afterwards. + - `None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments `config` and `state_dict`). + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch state_dict save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + cache_dir (`str`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies: + (`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., + `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a + dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + tf_to_pt_weight_rename (`Callable`, *optional*): + A function that is called to transform the names of weights during the PyTorch to TensorFlow + crossloading process. This is not necessary for most models, but is useful to allow composite models to + be crossloaded correctly. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import BertConfig, TFBertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = TFBertModel.from_pretrained("./test/saved_model/") + >>> # Update configuration during loading. + >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) + >>> assert model.config.output_attentions == True + >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json") + >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config) + ```""" + from_pt = kwargs.pop("from_pt", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + _ = kwargs.pop("mirror", None) + load_weight_prefix = kwargs.pop("load_weight_prefix", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) + + # Not relevant for TF models + _ = kwargs.pop("adapter_kwargs", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + _commit_hash=commit_hash, + **kwargs, + ) + else: + model_kwargs = kwargs + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + # Load model + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint in priority if from_pt + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): + # Load from a TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)): + # Load from a sharded TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) + is_sharded = True + + # At this stage we don't have a weight file so we will raise an error. + elif use_safetensors: + raise EnvironmentError( + f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. " + f"Please make sure that the model has been saved with `safe_serialization=True` or do not " + f"set `use_safetensors=True`." + ) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( + os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + ): + raise EnvironmentError( + f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " + "weights." + ) + else: + raise EnvironmentError( + f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + elif os.path.isfile(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + archive_file = pretrained_model_name_or_path + ".index" + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_pt: + filename = WEIGHTS_NAME + elif use_safetensors is not False: + filename = SAFE_WEIGHTS_NAME + else: + filename = TF2_WEIGHTS_NAME + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: + # Did not find the safetensors file, let's fallback to TF. + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = TF2_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None and filename == WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None: + # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): + is_sharded = True + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" + " load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," + f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" + ) + + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" + ) + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + filename = resolved_archive_file.split(os.path.sep)[-1] + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + _commit_hash=commit_hash, + ) + + safetensors_from_pt = False + if filename == SAFE_WEIGHTS_NAME: + with safe_open(resolved_archive_file, framework="tf") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + elif filename == SAFE_WEIGHTS_INDEX_NAME: + with safe_open(resolved_archive_file[0], framework="tf") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + + config.name_or_path = pretrained_model_name_or_path + + # composed models, *e.g.* TFRag, require special treatment when it comes to loading + # pre-trained weights. + if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None: + model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name") + + # Instantiate model. + model = cls(config, *model_args, **model_kwargs) + + if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"): + # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method + # to be defined for each class that requires a rename. We can probably just have a class-level + # dict and a single top-level method or something and cut down a lot of boilerplate code + tf_to_pt_weight_rename = model.tf_to_pt_weight_rename + + if from_pt: + from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model + + # Load from a PyTorch checkpoint + return load_pytorch_checkpoint_in_tf2_model( + model, + resolved_archive_file, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + # we might need to extend the variable scope for composite models + if load_weight_prefix is not None: + with tf.compat.v1.variable_scope(load_weight_prefix): + model.build_in_name_scope() # build the network with dummy inputs + else: + model.build_in_name_scope() # build the network with dummy inputs + + if safetensors_from_pt and not is_sharded: + from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model + + with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: + # Load from a PyTorch safetensors checkpoint + # We load in TF format here because PT weights often need to be transposed, and this is much + # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times. + return load_pytorch_state_dict_in_tf2_model( + model, + safetensors_archive, + tf_inputs=False, # No need to build the model again + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + ignore_mismatched_sizes=ignore_mismatched_sizes, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + elif safetensors_from_pt: + from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model + + return load_sharded_pytorch_safetensors_in_tf2_model( + model, + resolved_archive_file, + tf_inputs=False, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + ignore_mismatched_sizes=ignore_mismatched_sizes, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + # 'by_name' allow us to do transfer learning by skipping/adding layers + # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 + try: + if is_sharded: + for file in resolved_archive_file: + os.path.isfile(file), f"Error retrieving files {file}" + if filename == SAFE_WEIGHTS_INDEX_NAME: + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + else: + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + else: + # Handles both H5 and safetensors + missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + except OSError as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise OSError( + "Unable to load weights from h5 file. " + "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " + ) + + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.warning( + f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + + return model, loading_info + + return model + + def push_to_hub( + self, + repo_id: str, + use_temp_dir: Optional[bool] = None, + commit_message: Optional[str] = None, + private: Optional[bool] = None, + max_shard_size: Optional[Union[int, str]] = "10GB", + token: Optional[Union[bool, str]] = None, + # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs) + use_auth_token: Optional[Union[bool, str]] = None, + create_pr: bool = False, + **base_model_card_args, + ) -> str: + """ + Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. + + Parameters: + repo_id (`str`): + The name of the repository you want to push your model to. It should contain your organization name + when pushing to a given organization. + use_temp_dir (`bool`, *optional*): + Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub. + Will default to `True` if there is no directory named like `repo_id`, `False` otherwise. + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to `"Upload model"`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + token (`bool` or `str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` + is not specified. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard + will then be each of size lower than this size. If expressed as a string, needs to be digits followed + by a unit (like `"5MB"`). + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + + Examples: + + ```python + from transformers import TFAutoModel + + model = TFAutoModel.from_pretrained("google-bert/bert-base-cased") + + # Push the model to your namespace with the name "my-finetuned-bert". + model.push_to_hub("my-finetuned-bert") + + # Push the model to an organization with the name "my-finetuned-bert". + model.push_to_hub("huggingface/my-finetuned-bert") + ``` + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if "repo_path_or_name" in base_model_card_args: + warnings.warn( + "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " + "`repo_id` instead." + ) + repo_id = base_model_card_args.pop("repo_path_or_name") + # Deprecation warning will be sent after for repo_url and organization + repo_url = base_model_card_args.pop("repo_url", None) + organization = base_model_card_args.pop("organization", None) + + if os.path.isdir(repo_id): + working_dir = repo_id + repo_id = repo_id.split(os.path.sep)[-1] + else: + working_dir = repo_id.split("/")[-1] + + repo_id = self._create_repo( + repo_id, private=private, token=token, repo_url=repo_url, organization=organization + ) + + if use_temp_dir is None: + use_temp_dir = not os.path.isdir(working_dir) + + with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir: + files_timestamps = self._get_files_timestamps(work_dir) + + # Save all files. + self.save_pretrained(work_dir, max_shard_size=max_shard_size) + if hasattr(self, "history") and hasattr(self, "create_model_card"): + # This is a Keras model and we might be able to fish out its History and make a model card out of it + base_model_card_args = { + "output_dir": work_dir, + "model_name": Path(repo_id).name, + } + base_model_card_args.update(base_model_card_args) + self.create_model_card(**base_model_card_args) + + self._upload_modified_files( + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + create_pr=create_pr, + ) + + @classmethod + def register_for_auto_class(cls, auto_class="TFAutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + +class TFConv1D(keras.layers.Layer): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): + The number of output features. + nx (`int`): + The number of input features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation to use to initialize the weights. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. + """ + + def __init__(self, nf, nx, initializer_range=0.02, **kwargs): + super().__init__(**kwargs) + self.nf = nf + self.nx = nx + self.initializer_range = initializer_range + + def build(self, input_shape): + if self.built: + return + self.built = True + self.weight = self.add_weight( + "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range) + ) + self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer()) + + def call(self, x): + bz, sl = shape_list(x)[:2] + + x = tf.reshape(x, [-1, self.nx]) + x = tf.matmul(x, self.weight) + self.bias + + x = tf.reshape(x, [bz, sl, self.nf]) + + return x + + +class TFSharedEmbeddings(keras.layers.Layer): + r""" + Construct shared token embeddings. + + The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language + modeling. + + Args: + vocab_size (`int`): + The size of the vocabulary, e.g., the number of unique tokens. + hidden_size (`int`): + The size of the embedding vectors. + initializer_range (`float`, *optional*): + The standard deviation to use when initializing the weights. If no value is provided, it will default to + \\(1/\sqrt{hidden\_size}\\). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. + """ + + # TODO (joao): flagged for delection due to embeddings refactor + + def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range + warnings.warn( + "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.", + DeprecationWarning, + ) + + def build(self, input_shape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + self.weight = self.add_weight( + "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range) + ) + super().build(input_shape) + + def get_config(self): + config = { + "vocab_size": self.vocab_size, + "hidden_size": self.hidden_size, + "initializer_range": self.initializer_range, + } + base_config = super().get_config() + + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor: + """ + Get token embeddings of inputs or decode final hidden state. + + Args: + inputs (`tf.Tensor`): + In embedding mode, should be an int64 tensor with shape `[batch_size, length]`. + + In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`. + mode (`str`, defaults to `"embedding"`): + A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be + used as an embedding layer, the second one that the layer should be used as a linear decoder. + + Returns: + `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length, + embedding_size]`. + + In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`. + + Raises: + ValueError: if `mode` is not valid. + + Shared weights logic is adapted from + [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24). + """ + if mode == "embedding": + return self._embedding(inputs) + elif mode == "linear": + return self._linear(inputs) + else: + raise ValueError(f"mode {mode} is not valid.") + + def _embedding(self, input_ids): + """Applies embedding based on inputs tensor.""" + return tf.gather(self.weight, input_ids) + + def _linear(self, inputs): + """ + Computes logits by running inputs through a linear layer. + + Args: + inputs: A float32 tensor with shape [..., hidden_size] + + Returns: + float32 tensor with shape [..., vocab_size]. + """ + first_dims = shape_list(inputs)[:-1] + x = tf.reshape(inputs, [-1, self.hidden_size]) + logits = tf.matmul(x, self.weight, transpose_b=True) + + return tf.reshape(logits, first_dims + [self.vocab_size]) + + +class TFSequenceSummary(keras.layers.Layer): + """ + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + + initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. + """ + + def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs): + super().__init__(**kwargs) + + self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last" + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj + if self.has_summary: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = keras.layers.Dense( + num_classes, kernel_initializer=get_initializer(initializer_range), name="summary" + ) + + self.has_activation = False + activation_string = getattr(config, "summary_activation", None) + if activation_string is not None: + self.has_activation = True + self.activation = get_tf_activation(activation_string) + + self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0 + if self.has_first_dropout: + self.first_dropout = keras.layers.Dropout(config.summary_first_dropout) + + self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0 + if self.has_last_dropout: + self.last_dropout = keras.layers.Dropout(config.summary_last_dropout) + self.hidden_size = config.hidden_size + + def call(self, inputs, cls_index=None, training=False): + if not isinstance(inputs, (dict, tuple, list)): + hidden_states = inputs + elif isinstance(inputs, (tuple, list)): + hidden_states = inputs[0] + cls_index = inputs[1] if len(inputs) > 1 else None + assert len(inputs) <= 2, "Too many inputs." + else: + hidden_states = inputs.get("hidden_states") + cls_index = inputs.get("cls_index", None) + + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = tf.reduce_mean(hidden_states, axis=1) + elif self.summary_type == "cls_index": + hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] + if cls_index is None: + cls_index = tf.fill( + hidden_shape[:-2], hidden_shape[-2] - 1 + ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length + cls_shape = shape_list(cls_index) + if len(cls_shape) <= len(hidden_shape) - 2: + cls_index = tf.expand_dims(cls_index, axis=-1) + # else: + # cls_index = cls_index[..., tf.newaxis] + # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2) + output = tf.squeeze( + output, axis=len(hidden_shape) - 2 + ) # shape of output: (batch, num choices, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + if self.has_first_dropout: + output = self.first_dropout(output, training=training) + + if self.has_summary: + output = self.summary(output) + + if self.has_activation: + output = self.activation(output) + + if self.has_last_dropout: + output = self.last_dropout(output, training=training) + + return output + + def build(self, input_shape): + if self.built: + return + self.built = True + if getattr(self, "summary", None) is not None: + with tf.name_scope("summary"): + self.summary.build(self.hidden_size) + + +def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal: + """ + Creates a `keras.initializers.TruncatedNormal` with the given range. + + Args: + initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range. + + Returns: + `keras.initializers.TruncatedNormal`: The truncated normal initializer. + """ + return keras.initializers.TruncatedNormal(stddev=initializer_range) diff --git a/modeling_utils.py b/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5453c1ac4de5b5d8a74a58246d5ab89a20dca7fe --- /dev/null +++ b/modeling_utils.py @@ -0,0 +1,5672 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import copy +import functools +import gc +import importlib.metadata +import inspect +import itertools +import json +import os +import re +import shutil +import tempfile +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial, wraps +from threading import Thread +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from zipfile import is_zipfile + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from packaging import version +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss, Identity +from torch.utils.checkpoint import checkpoint + +from .activations import get_activation +from .configuration_utils import PretrainedConfig +from .dynamic_module_utils import custom_object_save +from .generation import CompileConfig, GenerationConfig, GenerationMixin +from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.flash_attention import flash_attention_forward +from .integrations.flex_attention import flex_attention_forward +from .integrations.sdpa_attention import sdpa_attention_forward +from .loss.loss_utils import LOSS_MAPPING +from .pytorch_utils import ( # noqa: F401 + Conv1D, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + id_tensor_storage, + prune_conv1d_layer, + prune_layer, + prune_linear_layer, + translate_to_torch_parallel_style, +) +from .quantizers import AutoHfQuantizer, HfQuantizer +from .quantizers.quantizers_utils import get_module_from_name +from .safetensors_conversion import auto_conversion +from .utils import ( + ACCELERATE_MIN_VERSION, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + DUMMY_INPUTS, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ContextManagers, + ModelOutput, + PushToHubMixin, + cached_file, + copy_func, + download_url, + extract_commit_hash, + has_file, + is_accelerate_available, + is_bitsandbytes_available, + is_flash_attn_2_available, + is_offline_mode, + is_optimum_available, + is_peft_available, + is_remote_url, + is_safetensors_available, + is_torch_flex_attn_available, + is_torch_greater_or_equal, + is_torch_sdpa_available, + is_torch_xla_available, + logging, + replace_return_docstrings, + strtobool, +) +from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files +from .utils.import_utils import ( + ENV_VARS_TRUE_VALUES, + is_sagemaker_mp_enabled, + is_torch_fx_proxy, + is_torchdynamo_compiling, +) +from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod + + +XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() +XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() + + +if is_accelerate_available(): + from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights + from accelerate.hooks import add_hook_to_module + from accelerate.utils import ( + check_tied_parameters_on_same_device, + extract_model_from_parallel, + find_tied_parameters, + get_balanced_memory, + get_max_memory, + load_offloaded_weights, + offload_weight, + save_offload_index, + set_module_tensor_to_device, + ) + + accelerate_version = version.parse(importlib.metadata.version("accelerate")) + if accelerate_version >= version.parse("0.31"): + from accelerate.utils.modeling import get_state_dict_from_offload + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import save_file as safe_save_file + +logger = logging.get_logger(__name__) + + +_init_weights = True +_is_quantized = False +_is_ds_init_called = False + + +def is_fsdp_enabled(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 + and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 + ) + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", -1)) == 0 + ) + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") +else: + IS_SAGEMAKER_MP_POST_1_10 = False + +if is_peft_available(): + from .utils import find_adapter_config_file + +SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") + +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} + + +@contextmanager +def no_init_weights(_enable=True): + """ + Context manager to globally disable weight initialization to speed up loading large models. + + TODO(Patrick): Delete safety argument `_enable=True` at next major version. . + """ + global _init_weights + old_init_weights = _init_weights + + if _enable: + _init_weights = False + + def _skip_init(*args, **kwargs): + pass + + # # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) + try: + yield + finally: + _init_weights = old_init_weights + if _enable: + # # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) + + +@contextmanager +def set_quantized_state(): + global _is_quantized + _is_quantized = True + try: + yield + finally: + _is_quantized = False + + +# Skip recursive calls to deepspeed.zero.Init to avoid pinning errors. +# This issue occurs with ZeRO stage 3 when using NVMe offloading. +# For more details, refer to issue #34429. +@contextmanager +def set_zero3_state(): + global _is_ds_init_called + _is_ds_init_called = True + try: + yield + finally: + _is_ds_init_called = False + + +def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): + try: + return next(parameter.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): + """ + Returns the first parameter dtype (can be non-floating) or asserts if none were found. + """ + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch > 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for t in parameter.parameters(): + last_dtype = t.dtype + if t.is_floating_point(): + # Adding fix for https://github.com/pytorch/xla/issues/4152 + # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 + # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf + # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo + if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available(): + return torch.bfloat16 + if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available(): + if t.dtype == torch.float: + return torch.bfloat16 + if t.dtype == torch.double: + return torch.float32 + return t.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for tuple in gen: + last_tuple = tuple + if tuple[1].is_floating_point(): + return tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype + + # fallback to buffer dtype + for t in parameter.buffers(): + last_dtype = t.dtype + if t.is_floating_point(): + return t.dtype + return last_dtype + + +def get_state_dict_float_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` or asserts if none were found. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + raise ValueError("couldn't find any floating point dtypes in state_dict") + + +def get_state_dict_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + # if no floating dtype was found return whatever the first dtype is + else: + return next(state_dict.values()).dtype + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such + as when loading in empty weights) by first checking + if the model explicitly disables it, then by ensuring that the state dict keys + are a subset of the model's parameters. + + Note: We fully disable this if we are using `deepspeed` + """ + if model_to_load.device.type == "meta": + return False + + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + if is_deepspeed_zero3_enabled(): + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = next(iter(model_to_load.state_dict().keys())) + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) + return False + + +def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): + """ + This is the same as + [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) + but for a sharded checkpoint. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`torch.nn.Module`): The model in which to load the checkpoint. + folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. + strict (`bool`, *optional`, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + prefer_safe (`bool`, *optional*, defaults to `False`) + If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the + safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. + + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields + - `missing_keys` is a list of str containing the missing keys + - `unexpected_keys` is a list of str containing the unexpected keys + """ + # Load the index + index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not (safe_index_present and is_safetensors_available()): + filenames = ( + (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,) + ) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + + load_safe = False + if safe_index_present: + if prefer_safe: + if is_safetensors_available(): + load_safe = True # load safe due to preference + else: + logger.warning( + f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!" + ) + elif not index_present: + load_safe = True # load safe since we have no other choice + + load_index = safe_index_file if load_safe else index_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + + # If strict=True, error before loading any of the state dicts. + loaded_keys = index["weight_map"].keys() + model_keys = model.state_dict().keys() + missing_keys = [key for key in model_keys if key not in loaded_keys] + unexpected_keys = [key for key in loaded_keys if key not in model_keys] + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + weights_only_kwarg = {"weights_only": True} + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg) + + for shard_file in shard_files: + state_dict = loader(os.path.join(folder, shard_file)) + model.load_state_dict(state_dict, strict=False) + + # Make sure memory is freed before we load the next state dict. + del state_dict + gc.collect() + + # Return the same thing as PyTorch load_state_dict function. + return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + + +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], + is_quantized: bool = False, + map_location: Optional[Union[str, torch.device]] = None, + weights_only: bool = True, +): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return safe_load_file(checkpoint_file) + try: + if map_location is None: + if ( + ( + is_deepspeed_zero3_enabled() + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ) + or (is_fsdp_enabled() and not is_local_dist_rank_0()) + ) and not is_quantized: + map_location = "meta" + else: + map_location = "cpu" + extra_args = {} + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and version.parse(torch.__version__) >= version.parse("2.1.0") + and is_zipfile(checkpoint_file) + ): + extra_args = {"mmap": True} + weights_only_kwarg = {"weights_only": weights_only} + return torch.load( + checkpoint_file, + map_location=map_location, + **weights_only_kwarg, + **extra_args, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read(7) == "version": + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def set_initialized_submodules(model, state_dict_keys): + """ + Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state + dict. + """ + not_initialized_submodules = {} + for module_name, module in model.named_modules(): + loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + # When checking if the root module is loaded all state_dict_keys must be used. + if module_name == "": + loaded_keys = set(state_dict_keys) + if loaded_keys.issuperset(module.state_dict()): + module._is_hf_initialized = True + else: + not_initialized_submodules[module_name] = module + return not_initialized_submodules + + +def _end_ptr(tensor: torch.Tensor) -> int: + # extract the end of the pointer if the tensor is a slice of a bigger tensor + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +def _get_tied_weight_keys(module: nn.Module, prefix=""): + tied_weight_keys = [] + if getattr(module, "_tied_weights_keys", None) is not None: + names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys] + tied_weight_keys.extend(names) + if getattr(module, "_dynamic_tied_weights_keys", None) is not None: + names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys] + tied_weight_keys.extend(names) + for name, submodule in module.named_children(): + local_prefix = f"{prefix}.{name}" if prefix else name + tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix)) + return tied_weight_keys + + +def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]: + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + disjoint_tensors = [] + shared_tensors = [] + for tensors in filtered_tensors: + if len(tensors) == 1: + disjoint_tensors.append(tensors.pop()) + else: + shared_tensors.append(tensors) + return shared_tensors, disjoint_tensors + + +def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + shared_tensors = [] + identical = [] + for shared in tensors: + if len(shared) < 2: + continue + + areas = collections.defaultdict(set) + for name in shared: + tensor = state_dict[name] + area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) + areas[area].add(name) + if len(areas) == 1: + identical.append(shared) + else: + shared_tensors.append(shared) + return shared_tensors, identical + + +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + if is_deepspeed_zero3_enabled(): + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) + + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + + +def find_submodule_and_param_name(model, long_key, start_prefix): + """ + A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed + from the start of the key + """ + + if len(start_prefix) > 0 and long_key.startswith(start_prefix): + long_key = ".".join(long_key.split(".")[1:]) + + split_key = long_key.split(".") + submodule = model + while len(split_key) > 1: + if hasattr(submodule, split_key[0]): + submodule = getattr(submodule, split_key[0]) + del split_key[0] + else: + submodule = None + break + if submodule == model: + submodule = None + return submodule, split_key[0] + + +def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): + """ + Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # dematerialize param storage for keys that are going to be replaced by state_dict, by + # putting those on the meta device + for k in loaded_state_dict_keys: + submodule, param_name = find_submodule_and_param_name(model, k, start_prefix) + if submodule is not None: + # selectively switch to the meta device only those params/buffers that will + # be next replaced from state_dict. This a complex way to do p.to_("meta") + # since we have no in-place to_ for tensors. + new_val = getattr(submodule, param_name) + if isinstance(new_val, torch.nn.Parameter): + # isinstance returns False for Params on meta device, so switch after the check + new_val = torch.nn.Parameter(new_val.to("meta")) + else: + new_val = new_val.to("meta") + setattr(submodule, param_name, new_val) + + +def _load_state_dict_into_meta_model( + model, + state_dict, + start_prefix, + expected_keys, + device_map=None, + offload_folder=None, + offload_index=None, + state_dict_folder=None, + state_dict_index=None, + dtype=None, + hf_quantizer=None, + is_safetensors=False, + keep_in_fp32_modules=None, + unexpected_keys=None, # passing `unexpected` for cleanup from quantization items + pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys +): + """ + This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its + params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the + params back to the normal device, but only for `loaded_state_dict_keys`. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model + # - deepspeed zero 3 support + # - need to copy metadata if any - see _load_state_dict_into_model + # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() + + error_msgs = [] + + is_quantized = hf_quantizer is not None + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + + for param_name, param in state_dict.items(): + if param_name not in expected_keys: + continue + + if param_name.startswith(start_prefix): + param_name = param_name[len(start_prefix) :] + + module_name = param_name + set_module_kwargs = {} + + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn + if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: + if ( + keep_in_fp32_modules is not None + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + + # For backward compatibility with older versions of `accelerate` + # TODO: @sgugger replace this check with version check at the next `accelerate` release + if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters): + set_module_kwargs["dtype"] = torch.float32 + else: + param = param.to(dtype) + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + + old_param = model + splits = param_name.split(".") + for split in splits: + # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys. + old_param = getattr(old_param, split, None) + if old_param is None: + break + + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + + if old_param is not None: + if dtype is None: + param = param.to(old_param.dtype) + + if old_param.is_contiguous(): + param = param.contiguous() + + set_module_kwargs["value"] = param + + if device_map is None: + param_device = "cpu" + else: + # find next higher level module that is defined in device_map: + # bert.lm_head.weight -> bert.lm_head -> bert -> '' + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + # TODO: group all errors and raise at the end. + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] + + if param_device == "disk": + if not is_safetensors: + offload_index = offload_weight(param, param_name, offload_folder, offload_index) + elif param_device == "cpu" and state_dict_index is not None: + state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + elif ( + not is_quantized + or (not hf_quantizer.requires_parameters_quantization) + or ( + not hf_quantizer.check_quantized_param( + model, param, param_name, state_dict, param_device=param_device, device_map=device_map + ) + ) + ): + if is_fsdp_enabled(): + param_device = "cpu" if is_local_dist_rank_0() else "meta" + + # For backward compatibility with older versions of `accelerate` and for non-quantized params + set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) + else: + hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) + # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU + # and then cast it to CPU to avoid excessive memory usage on each GPU + # in comparison to the sharded model across GPUs. + if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + module, tensor_name = get_module_from_name(model, param_name) + value = getattr(module, tensor_name) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + val_kwargs = {} + if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params": + val_kwargs["requires_grad"] = False + value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) + setattr(module, tensor_name, value) + # TODO: consider removing used param_parts from state_dict before return + + return error_msgs, offload_index, state_dict_index + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +class ModuleUtilsMixin: + """ + A few utilities for `torch.nn.Modules`, to be used as a mixin. + """ + + @staticmethod + def _hook_rss_memory_pre_forward(module, *args, **kwargs): + try: + import psutil + except ImportError: + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_pre_forward = mem.rss + return None + + @staticmethod + def _hook_rss_memory_post_forward(module, *args, **kwargs): + try: + import psutil + except ImportError: + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_post_forward = mem.rss + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) + return None + + def add_memory_hooks(self): + """ + Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. + + Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero + with `model.reset_memory_hooks_state()`. + """ + for module in self.modules(): + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) + module.register_forward_hook(self._hook_rss_memory_post_forward) + self.reset_memory_hooks_state() + + def reset_memory_hooks_state(self): + """ + Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]). + """ + for module in self.modules(): + module.mem_rss_diff = 0 + module.mem_rss_post_forward = 0 + module.mem_rss_pre_forward = 0 + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min + + return encoder_extended_attention_mask + + @staticmethod + def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None): + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + else: + device = attention_mask.device + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) + ] + total_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + else: + total_parameters = list(self.parameters()) + + total_numel = [] + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif hasattr(param, "quant_storage"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + total_numel.append(param.numel() * 2 * num_bytes) + else: + total_numel.append(param.numel()) + + return sum(total_numel) + + def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: + """ + Helper function to estimate the total number of tokens from the model inputs. + + Args: + inputs (`dict`): The model inputs. + + Returns: + `int`: The total number of tokens. + """ + if not hasattr(self, "warnings_issued"): + self.warnings_issued = {} + if self.main_input_name in input_dict: + return input_dict[self.main_input_name].numel() + elif "estimate_tokens" not in self.warnings_issued: + logger.warning( + "Could not estimate the number of tokens of the input, floating-point operations will not be computed" + ) + self.warnings_issued["estimate_tokens"] = True + return 0 + + def floating_point_ops( + self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True + ) -> int: + """ + Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a + batch with this transformer model. Default approximation neglects the quadratic dependency on the number of + tokens (valid if `12 * d_model << sequence_length`) as laid out in [this + paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter + re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths. + + Args: + batch_size (`int`): + The batch size for the forward pass. + + sequence_length (`int`): + The number of tokens in each line of the batch. + + exclude_embeddings (`bool`, *optional*, defaults to `True`): + Whether or not to count embedding and softmax operations. + + Returns: + `int`: The number of floating-point operations. + """ + + return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) + + +# TODO (joao): remove `GenerationMixin` inheritance in v4.50 +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): + r""" + Base class for all models. + + [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, + taking as arguments: + + - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint. + - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model. + - **path** (`str`) -- A path to the TensorFlow checkpoint. + + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + model_tags = None + + _auto_class = None + _no_split_modules = None + _skip_keys_device_placement = None + _keep_in_fp32_modules = None + + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing + # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. + _keys_to_ignore_on_load_missing = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of + # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary + # warnings. + _keys_to_ignore_on_load_unexpected = None + # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't + # trained, but which are either deterministic or tied variables) + _keys_to_ignore_on_save = None + # a list of `state_dict` keys that are potentially tied to another key in the state_dict. + _tied_weights_keys = None + + is_parallelizable = False + supports_gradient_checkpointing = False + _is_stateful = False + + # Flash Attention 2 support + _supports_flash_attn_2 = False + + # SDPA support + _supports_sdpa = False + + # Flex Attention support + _supports_flex_attn = False + + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? + _supports_cache_class = False + _supports_static_cache = False + + # Has support for a `QuantoQuantizedCache` instance as `past_key_values` + _supports_quantized_cache = False + + # A tensor parallel plan to be applied to the model when TP is enabled. For + # top-level models, this attribute is currently defined in respective model + # code. For base models, this attribute comes from + # `config.base_model_tp_plan` during `post_init`. + _tp_plan = None + + @property + def dummy_inputs(self) -> Dict[str, torch.Tensor]: + """ + `Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network. + """ + return {"input_ids": torch.tensor(DUMMY_INPUTS)} + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a PyTorch model. + """ + return "pt" + + def __init__(self, config: PretrainedConfig, *inputs, **kwargs): + super().__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + if not getattr(config, "_attn_implementation_autoset", False): + config = self._autoset_attn_implementation( + config, torch_dtype=torch.get_default_dtype(), check_device_map=False + ) + self.config = config + + # for initialization of the loss + loss_type = self.__class__.__name__ + if loss_type not in LOSS_MAPPING: + loss_groups = f"({'|'.join(LOSS_MAPPING)})" + loss_type = re.findall(loss_groups, self.__class__.__name__) + if len(loss_type) > 0: + loss_type = loss_type[0] + else: + loss_type = None + self.loss_type = loss_type + + self.name_or_path = config.name_or_path + self.warnings_issued = {} + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + # Overwrite the class attribute to make it an instance attribute, so models like + # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute + # when a different component (e.g. language_model) is used. + self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + """ + self.init_weights() + self._backward_compatibility_gradient_checkpointing() + # If current model is a base model, attach `base_model_tp_plan` from config + if self.base_model is self: + self._tp_plan = self.config.base_model_tp_plan + + def dequantize(self): + """ + Potentially dequantize the model in case it has been quantized by a quantization method that support + dequantization. + """ + hf_quantizer = getattr(self, "hf_quantizer", None) + + if hf_quantizer is None: + raise ValueError("You need to first quantize your model in order to dequantize it") + + return hf_quantizer.dequantize(self) + + def _backward_compatibility_gradient_checkpointing(self): + if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(self.config, "gradient_checkpointing") + + def add_model_tags(self, tags: Union[List[str], str]) -> None: + r""" + Add custom tags into the model that gets pushed to the Hugging Face Hub. Will + not overwrite existing tags in the model. + + Args: + tags (`Union[List[str], str]`): + The desired tags to inject in the model + + Examples: + + ```python + from transformers import AutoModel + + model = AutoModel.from_pretrained("google-bert/bert-base-cased") + + model.add_model_tags(["custom", "custom-bert"]) + + # Push the model to your namespace with the name "my-custom-bert". + model.push_to_hub("my-custom-bert") + ``` + """ + if isinstance(tags, str): + tags = [tags] + + if self.model_tags is None: + self.model_tags = [] + + for tag in tags: + if tag not in self.model_tags: + self.model_tags.append(tag) + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. + """ + # when we init a model from within another model (e.g. VLMs) and dispatch on FA2 + # a warning is raised that dtype should be fp16. Since we never pass dtype from within + # modeling code, we can try to infer it here same way as done in `from_pretrained` + torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype()) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + + # override default dtype if needed + dtype_orig = None + if torch_dtype is not None: + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. + + if config._attn_implementation_internal is not None: + # In this case, the config has been created with the attn_implementation set by the user, which we + # should respect. + attn_implementation = config._attn_implementation_internal + else: + attn_implementation = None + + config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + check_device_map=False, + torch_dtype=torch_dtype, + ) + + if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called: + import deepspeed + + logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + # this immediately partitions the model across all gpus, to avoid the overhead in time + # and memory copying it on CPU or each GPU first + init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()] + with ContextManagers(init_contexts): + model = cls(config, **kwargs) + + else: + model = cls(config, **kwargs) + + # restore default dtype if it was modified + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + return model + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) + 4. The default model's implementation otherwise (`LlamaAttention` for example) . + """ + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: + raise ValueError( + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + ) + + if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ + "eager" + ] + list(ALL_ATTENTION_FUNCTIONS.keys()): + message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if cls._supports_flex_attn: + message += ( + ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + ) + raise ValueError(message + ".") + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. + requested_attn_implementation = config._attn_implementation_internal + + # Composite models consisting of several PretrainedModels have to specify attention impl as a dict + # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it + # for all sub-models. + # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict. + # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)` + # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238 + for key in config.sub_configs.keys(): + sub_config = getattr(config, key) + curr_attn_implementation = ( + requested_attn_implementation + if not isinstance(requested_attn_implementation, dict) + else requested_attn_implementation.get(key, None) + ) + sub_config._attn_implementation_internal = curr_attn_implementation + + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + + if config._attn_implementation == "flash_attention_2": + cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + elif requested_attn_implementation == "flex_attention": + config = cls._check_and_enable_flex_attn(config, hard_check_only=True) + elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): + # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. + config = cls._check_and_enable_sdpa( + config, + hard_check_only=False if requested_attn_implementation is None else True, + ) + + if ( + torch.version.hip is not None + and config._attn_implementation == "sdpa" + and torch.cuda.device_count() > 1 + ): + logger.warning_once( + "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." + ) + torch.backends.cuda.enable_flash_sdp(False) + elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()): + config._attn_implementation = requested_attn_implementation + elif isinstance(requested_attn_implementation, dict): + config._attn_implementation = None + else: + config._attn_implementation = "eager" + + config._attn_implementation_autoset = True + return config + + @classmethod + def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (`torch.dtype`): + a floating dtype to set to. + + Returns: + `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was + modified. If it wasn't, returns `None`. + + Note `set_default_dtype` currently only works with floating-point types and asserts if for example, + `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception. + """ + if not dtype.is_floating_point: + raise ValueError( + f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype" + ) + + logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.") + dtype_orig = torch.get_default_dtype() + torch.set_default_dtype(dtype) + return dtype_orig + + @property + def base_model(self) -> nn.Module: + """ + `torch.nn.Module`: The main body of the model. + """ + return getattr(self, self.base_model_prefix, self) + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Directly inherits `GenerationMixin` -> can generate + if "GenerationMixin" in str(cls.__bases__): + return True + # Model class overwrites `generate` (e.g. time series models) -> can generate + if str(cls.__name__) in str(cls.generate): + return True + # The class inherits from a class that can generate (recursive check) -> can generate + for base in cls.__bases__: + if not hasattr(base, "can_generate"): + continue + if "PreTrainedModel" not in str(base) and base.can_generate(): + return True + # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # was how we detected whether a model could generate. + if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): + logger.warning_once( + f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " + "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " + "to call `generate` and other related functions." + "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " + "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" + "\n - If you are the owner of the model architecture code, please modify your model class such that " + "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." + "\n - If you are not the owner of the model architecture class, please contact the model code owner " + "to update it." + ) + return True + # Otherwise, can't generate + return False + + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 2 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_2: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_2_available(): + preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + + if importlib.util.find_spec("flash_attn") is None: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if torch.version.cuda: + if flash_attention_version < version.parse("2.1.0"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" + ) + elif not torch.cuda.is_available(): + raise ValueError( + f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device." + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + elif torch.version.hip: + if flash_attention_version < version.parse("2.0.4"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`' + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_2" + return config + + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of SDPA for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_sdpa: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" + ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_sdpa_available(): + raise ImportError( + "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1." + ) + + if not is_torch_sdpa_available() or not cls._supports_sdpa: + return config + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config._attn_implementation = "sdpa" + return config + + @classmethod + def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of Flex Attention for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_flex_attn: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch's flex_attention." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809." + " If you believe this error is a bug, please open an issue in Transformers GitHub repository" + ' and load your model with the argument `attn_implementation="eager"` meanwhile.' + ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_flex_attn_available(): + raise ImportError( + "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0." + ) + + if not is_torch_flex_attn_available() or not cls._supports_flex_attn: + return config + + if not hard_check_only: + config._attn_implementation = "flex_attention" + + return config + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + + def disable_input_require_grads(self): + """ + Removes the `_require_grads_hook`. + """ + self._require_grads_hook.remove() + + def get_input_embeddings(self) -> nn.Module: + """ + Returns the model's input embeddings. + + Returns: + `nn.Module`: A torch module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + else: + raise NotImplementedError + + def set_input_embeddings(self, value: nn.Module): + """ + Set model's input embeddings. + + Args: + value (`nn.Module`): A module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_input_embeddings(value) + else: + raise NotImplementedError + + def get_output_embeddings(self) -> nn.Module: + """ + Returns the model's output embeddings. + + Returns: + `nn.Module`: A torch module mapping hidden states to vocabulary. + """ + return None # Overwrite for models with output embeddings + + def _init_weights(self, module): + """ + Initialize the weights. This method should be overridden by derived class and is + the only initialization method that will be called when loading a checkpoint + using `from_pretrained`. Any attempt to initialize outside of this function + will be useless as the torch.nn.init function are all replaced with skip. + """ + pass + + def _initialize_weights(self, module): + """ + Initialize the weights if they are not already initialized. + """ + if getattr(module, "_is_hf_initialized", False): + return + self._init_weights(module) + module._is_hf_initialized = True + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the + weights instead. + """ + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): + if hasattr(self, self.base_model_prefix): + self = getattr(self, self.base_model_prefix) + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, self.decoder, self.base_model_prefix, "encoder" + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights + + for module in self.modules(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + + @staticmethod + def _tie_encoder_decoder_weights( + encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str + ): + uninitialized_encoder_weights: List[str] = [] + tied_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" + " weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + base_encoder_name: str, + uninitialized_encoder_weights: List[str], + depth=0, + total_decoder_name="", + total_encoder_name="", + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" + if hasattr(decoder_pointer, "weight"): + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") + encoder_pointer.bias = decoder_pointer.bias + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()} + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( + encoder_modules + ) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" + " a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + base_encoder_name, + uninitialized_encoder_weights, + depth=depth + 1, + total_encoder_name=f"{total_encoder_name}.{encoder_name}", + total_decoder_name=f"{total_decoder_name}.{decoder_name}", + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively( + decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights + ) + + if len(uninitialized_encoder_weights) > 0: + logger.warning( + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" + ) + return tied_weights + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """Tie or clone module weights depending of whether we are using TorchScript or not""" + if self.config.torchscript: + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) + else: + output_embeddings.weight = input_embeddings.weight + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = nn.functional.pad( + output_embeddings.bias.data, + ( + 0, + output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], + ), + "constant", + 0, + ) + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, PreTrainedModel): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The new number of tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Since we are basically resuing the same old embeddings with new weight values, gathering is required + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None): + vocab_size = model_embeds.weight.shape[0] + else: + vocab_size = model_embeds.weight.shape[0] + + # Update base model and current model config. + self.config.get_text_config().vocab_size = vocab_size + self.vocab_size = vocab_size + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings( + old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(old_embeddings, "_hf_hook"): + hook = old_embeddings._hf_hook + add_hook_to_module(new_embeddings, hook) + old_embeddings_requires_grad = old_embeddings.weight.requires_grad + new_embeddings.requires_grad_(old_embeddings_requires_grad) + self.set_input_embeddings(new_embeddings) + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + + # Update new_num_tokens with the actual size of new_embeddings + if pad_to_multiple_of is not None: + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None): + new_num_tokens = new_embeddings.weight.shape[0] + else: + new_num_tokens = new_embeddings.weight.shape[0] + + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() + if isinstance(old_lm_head, torch.nn.Embedding): + new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) + else: + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) + if hasattr(old_lm_head, "_hf_hook"): + hook = old_lm_head._hf_hook + add_hook_to_module(new_lm_head, hook) + old_lm_head_requires_grad = old_lm_head.weight.requires_grad + new_lm_head.requires_grad_(old_lm_head_requires_grad) + self.set_output_embeddings(new_lm_head) + + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + """ + Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_embeddings (`torch.nn.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + + Return: + `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if + `new_num_tokens` is `None` + """ + + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.weight.shape[0] + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + logger.info( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" + f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." + " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + + if new_num_tokens is None: + return old_embeddings + + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + else: + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + + if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + return old_embeddings + + if not isinstance(old_embeddings, nn.Embedding): + raise TypeError( + f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" + " should either use a different resize function or make sure that `old_embeddings` are an instance of" + f" {nn.Embedding}." + ) + + # Build new embeddings + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_embeddings = nn.Embedding( + new_num_tokens, + old_embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype, + ) + + if new_num_tokens > old_num_tokens and not mean_resizing: + # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`. + self._init_weights(new_embeddings) + + elif new_num_tokens > old_num_tokens and mean_resizing: + # initialize new embeddings (in particular added tokens). The new embeddings will be initialized + # from a multivariate normal distribution that has old embeddings' mean and covariance. + # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + logger.warning_once( + "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. " + "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. " + "To disable this, use `mean_resizing=False`" + ) + + added_num_tokens = new_num_tokens - old_num_tokens + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): + self._init_added_embeddings_weights_with_mean( + old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ) + else: + self._init_added_embeddings_weights_with_mean( + old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ) + + # Copy token embeddings from the previous weights + + # numbers of tokens to copy + n = min(old_num_tokens, new_num_tokens) + + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + else: + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + + # Replace weights in old_embeddings and return to maintain the same embedding type. + # This ensures correct functionality when a Custom Embedding class is passed as input. + # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979) + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + old_embeddings.weight = new_embeddings.weight + old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] + + # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx` + # will be set to `None` in the resized embeddings. + if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx: + old_embeddings.padding_idx = None + else: + old_embeddings.weight.data = new_embeddings.weight.data + old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] + if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx: + old_embeddings.padding_idx = None + + return old_embeddings + + def _get_resized_lm_head( + self, + old_lm_head: nn.Linear, + new_num_tokens: Optional[int] = None, + transposed: Optional[bool] = False, + mean_resizing: bool = True, + ) -> nn.Linear: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_lm_head (`torch.nn.Linear`): + Old lm head liner layer to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults + to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, + vocab_size` else `vocab_size, lm_head_dim`. + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + Return: + `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is + `None` + """ + if new_num_tokens is None: + return old_lm_head + + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + else: + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + + if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + return old_lm_head + + if not isinstance(old_lm_head, nn.Linear): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You" + " should either use a different resize function or make sure that `old_lm_head` are an instance of" + f" {nn.Linear}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_lm_head = nn.Linear( + *new_lm_head_shape, + bias=has_new_lm_head_bias, + device=old_lm_head.weight.device, + dtype=old_lm_head.weight.dtype, + ) + + if new_num_tokens > old_num_tokens and not mean_resizing: + # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`. + self._init_weights(new_lm_head) + + elif new_num_tokens > old_num_tokens and mean_resizing: + # initialize new lm_head weights (in particular added tokens). The new lm_head weights + # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. + # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + logger.warning_once( + "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. " + "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. " + "To disable this, use `mean_resizing=False`" + ) + + added_num_tokens = new_num_tokens - old_num_tokens + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_lm_head.weight] + if has_new_lm_head_bias: + params += [old_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=None): + self._init_added_lm_head_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed + ) + if has_new_lm_head_bias: + self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens) + + else: + self._init_added_lm_head_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed + ) + if has_new_lm_head_bias: + self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + else: + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + + return new_lm_head + + def _init_added_embeddings_weights_with_mean( + self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ): + old_embeddings_weight = old_embeddings.weight.data.to(torch.float32) + mean_embeddings = torch.mean(old_embeddings_weight, axis=0) + old_centered_embeddings = old_embeddings_weight - mean_embeddings + covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens + + # Check if the covariance is positive definite. + eigenvalues = torch.linalg.eigvals(covariance) + is_covariance_psd = bool( + (covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all() + ) + if is_covariance_psd: + # If covariances is positive definite, a distribution can be created. and we can sample new weights from it. + distribution = torch.distributions.multivariate_normal.MultivariateNormal( + mean_embeddings, covariance_matrix=1e-9 * covariance + ) + new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample( + sample_shape=(added_num_tokens,) + ).to(old_embeddings.weight.dtype) + else: + # Otherwise, just initialize with the mean. because distribtion will not be created. + new_embeddings.weight.data[-1 * added_num_tokens :, :] = ( + mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype) + ) + + def _init_added_lm_head_weights_with_mean( + self, + old_lm_head, + new_lm_head, + old_lm_head_dim, + old_num_tokens, + added_num_tokens, + transposed=False, + ): + if transposed: + # Transpose to the desired shape for the function. + new_lm_head.weight.data = new_lm_head.weight.data.T + old_lm_head.weight.data = old_lm_head.weight.data.T + + # The same initilization logic as Embeddings. + self._init_added_embeddings_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens + ) + + if transposed: + # Transpose again to the correct shape. + new_lm_head.weight.data = new_lm_head.weight.data.T + old_lm_head.weight.data = old_lm_head.weight.data.T + + def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens): + bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32) + bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32) + new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std) + + def _copy_lm_head_original_to_resized( + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ): + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + def resize_position_embeddings(self, new_num_position_embeddings: int): + raise NotImplementedError( + f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + raise NotImplementedError( + f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def init_weights(self): + """ + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. + """ + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + + if _init_weights: + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + def prune_heads(self, heads_to_prune: Dict[int, List[int]]): + """ + Prunes heads of the base model. + + Arguments: + heads_to_prune (`Dict[int, List[int]]`): + Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads + to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on + layer 1 and heads 2 and 3 on layer 2. + """ + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads + for layer, heads in heads_to_prune.items(): + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON + + self.base_model._prune_heads(heads_to_prune) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + + def gradient_checkpointing_disable(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=False) + else: + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~PreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + state_dict (nested dictionary of `torch.Tensor`): + The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only + save parts of the model or if special precautions need to be taken when recovering the state dictionary + of a model (like when using model parallelism). + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + We default it to 5GB in order for models to be able to run easily on free-tier google colab instances + without CPU OOM issues. + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + save_peft_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with PEFT library, in case adapter weights are attached to the model, all + keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can + disable this behaviours by setting `save_peft_format` to `False`. + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + + hf_quantizer = getattr(self, "hf_quantizer", None) + quantization_serializable = ( + hf_quantizer is not None + and isinstance(hf_quantizer, HfQuantizer) + and hf_quantizer.is_serializable(safe_serialization=safe_serialization) + ) + + if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." + ) + + if "save_config" in kwargs: + warnings.warn( + "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." + ) + is_main_process = kwargs.pop("save_config") + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Unset attn implementation so it can be set to another one when loading back + model_to_save.config._attn_implementation_autoset = False + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + # Save the config + if is_main_process: + if not _hf_peft_config_loaded: + # If the model config has set attributes that should be in the generation config, move them there. + misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters() + if self.can_generate() and len(misplaced_generation_parameters) > 0: + warnings.warn( + "Moving the following attributes in the config to the generation config: " + f"{misplaced_generation_parameters}. You are seeing this warning because you've set " + "generation parameters in the model config, as opposed to in the generation config.", + UserWarning, + ) + for param_name, param_value in misplaced_generation_parameters.items(): + setattr(model_to_save.generation_config, param_name, param_value) + setattr(model_to_save.config, param_name, None) + + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) + + if _hf_peft_config_loaded: + logger.info( + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." + ) + state_dict = model_to_save.get_adapter_state_dict() + + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict + + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] + current_peft_config.save_pretrained(save_directory) + + # for offloaded modules + module_map = {} + + # Save the model + if state_dict is None: + # if any model parameters are offloaded, make module map + if ( + hasattr(self, "hf_device_map") + and len(set(self.hf_device_map.values())) > 1 + and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()) + ): + warnings.warn( + "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)" + ) + for name, module in model_to_save.named_modules(): + if name == "": + continue + module_state_dict = module.state_dict() + + for key in module_state_dict: + module_map[name + f".{key}"] = module + state_dict = model_to_save.state_dict() + + # Translate state_dict from smp to hf if saving with smp >= 1.10 + if IS_SAGEMAKER_MP_POST_1_10: + for smp_to_hf, _ in smp.state.module_manager.translate_functions: + state_dict = smp_to_hf(state_dict) + + # Handle the case where some state_dict keys shouldn't be saved + if self._keys_to_ignore_on_save is not None: + for ignore_key in self._keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + + # Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model. + # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm) + state_dict = self._fix_state_dict_keys_on_save(state_dict) + + if safe_serialization: + # Safetensors does not allow tensor aliasing. + # We're going to remove aliases before saving + ptrs = collections.defaultdict(list) + for name, tensor in state_dict.items(): + # Sometimes in the state_dict we have non-tensor objects. + # e.g. in bitsandbytes we have some `str` objects in the state_dict + if isinstance(tensor, torch.Tensor): + ptrs[id_tensor_storage(tensor)].append(name) + else: + # In the non-tensor case, fall back to the pointer of the object itself + ptrs[id(tensor)].append(name) + + # These are all the pointers of shared tensors + if hasattr(self, "hf_device_map"): + # if the model has offloaded parameters, we must check using find_tied_parameters() + tied_params = find_tied_parameters(self) + if tied_params: + tied_names = tied_params[0] + shared_ptrs = { + ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names) + } + else: + shared_ptrs = {} + else: + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + + # Recursively descend to find tied weight keys + _tied_weights_keys = _get_tied_weight_keys(self) + error_names = [] + to_delete_names = set() + for names in shared_ptrs.values(): + # Removing the keys which are declared as known duplicates on + # load. This allows to make sure the name which is kept is consistent. + if _tied_weights_keys is not None: + found = 0 + for name in sorted(names): + matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys) + if matches_pattern and name in state_dict: + found += 1 + if found < len(names): + to_delete_names.add(name) + # We are entering a place where the weights and the transformers configuration do NOT match. + shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) + # Those are actually tensor sharing but disjoint from each other, we can safely clone them + # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. + for name in disjoint_names: + state_dict[name] = state_dict[name].clone() + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + shared_names, identical_names = _find_identical(shared_names, state_dict) + # delete tensors that have identical storage + for inames in identical_names: + known = inames.intersection(to_delete_names) + for name in known: + del state_dict[name] + unknown = inames.difference(to_delete_names) + if len(unknown) > 1: + error_names.append(unknown) + + if shared_names: + error_names.append(set(shared_names)) + + if len(error_names) > 0: + raise RuntimeError( + f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", + ) + + # Shard the model if it is too big. + if not _hf_peft_config_loaded: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + else: + weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in state_dict_split.filename_to_tensors.keys() + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + if module_map: + filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards") + for shard_file, tensors in filename_to_tensors: + shard = {} + for tensor in tensors: + shard[tensor] = state_dict[tensor].contiguous() + # delete reference, see https://github.com/huggingface/transformers/pull/34890 + del state_dict[tensor] + + # remake shard with onloaded parameters if necessary + if module_map: + if accelerate_version < version.parse("0.31"): + raise ImportError( + f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. " + f"Please upgrade accelerate with `pip install -U accelerate`" + ) + # init state_dict for this shard + shard_state_dict = {name: "" for name in shard} + for module_name in shard: + module = module_map[module_name] + # update state dict with onloaded parameters + shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) + + # assign shard to be the completed state dict + shard = shard_state_dict + del shard_state_dict + gc.collect() + + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + del state_dict + + if index is None: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + if push_to_hub: + # Eventually create an empty model card + model_card = create_and_tag_model_card( + repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors + ) + + # Update model card if needed: + model_card.save(os.path.join(save_directory, "README.md")) + + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + @wraps(PushToHubMixin.push_to_hub) + def push_to_hub(self, *args, **kwargs): + tags = self.model_tags if self.model_tags is not None else [] + + tags_kwargs = kwargs.get("tags", []) + if isinstance(tags_kwargs, str): + tags_kwargs = [tags_kwargs] + + for tag in tags_kwargs: + if tag not in tags: + tags.append(tag) + + if tags: + kwargs["tags"] = tags + return super().push_to_hub(*args, **kwargs) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.cuda` is not supported for HQQ-quantized models.") + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + else: + return super().cuda(*args, **kwargs) + + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours. + # the correct API should be to load the model with the desired dtype directly through `from_pretrained`. + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.to` is not supported for HQQ-quantized models.") + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: + if dtype_present_in_args: + raise ValueError( + "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" + " `dtype` by passing the correct `torch_dtype` argument." + ) + return super().to(*args, **kwargs) + + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().half(*args) + + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().float(*args) + + @classmethod + def from_pretrained( + cls: Type[SpecificPreTrainedModelType], + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + weights_only: bool = True, + **kwargs, + ) -> SpecificPreTrainedModelType: + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded + in using the `meta` device and brought into memory once an input is passed through that layer regardless of + `low_cpu_mem_usage`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + - `None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments `config` and `state_dict`). + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + state_dict (`Dict[str, torch.Tensor]`, *optional*): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and + [`~PreTrainedModel.from_pretrained`] is not a simpler option. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + _fast_init(`bool`, *optional*, defaults to `True`): + Whether or not to disable fast initialization. + + + + One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ < + 4.6.0` for seeded model initialization. This argument will be removed at the next major version. See + [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information. + + + attn_implementation (`str`, *optional*): + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + + > Parameters for big model inference + + low_cpu_mem_usage(`bool`, *optional*): + Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model. + Generally should be combined with a `device_map` (such as `"auto"`) for best results. + This is an experimental feature and a subject to change at any moment. + + If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without + `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However, + this should still be enabled if you are passing in a `device_map`. + + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under a specific `dtype`. The different options + are: + + 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified + `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified + - the model will get loaded in `torch.float` (fp32). + + 2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be + attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in + the checkpoint that's of a floating point type and use that as `dtype`. This will load the model + using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how + the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + + 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc. + + + + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or + reach out to the authors and ask them to add this information to the model's card and to insert the + `torch_dtype` entry in `config.json` on the hub. + + + + device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank + like `1`) on which the model will be allocated, the device map will map the entire model to this + device. Passing `device_map = 0` means put the whole model on GPU 0. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. + offload_buffers (`bool`, *optional*): + Whether or not to offload the buffers with the model parameters. + quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*): + A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g + bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and + `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes + quantizations and not preferred. consider inserting all such arguments into quantization_config + instead. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_tf` or `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. + + weights_only (`bool`, *optional*, defaults to `True`): + Indicates whether unpickler should be restricted to loading only tensors, primitive types, + dictionaries and any types added via torch.serialization.add_safe_globals(). + When set to False, we can load wrapper tensor subclass weights. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + Examples: + + ```python + >>> from transformers import BertConfig, BertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = BertModel.from_pretrained("./test/saved_model/") + >>> # Update configuration during loading. + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) + >>> assert model.config.output_attentions == True + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") + >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) + >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) + ``` + + * `low_cpu_mem_usage` algorithm: + + This is an experimental function that loads the model using ~1x model size CPU memory + + Here is how it works: + + 1. save which state_dict keys we have + 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory + 3. after the model has been instantiated switch to the meta device all params/buffers that + are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors + + """ + state_dict = kwargs.pop("state_dict", None) + from_tf = kwargs.pop("from_tf", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + offload_buffers = kwargs.pop("offload_buffers", False) + load_in_8bit = kwargs.pop("load_in_8bit", False) + load_in_4bit = kwargs.pop("load_in_4bit", False) + quantization_config = kwargs.pop("quantization_config", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", {}) + adapter_name = kwargs.pop("adapter_name", "default") + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + generation_config = kwargs.pop("generation_config", None) + + gguf_file = kwargs.pop("gguf_file", None) + # Cache path to the GGUF file + gguf_path = None + + tp_plan = kwargs.pop("tp_plan", None) + if tp_plan is not None and tp_plan != "auto": + # TODO: we can relax this check when we support taking tp_plan from a json file, for example. + raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") + + if is_fsdp_enabled(): + low_cpu_mem_usage = True + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: + adapter_kwargs["token"] = token + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + if gguf_file is not None and not is_accelerate_available(): + raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.") + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + if is_peft_available(): + _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) + + if _adapter_model_path is None: + _adapter_model_path = find_adapter_config_file( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + _commit_hash=commit_hash, + **adapter_kwargs, + ) + if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): + with open(_adapter_model_path, "r", encoding="utf-8") as f: + _adapter_model_path = pretrained_model_name_or_path + pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + else: + _adapter_model_path = None + + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." + ) + elif not is_accelerate_available(): + raise ImportError( + f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + + # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. + if load_in_4bit or load_in_8bit: + if quantization_config is not None: + raise ValueError( + "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing " + "`quantization_config` argument at the same time." + ) + + # preparing BitsAndBytesConfig from kwargs + config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters} + config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit} + quantization_config, kwargs = BitsAndBytesConfig.from_dict( + config_dict=config_dict, return_unused_kwargs=True, **kwargs + ) + logger.warning( + "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. " + "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead." + ) + + from_pt = not (from_tf | from_flax) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + # In case one passes a config to `from_pretrained` + "attn_implementation" + # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs + # Please see: https://github.com/huggingface/transformers/issues/28038 + + # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory + # we pop attn_implementation from the kwargs but this handles the case where users + # passes manually the config to `from_pretrained`. + config = copy.deepcopy(config) + + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None: + config._attn_implementation = kwarg_attn_imp + + model_kwargs = kwargs + + pre_quantized = getattr(config, "quantization_config", None) is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config.quantization_config = AutoHfQuantizer.merge_quantization_configs( + config.quantization_config, quantization_config + ) + else: + config.quantization_config = quantization_config + + hf_quantizer = AutoHfQuantizer.from_config( + config.quantization_config, + pre_quantized=pre_quantized, + ) + + else: + hf_quantizer = None + + if hf_quantizer is not None: + hf_quantizer.validate_environment( + torch_dtype=torch_dtype, + from_tf=from_tf, + from_flax=from_flax, + device_map=device_map, + weights_only=weights_only, + ) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + device_map = hf_quantizer.update_device_map(device_map) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.") + is_quantized = hf_quantizer is not None + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None + # Load model + loading_info = None + + # Keep in fp32 modules + keep_in_fp32_modules = None + use_keep_in_fp32_modules = False + + if gguf_file is not None and hf_quantizer is not None: + raise ValueError( + "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." + ) + + if pretrained_model_name_or_path is not None and gguf_file is None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ): + # Load from a TF 1.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + elif from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + ): + # Load from a TF 2.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + elif from_flax and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif not use_safetensors and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) + or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) + ): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" + " `from_tf=True` to load this model from those weights." + ) + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" + " to load this model from those weights." + ) + elif use_safetensors: + raise EnvironmentError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): + if not from_tf: + raise ValueError( + f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " + "from_tf to True to load from this checkpoint." + ) + archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + elif use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + if revision == "main": + resolved_archive_file, revision, is_sharded = auto_conversion( + pretrained_model_name_or_path, **cached_file_kwargs + ) + cached_file_kwargs["revision"] = revision + if resolved_archive_file is None: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + if not local_files_only and not is_offline_mode(): + if resolved_archive_file is not None: + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "resume_download": resume_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, + name="Thread-auto_conversion", + ).start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." + " Use `from_tf=True` to load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" + " `from_flax=True` to load this model from those weights." + ) + elif variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) + + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + elif gguf_file: + from .modeling_gguf_pytorch_utils import load_gguf_checkpoint + + # Case 1: the GGUF file is present locally + if os.path.isfile(gguf_file): + gguf_path = gguf_file + # Case 2: The GGUF path is a location on the Hub + # Load from URL or cache if already cached + else: + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs) + + # we need a dummy model to help rename state_dict + with torch.device("meta"): + dummy_model = cls(config) + state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True, model_to_load=dummy_model)["tensors"] + + resolved_archive_file = None + is_sharded = False + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + if ( + is_safetensors_available() + and isinstance(resolved_archive_file, str) + and resolved_archive_file.endswith(".safetensors") + ): + with safe_open(resolved_archive_file, framework="pt") as f: + metadata = f.metadata() + + if metadata is None: + # Assume it's a pytorch checkpoint (introduced for timm checkpoints) + pass + elif metadata.get("format") == "pt": + pass + elif metadata.get("format") == "tf": + from_tf = True + logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "flax": + from_flax = True + logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "mlx": + # This is a mlx file, we assume weights are compatible with pt + pass + else: + raise ValueError( + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}" + ) + + from_pt = not (from_tf | from_flax) + + # load pt weights early so that we know which dtype to init the model under + + if from_pt: + if not is_sharded and state_dict is None: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + dtype_orig = None + + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + elif not is_sharded: + torch_dtype = get_state_dict_dtype(state_dict) + else: + one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) + torch_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + logger.info( + "Since the `torch_dtype` attribute can't be found in model's config object, " + "will use torch_dtype={torch_dtype} as derived from model's weights" + ) + elif hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, torch.dtype): + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, dict): + for key, curr_dtype in torch_dtype.items(): + if hasattr(config, key): + value = getattr(config, key) + value.torch_dtype = curr_dtype + # main torch dtype for modules that aren't part of any sub-config + torch_dtype = torch_dtype.get("") + config.torch_dtype = torch_dtype + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + elif torch_dtype is None: + torch_dtype = torch.float32 + else: + raise ValueError( + f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " + f"for each sub-config in composite configs, but received {torch_dtype}" + ) + + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_state_dict_keys = list(state_dict.keys()) + if ( + gguf_path is None + and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) + and pretrained_model_name_or_path is not None + ): + # In case some weights need to be kept in float32 and accelerate is not installed, + # we later on want to take the path where state_dict is not None, that is the one + # that do not require accelerate. + state_dict = None + + config.name_or_path = pretrained_model_name_or_path + + # Instantiate model. + init_contexts = [no_init_weights(_enable=_fast_init)] + tp_device = None + + if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: + import deepspeed + + logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + init_contexts = [ + deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), + set_zero3_state(), + ] + init_contexts + elif low_cpu_mem_usage: + if not is_accelerate_available(): + raise ImportError( + f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + init_contexts.append(init_empty_weights()) + elif tp_plan is not None: + if not torch.distributed.is_initialized(): + raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") + + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type + device_module = torch.get_device_module(device_type) + # Get device with index assuming equal number of devices per host + tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) + init_contexts.append(tp_device) + + if is_deepspeed_zero3_enabled() and is_quantized: + init_contexts.append(set_quantized_state()) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map + ) + + with ContextManagers(init_contexts): + # Let's make sure we don't run the init function of buffer modules + model = cls(config, *model_args, **model_kwargs) + + # make sure we use the model's config since the __init__ call might have copied it + config = model.config + + # Check first if we are `from_pt` + if use_keep_in_fp32_modules: + if is_accelerate_available() and not is_deepspeed_zero3_enabled(): + low_cpu_mem_usage = True + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # We store the original dtype for quantized models as we cannot easily retrieve it + # once the weights have been quantized + # Note that once you have loaded a quantized model, you can't change its dtype so this will + # remain a single source of truth + config._pre_quantization_dtype = torch_dtype + + if isinstance(device_map, str): + special_dtypes = {} + + if hf_quantizer is not None: + special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) + + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + + if hf_quantizer is not None: + target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) + + no_split_modules = model._get_no_split_modules(device_map) + if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + + device_map_kwargs = {"no_split_module_classes": no_split_modules} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=target_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + if hf_quantizer is not None: + max_memory = hf_quantizer.adjust_max_memory(max_memory) + device_map_kwargs["max_memory"] = max_memory + + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + elif device_map is not None: + model.tie_weights() + tied_params = find_tied_parameters(model) + # check if we don't have tied param in different devices + check_tied_parameters_on_same_device(tied_params, device_map) + + if from_tf: + if resolved_archive_file.endswith(".index"): + # Load from a TensorFlow 1.X checkpoint - provided by original authors + model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' + else: + # Load from our TensorFlow 2.0 checkpoints + try: + from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model + + model, loading_info = load_tf2_checkpoint_in_pytorch_model( + model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True + ) + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." + " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation" + " instructions." + ) + raise + elif from_flax: + try: + from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + except ImportError: + logger.error( + "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for" + " installation instructions." + ) + raise + elif from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + load_contexts = [] + # Make sure we load onto targeted device + if tp_device is not None: + load_contexts.append(tp_device) + + with ContextManagers(load_contexts): + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + gguf_path=gguf_path, + weights_only=weights_only, + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate() and generation_config is not None: + logger.info("The user-defined `generation_config` will be used to override the default generation config.") + model.generation_config = model.generation_config.from_dict(generation_config.to_dict()) + elif model.can_generate() and pretrained_model_name_or_path is not None: + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + # Dispatch model with hooks on all devices if necessary + if device_map is not None: + device_map_kwargs = { + "device_map": device_map, + "offload_dir": offload_folder, + "offload_index": offload_index, + "offload_buffers": offload_buffers, + } + if "skip_keys" in inspect.signature(dispatch_model).parameters: + device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + # For HQQ method we force-set the hooks for single GPU envs + if ( + "force_hooks" in inspect.signature(dispatch_model).parameters + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ + ): + device_map_kwargs["force_hooks"] = True + if ( + hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8 + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + device_map_kwargs["offload_buffers"] = True + + if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + dispatch_model(model, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model, config=config) + model.hf_quantizer = hf_quantizer + + if _adapter_model_path is not None: + model.load_adapter( + _adapter_model_path, + adapter_name=adapter_name, + token=token, + adapter_kwargs=adapter_kwargs, + ) + + if output_loading_info: + if loading_info is None: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + if tp_plan is not None: + assert tp_device is not None, "tp_device not set!" + if not model.supports_tp_plan: + raise NotImplementedError("This model does not have a tensor parallel plan.") + # Assuming sharding the model onto the world + world_size = torch.distributed.get_world_size() + device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) + # Apply Tensor Parallelism + model.tensor_parallel(device_mesh) + + return model + + @staticmethod + def _fix_state_dict_key_on_load(key): + """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" + + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + + # to avoid logging parametrized weight norm renaming + if hasattr(nn.utils.parametrizations, "weight_norm"): + if "weight_g" in key: + return key.replace("weight_g", "parametrizations.weight.original0") + if "weight_v" in key: + return key.replace("weight_v", "parametrizations.weight.original1") + else: + if "parametrizations.weight.original0" in key: + return key.replace("parametrizations.weight.original0", "weight_g") + if "parametrizations.weight.original1" in key: + return key.replace("parametrizations.weight.original1", "weight_v") + return key + + @classmethod + def _fix_state_dict_keys_on_load(cls, state_dict): + """Fixes state dict keys by replacing legacy parameter names with their modern equivalents. + Logs if any parameters have been renamed. + """ + + renamed_keys = {} + state_dict_keys = list(state_dict.keys()) + for key in state_dict_keys: + new_key = cls._fix_state_dict_key_on_load(key) + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + + # add it once for logging + if "gamma" in key and "gamma" not in renamed_keys: + renamed_keys["gamma"] = (key, new_key) + if "beta" in key and "beta" not in renamed_keys: + renamed_keys["beta"] = (key, new_key) + + if renamed_keys: + warning_msg = f"A pretrained model of type `{cls.__name__}` " + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.values(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) + + return state_dict + + @staticmethod + def _fix_state_dict_key_on_save(key): + """ + Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save. + Do nothing by default, but can be overriden in particular models. + """ + return key + + def _fix_state_dict_keys_on_save(self, state_dict): + """ + Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save. + Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`. + """ + return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()} + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + _fast_init=True, + low_cpu_mem_usage=False, + device_map=None, + offload_folder=None, + offload_state_dict=None, + dtype=None, + hf_quantizer=None, + keep_in_fp32_modules=None, + gguf_path=None, + weights_only=True, + ): + is_safetensors = False + is_quantized = hf_quantizer is not None + state_dict_folder = None + state_dict_index = None + + if device_map is not None and "disk" in device_map.values(): + archive_file = ( + resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file + ) + is_safetensors = archive_file.endswith(".safetensors") + if offload_folder is None and not is_safetensors: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if offload_folder is not None: + os.makedirs(offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + + is_sharded_safetensors = is_safetensors and sharded_metadata is not None + + # tie the model weights before retrieving the state_dict + model.tie_weights() + + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + prefix = model.base_model_prefix + + if hf_quantizer is not None: + expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) + + original_loaded_keys = loaded_keys + loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys] + + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False + + # key re-naming operations are never done on the keys + # that are loaded, but always on the keys of the newly initialized model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + if remove_prefix_from_model: + _prefix = f"{prefix}." + expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] + expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] + elif add_prefix_to_model: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = sorted(set(expected_keys) - set(loaded_keys)) + unexpected_keys = set(loaded_keys) - set(expected_keys) + + # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model + # buffers + model_buffers = {n for n, _ in model.named_buffers()} + if remove_prefix_from_model: + model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} + elif add_prefix_to_model: + model_buffers = {".".join([prefix, key]) for key in model_buffers} + unexpected_keys = sorted(unexpected_keys - model_buffers) + + # Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858) + has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) + if has_inv_freq_buffers: + unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k} + + model.tie_weights() + if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + ptrs = collections.defaultdict(list) + for name, tensor in model.state_dict().items(): + id_tensor = id_tensor_storage(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + else: + # id function doesn't work for meta tensor so we need this function + tied_params = find_tied_parameters(model) + + for group in tied_params: + if remove_prefix_from_model: + group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] + elif add_prefix_to_model: + group = [".".join([prefix, key]) for key in group] + missing_in_group = [k for k in missing_keys if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + missing_keys = [k for k in missing_keys if k not in missing_in_group] + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) + + # retrieve weights on meta device and put them back on CPU. + # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step + if low_cpu_mem_usage: + for key in missing_keys: + if key in list(model_state_dict.keys()): + key = key + elif f"{prefix}.{key}" in list(model_state_dict.keys()): + key = f"{prefix}.{key}" + elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): + key = ".".join(key.split(".")[1:]) + param = model_state_dict[key] + + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any( + module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + ): + target_dtype = torch.float32 + + if param.device == torch.device("meta"): + value = torch.empty(*param.size(), dtype=target_dtype) + if ( + not is_quantized + or (getattr(hf_quantizer, "requires_parameters_quantization", False)) + or not hf_quantizer.check_quantized_param( + model, param_value=value, param_name=key, state_dict={} + ) + ): + set_module_tensor_to_device(model, key, "cpu", value) + else: + hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) + + # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + if not ignore_mismatched_sizes: + if remove_prefix_from_model: + _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) + # If we're about to tie the output embeds to the input embeds we don't need to init them + if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: + output_embeddings = model.get_output_embeddings() + if output_embeddings is not None: + # Still need to initialize if there is a bias term since biases are not tied. + if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: + output_embeddings._is_hf_initialized = True + else: + not_initialized_submodules = dict(model.named_modules()) + # This will only initialize submodules that are not marked as initialized by the line above. + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + not_initialized_parameters = list( + set( + itertools.chain.from_iterable( + submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() + ) + ) + ) + with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): + model.apply(model._initialize_weights) + else: + model.apply(model._initialize_weights) + + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + raise ValueError( + "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " + "properly saved?" + ) + if device_map is not None: + device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys): + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{model_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(model_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + if ( + state_dict[checkpoint_key].shape[-1] == 1 + and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() + ): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + pass + else: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if resolved_archive_file is not None: + folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) + else: + folder = None + if device_map is not None and is_safetensors: + param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) + str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" + if sharded_metadata is None: + archive_file = ( + resolved_archive_file[0] + if isinstance(resolved_archive_file, (list, tuple)) + else resolved_archive_file + ) + weight_map = {p: archive_file for p in original_loaded_keys} + else: + weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} + offload_index = { + p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} + for p, f in weight_map.items() + if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" + } + else: + offload_index = None + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + + # For GGUF models `state_dict` is never set to None as the state dict is always small + if gguf_path or low_cpu_mem_usage: + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) + error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + fixed_state_dict, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) + error_msgs = _load_state_dict_into_model( + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + ) + + else: + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + mismatched_keys = [] + if not is_safetensors: + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + else: + state_dict_folder = None + state_dict_index = None + + if is_sharded_safetensors: + disk_only_shard_files = get_disk_only_shard_files( + device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix + ) + disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] + else: + disk_only_shard_files = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + assign_to_params_buffers = None + for shard_file in resolved_archive_file: + # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. + if shard_file in disk_only_shard_files: + continue + map_location = None + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type == "int4_weight_only" + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + ) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + if low_cpu_mem_usage: + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + set_module_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) + new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + fixed_state_dict, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + error_msgs += new_error_msgs + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) + error_msgs += _load_state_dict_into_model( + model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + ) + + # force memory release + del state_dict + gc.collect() + + if offload_index is not None and len(offload_index) > 0: + if model != model_to_load: + # We need to add the prefix of the base model + prefix = cls.base_model_prefix + if not is_safetensors: + for weight_name in offload_index: + shutil.move( + os.path.join(offload_folder, f"{weight_name}.dat"), + os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), + ) + offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} + if not is_safetensors: + save_offload_index(offload_index, offload_folder) + offload_index = None + + if offload_state_dict: + # Load back temporarily offloaded state dict + load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + archs = [] if model.config.architectures is None else model.config.architectures + warner = logger.warning if model.__class__.__name__ in archs else logger.info + warner( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs + + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): + module_keys = {".".join(key.split(".")[:-1]) for key in names} + + # torch.nn.ParameterList is a special case where two parameter keywords + # are appended to the module name, *e.g.* bert.special_embeddings.0 + module_keys = module_keys.union( + {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()} + ) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + _prefix = f"{self.base_model_prefix}." + name = name[len(_prefix) :] if name.startswith(_prefix) else name + elif add_prefix: + name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + + @staticmethod + def _load_pretrained_model_low_mem( + model, + loaded_state_dict_keys, + resolved_archive_file, + start_prefix="", + hf_quantizer=None, + pretrained_model_name_or_path=None, + weights_only=True, + ): + """ + This is an experimental function that loads the model using ~1.x model size CPU memory + + Before you call it do: + + 1. save which state_dict keys are available + 2. drop state_dict before model is created, since the latter takes 1x model size memory + + Here then we continue: + + 3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To + handle bitsandbytes, needs non-empty hf_quantizer argument. + """ + + _move_model_to_meta(model, loaded_state_dict_keys, start_prefix) + state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) + expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys + fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict) + error_msgs = _load_state_dict_into_meta_model( + model, + fixed_state_dict, + start_prefix, + expected_keys=expected_keys, + hf_quantizer=hf_quantizer, + ) + return error_msgs + + @classmethod + def register_for_auto_class(cls, auto_class="AutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def to_bettertransformer(self) -> "PreTrainedModel": + """ + Converts the model to use [PyTorch's native attention + implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to + Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a + subset of all Transformers models are supported. + + PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested + tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog + post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2). + + Returns: + [`PreTrainedModel`]: The model converted to BetterTransformer. + """ + if not is_optimum_available(): + raise ImportError("The package `optimum` is required to use Better Transformer.") + + from optimum.version import __version__ as optimum_version + + if version.parse(optimum_version) < version.parse("1.7.0"): + raise ImportError( + f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found." + ) + + from optimum.bettertransformer import BetterTransformer + + return BetterTransformer.transform(self) + + def reverse_bettertransformer(self): + """ + Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is + used, for example in order to save the model. + + Returns: + [`PreTrainedModel`]: The model converted back to the original modeling. + """ + if not is_optimum_available(): + raise ImportError("The package `optimum` is required to use Better Transformer.") + + from optimum.version import __version__ as optimum_version + + if version.parse(optimum_version) < version.parse("1.7.0"): + raise ImportError( + f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found." + ) + + from optimum.bettertransformer import BetterTransformer + + return BetterTransformer.reverse(self) + + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + + # Skip the check during tracing. + if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling(): + return + + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See " + "https://huggingface.co/docs/transformers/troubleshooting" + "#incorrect-output-when-padding-tokens-arent-masked." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning_once(warn_string) + + @property + def supports_tp_plan(self): + """ + Returns whether the model has a tensor parallelism plan. + """ + if self._tp_plan is not None: + return True + # Check if base model has a TP plan + if getattr(self.base_model, "_tp_plan", None) is not None: + return True + return False + + def tensor_parallel(self, device_mesh): + """ + Tensor parallelize the model across the given device mesh. + + Args: + device_mesh (`torch.distributed.DeviceMesh`): + The device mesh to use for tensor parallelism. + """ + if not is_torch_greater_or_equal("2.5"): + raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") + + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. + # No op if `_tp_plan` attribute does not exist under the module. + # This is a helper function to be used with `model.apply` to recursively + # parallelize a model. + def tplize(mod: torch.nn.Module) -> None: + tp_plan = getattr(mod, "_tp_plan", None) + if tp_plan is None: + return + logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") + # In model configs, we use a neutral type (string) to specify + # parallel styles, here we translate them into torch TP types. + # Using tree_map because `tp_plan` is a dict. + tp_plan = torch.utils._pytree.tree_map( + translate_to_torch_parallel_style, + tp_plan, + ) + # Apply TP to current module. + torch.distributed.tensor.parallel.parallelize_module( + mod, + device_mesh=device_mesh, + parallelize_plan=tp_plan, + ) + + # `apply` is a native method of `nn.Module` that recursively applies a + # function to every submodule. + self.apply(tplize) + + @property + def loss_function(self): + if hasattr(self, "_loss_function"): + return self._loss_function + + loss_type = getattr(self, "loss_type", None) + + if loss_type is None or loss_type not in LOSS_MAPPING: + logger.warning_once( + f"`loss_type={loss_type}` was set in the config but it is unrecognised." + f"Using the default loss: `ForCausalLMLoss`." + ) + loss_type = "ForCausalLM" + return LOSS_MAPPING[loss_type] + + @loss_function.setter + def loss_function(self, value): + self._loss_function = value + + def get_compiled_call(self, compile_config: CompileConfig): + """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between + non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't + want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding + (where we want the speed-ups of compiled version with static shapes).""" + # Only reset it if not present or different from previous config + default_config = getattr(self.generation_config, "compile_config", CompileConfig()) + if ( + not hasattr(self, "_compiled_call") + or getattr(self, "_last_compile_config", default_config) != compile_config + ): + self._last_compile_config = compile_config + self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict()) + return self._compiled_call + + +PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) +if PreTrainedModel.push_to_hub.__doc__ is not None: + PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format( + object="model", object_class="AutoModel", object_files="model file" + ) + + +class PoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +@dataclass +class SquadHeadOutput(ModelOutput): + """ + Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + + +class SQuADHead(nn.Module): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = PoolerStartLogits(config) + self.end_logits = PoolerEndLogits(config) + self.answer_class = PoolerAnswerClass(config) + + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) + def forward( + self, + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_dict: bool = False, + ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return SquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +class SequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else Identity() + + self.first_dropout = Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + recursive (`bool`, *optional*, defaults to `False`): + Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers + recursively, not just the top-level distributed containers. + """ + # Use accelerate implementation if available (should always be the case when using torch) + # This is for pytorch, as we also have to handle things like dynamo + if is_accelerate_available(): + kwargs = {} + if recursive: + if not is_accelerate_available("0.29.0"): + raise RuntimeError( + "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate" + ) + else: + kwargs["recursive"] = recursive + return extract_model_from_parallel(model, **kwargs) + else: + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model + + +def expand_device_map(device_map, param_names, start_prefix): + """ + Expand a device map to return the correspondance parameter name to device. + """ + new_device_map = {} + param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)] + for module, device in device_map.items(): + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) + return new_device_map + + +def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): + """ + Returns the list of shard files containing only weights offloaded to disk. + """ + + weight_map = { + p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix) + } + files_content = collections.defaultdict(list) + for weight_name, filename in weight_map.items(): + while len(weight_name) > 0 and weight_name not in device_map: + weight_name = ".".join(weight_name.split(".")[:-1]) + files_content[filename].append(device_map[weight_name]) + + return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] + + +ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {} + +ALL_ATTENTION_FUNCTIONS.update( + { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "sdpa": sdpa_attention_forward, + } +) diff --git a/optimization.py b/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca5d36d0f40e3e79a4308be6fed065be660c493 --- /dev/null +++ b/optimization.py @@ -0,0 +1,958 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for BERT model.""" + +import math +import warnings +from functools import partial +from typing import Callable, Iterable, Optional, Tuple, Union + +import torch +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau + +from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler +from .trainer_utils import SchedulerType +from .utils import logging +from .utils.versions import require_version + + +logger = logging.get_logger(__name__) + + +def _get_constant_lambda(_=None): + return 1 + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) + + +def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs): + """ + Create a schedule with a constant learning rate that decreases when a metric has stopped improving. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + kwargs (`dict`, *optional*): + Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau` + for possible parameters. + + Return: + `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. + """ + + return ReduceLROnPlateau(optimizer, **kwargs) + + +def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_polynomial_decay_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float, + power: float, + lr_init: int, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") + + lr_lambda = partial( + _get_polynomial_decay_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + lr_end=lr_end, + power=power, + lr_init=lr_init, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + shift = timescale - num_warmup_steps + decay = 1.0 / math.sqrt((current_step + shift) / timescale) + return decay + + +def get_inverse_sqrt_schedule( + optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1 +): + """ + Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a + warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + timescale (`int`, *optional*, defaults to `num_warmup_steps`): + Time scale. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + # Note: this implementation is adapted from + # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930 + + if timescale is None: + timescale = num_warmup_steps or 10_000 + + lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_with_min_lr_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + min_lr: float = None, + min_lr_rate: float = None, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + min_lr (`float`, *optional*): + The minimum learning rate to reach after the cosine schedule. + min_lr_rate (`float`, *optional*): + The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + if min_lr is not None and min_lr_rate is not None: + raise ValueError("Only one of min_lr or min_lr_rate should be set") + elif min_lr is not None: + min_lr_rate = min_lr / optimizer.defaults["lr"] + elif min_lr_rate is None: + raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`") + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + min_lr_rate=min_lr_rate, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_wsd_scheduler_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_stable_steps: int, + num_decay_steps: int, + num_cycles: float, + min_lr_ratio: float, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps + num_stable_steps: + return 1.0 + if current_step < num_warmup_steps + num_stable_steps + num_decay_steps: + progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) + value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + return (1.0 - min_lr_ratio) * value + min_lr_ratio + return min_lr_ratio + + +def get_wsd_schedule( + optimizer: Optimizer, + num_warmup_steps: int, + num_stable_steps: int, + num_decay_steps: int, + min_lr_ratio: float = 0, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that has three stages: + 1. linear increase from 0 to initial lr. + 2. constant lr (equal to initial lr). + 3. decrease following the values of the cosine function between the initial lr set in the optimizer to + a fraction of initial lr. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_stable_steps (`int`): + The number of steps for the stable phase. + num_decay_steps (`int`): + The number of steps for the cosine annealing phase. + min_lr_ratio (`float`, *optional*, defaults to 0): + The minimum learning rate as a ratio of the initial learning rate. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + lr_lambda = partial( + _get_wsd_scheduler_lambda, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, + SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, + SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup, + SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + scheduler_specific_kwargs: Optional[dict] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + scheduler_specific_kwargs (`dict`, *optional*): + Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler + parameters will cause the scheduler function to raise a TypeError. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and + # recursively call `get_scheduler` to get the proper schedulers on each parameter + if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + name, + optimizer=optimizer_dict[param], + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + def scheduler_hook(param): + # Since the optimizer hook has been already attached we only need to + # attach the scheduler hook, the gradients have been zeroed here + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + if param.requires_grad: + param.register_post_accumulate_grad_hook(scheduler_hook) + + return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"]) + + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + if scheduler_specific_kwargs is None: + scheduler_specific_kwargs = {} + + if name == SchedulerType.REDUCE_ON_PLATEAU: + return schedule_func(optimizer, **scheduler_specific_kwargs) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **scheduler_specific_kwargs, + ) + + +class AdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + no_deprecation_warning: bool = False, + ): + if not no_deprecation_warning: + warnings.warn( + "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" + " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" + " warning", + FutureWarning, + ) + require_version("torch>=1.5.0") # add_ with alpha + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + return loss + + +class Adafactor(Optimizer): + """ + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + require_version("torch>=1.5.0") # add_ with alpha + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss + + +class AdafactorSchedule(LambdaLR): + """ + Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g., + for logging), this class creates a proxy object that retrieves the current lr values from the optimizer. + + It returns `initial_lr` during startup and the actual `lr` during stepping. + """ + + def __init__(self, optimizer, initial_lr=0.0): + def lr_lambda(_): + return initial_lr + + for group in optimizer.param_groups: + group["initial_lr"] = initial_lr + super().__init__(optimizer, lr_lambda) + for group in optimizer.param_groups: + del group["initial_lr"] + + def get_lr(self): + opt = self.optimizer + lrs = [ + opt._get_lr(group, opt.state[group["params"][0]]) + for group in opt.param_groups + if group["params"][0].grad is not None + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + +def get_adafactor_schedule(optimizer, initial_lr=0.0): + """ + Get a proxy schedule for [`~optimization.Adafactor`] + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + initial_lr (`float`, *optional*, defaults to 0.0): + Initial lr + + Return: + [`~optimization.Adafactor`] proxy schedule object. + + + """ + return AdafactorSchedule(optimizer, initial_lr) diff --git a/optimization_tf.py b/optimization_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..f27913156c44614d662e81d4b520778d2d98fcd1 --- /dev/null +++ b/optimization_tf.py @@ -0,0 +1,379 @@ +# Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions and classes related to optimization (weight updates).""" + +import re +from typing import Callable, List, Optional, Union + +import tensorflow as tf + + +try: + from tf_keras.optimizers.legacy import Adam +except (ImportError, ModuleNotFoundError): + from tensorflow.keras.optimizers.legacy import Adam + +from .modeling_tf_utils import keras + + +# This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15 +if hasattr(keras.optimizers.schedules, "learning_rate_schedule"): + schedules = keras.optimizers.schedules.learning_rate_schedule +else: + schedules = keras.optimizers.schedules + + +class WarmUp(schedules.LearningRateSchedule): + """ + Applies a warmup schedule on a given learning rate decay schedule. + + Args: + initial_learning_rate (`float`): + The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end + of the warmup). + decay_schedule_fn (`Callable`): + The schedule function to apply after the warmup for the rest of training. + warmup_steps (`int`): + The number of steps for the warmup part of training. + power (`float`, *optional*, defaults to 1.0): + The power to use for the polynomial warmup (defaults is a linear warmup). + name (`str`, *optional*): + Optional name prefix for the returned tensors during the schedule. + """ + + def __init__( + self, + initial_learning_rate: float, + decay_schedule_fn: Callable, + warmup_steps: int, + power: float = 1.0, + name: str = None, + ): + super().__init__() + self.initial_learning_rate = initial_learning_rate + self.warmup_steps = warmup_steps + self.power = power + self.decay_schedule_fn = decay_schedule_fn + self.name = name + + def __call__(self, step): + with tf.name_scope(self.name or "WarmUp") as name: + # Implements polynomial warmup. i.e., if global_step < warmup_steps, the + # learning rate will be `global_step/num_warmup_steps * init_lr`. + global_step_float = tf.cast(step, tf.float32) + warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) + warmup_percent_done = global_step_float / warmup_steps_float + warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power) + return tf.cond( + global_step_float < warmup_steps_float, + lambda: warmup_learning_rate, + lambda: self.decay_schedule_fn(step - self.warmup_steps), + name=name, + ) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_schedule_fn": self.decay_schedule_fn, + "warmup_steps": self.warmup_steps, + "power": self.power, + "name": self.name, + } + + +def create_optimizer( + init_lr: float, + num_train_steps: int, + num_warmup_steps: int, + min_lr_ratio: float = 0.0, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_epsilon: float = 1e-8, + adam_clipnorm: Optional[float] = None, + adam_global_clipnorm: Optional[float] = None, + weight_decay_rate: float = 0.0, + power: float = 1.0, + include_in_weight_decay: Optional[List[str]] = None, +): + """ + Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay. + + Args: + init_lr (`float`): + The desired learning rate at the end of the warmup phase. + num_train_steps (`int`): + The total number of training steps. + num_warmup_steps (`int`): + The number of warmup steps. + min_lr_ratio (`float`, *optional*, defaults to 0): + The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`. + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 to use in Adam. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 to use in Adam. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon to use in Adam. + adam_clipnorm (`float`, *optional*, defaults to `None`): + If not `None`, clip the gradient norm for each weight tensor to this value. + adam_global_clipnorm (`float`, *optional*, defaults to `None`) + If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all + weight tensors, as if they were concatenated into a single vector. + weight_decay_rate (`float`, *optional*, defaults to 0): + The weight decay to use. + power (`float`, *optional*, defaults to 1.0): + The power to use for PolynomialDecay. + include_in_weight_decay (`List[str]`, *optional*): + List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is + applied to all parameters except bias and layer norm parameters. + """ + # Implements linear decay of the learning rate. + lr_schedule = schedules.PolynomialDecay( + initial_learning_rate=init_lr, + decay_steps=num_train_steps - num_warmup_steps, + end_learning_rate=init_lr * min_lr_ratio, + power=power, + ) + if num_warmup_steps: + lr_schedule = WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=num_warmup_steps, + ) + if weight_decay_rate > 0.0: + optimizer = AdamWeightDecay( + learning_rate=lr_schedule, + weight_decay_rate=weight_decay_rate, + beta_1=adam_beta1, + beta_2=adam_beta2, + epsilon=adam_epsilon, + clipnorm=adam_clipnorm, + global_clipnorm=adam_global_clipnorm, + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], + include_in_weight_decay=include_in_weight_decay, + ) + else: + optimizer = keras.optimizers.Adam( + learning_rate=lr_schedule, + beta_1=adam_beta1, + beta_2=adam_beta2, + epsilon=adam_epsilon, + clipnorm=adam_clipnorm, + global_clipnorm=adam_global_clipnorm, + ) + # We return the optimizer and the LR scheduler in order to better track the + # evolution of the LR independently of the optimizer. + return optimizer, lr_schedule + + +class AdamWeightDecay(Adam): + """ + Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the + loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact + with the m and v parameters in strange ways as shown in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent + to adding the square of the weights to the loss with plain (non-momentum) SGD. + + Args: + learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001): + The learning rate to use or a schedule. + beta_1 (`float`, *optional*, defaults to 0.9): + The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates. + beta_2 (`float`, *optional*, defaults to 0.999): + The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates. + epsilon (`float`, *optional*, defaults to 1e-07): + The epsilon parameter in Adam, which is a small constant for numerical stability. + amsgrad (`bool`, *optional*, defaults to `False`): + Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and + Beyond](https://arxiv.org/abs/1904.09237). + weight_decay_rate (`float`, *optional*, defaults to 0.0): + The weight decay to apply. + include_in_weight_decay (`List[str]`, *optional*): + List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is + applied to all parameters by default (unless they are in `exclude_from_weight_decay`). + exclude_from_weight_decay (`List[str]`, *optional*): + List of the parameter names (or re patterns) to exclude from applying weight decay to. If a + `include_in_weight_decay` is passed, the names in it will supersede this list. + name (`str`, *optional*, defaults to `"AdamWeightDecay"`): + Optional name for the operations created when applying gradients. + kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use + `learning_rate` instead. + """ + + def __init__( + self, + learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001, + beta_1: float = 0.9, + beta_2: float = 0.999, + epsilon: float = 1e-7, + amsgrad: bool = False, + weight_decay_rate: float = 0.0, + include_in_weight_decay: Optional[List[str]] = None, + exclude_from_weight_decay: Optional[List[str]] = None, + name: str = "AdamWeightDecay", + **kwargs, + ): + super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) + self.weight_decay_rate = weight_decay_rate + self._include_in_weight_decay = include_in_weight_decay + self._exclude_from_weight_decay = exclude_from_weight_decay + + @classmethod + def from_config(cls, config): + """Creates an optimizer from its config with WarmUp custom object.""" + custom_objects = {"WarmUp": WarmUp} + return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects) + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state) + apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant( + self.weight_decay_rate, name="adam_weight_decay_rate" + ) + + def _decay_weights_op(self, var, learning_rate, apply_state): + do_decay = self._do_use_weight_decay(var.name) + if do_decay: + return var.assign_sub( + learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"], + use_locking=self._use_locking, + ) + return tf.no_op() + + def apply_gradients(self, grads_and_vars, name=None, **kwargs): + grads, tvars = list(zip(*grads_and_vars)) + return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs) + + def _get_lr(self, var_device, var_dtype, apply_state): + """Retrieves the learning rate with the given state.""" + if apply_state is None: + return self._decayed_lr_t[var_dtype], {} + + apply_state = apply_state or {} + coefficients = apply_state.get((var_device, var_dtype)) + if coefficients is None: + coefficients = self._fallback_apply_state(var_device, var_dtype) + apply_state[(var_device, var_dtype)] = coefficients + + return coefficients["lr_t"], {"apply_state": apply_state} + + def _resource_apply_dense(self, grad, var, apply_state=None): + lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) + decay = self._decay_weights_op(var, lr_t, apply_state) + with tf.control_dependencies([decay]): + return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) + decay = self._decay_weights_op(var, lr_t, apply_state) + with tf.control_dependencies([decay]): + return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs) + + def get_config(self): + config = super().get_config() + config.update({"weight_decay_rate": self.weight_decay_rate}) + return config + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if self.weight_decay_rate == 0: + return False + + if self._include_in_weight_decay: + for r in self._include_in_weight_decay: + if re.search(r, param_name) is not None: + return True + + if self._exclude_from_weight_decay: + for r in self._exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True + + +# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py +class GradientAccumulator: + """ + Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a + replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should + then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`. + """ + + # We use the ON_READ synchronization policy so that no synchronization is + # performed on assignment. To get the value, we call .value() which returns the + # value on the current replica without synchronization. + + def __init__(self): + """Initializes the accumulator.""" + self._gradients = [] + self._accum_steps = None + + @property + def step(self): + """Number of accumulated steps.""" + if self._accum_steps is None: + self._accum_steps = tf.Variable( + tf.constant(0, dtype=tf.int64), + trainable=False, + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + + return self._accum_steps.value() + + @property + def gradients(self): + """The accumulated gradients on the current replica.""" + if not self._gradients: + raise ValueError("The accumulator should be called first to initialize the gradients") + return [gradient.value() if gradient is not None else gradient for gradient in self._gradients] + + def __call__(self, gradients): + """Accumulates `gradients` on the current replica.""" + if not self._gradients: + _ = self.step # Create the step variable. + self._gradients.extend( + [ + tf.Variable( + tf.zeros_like(gradient), + trainable=False, + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + if gradient is not None + else gradient + for gradient in gradients + ] + ) + if len(gradients) != len(self._gradients): + raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}") + + for accum_gradient, gradient in zip(self._gradients, gradients): + if accum_gradient is not None and gradient is not None: + accum_gradient.assign_add(gradient) + + self._accum_steps.assign_add(1) + + def reset(self): + """Resets the accumulated gradients on the current replica.""" + if not self._gradients: + return + self._accum_steps.assign(0) + for gradient in self._gradients: + if gradient is not None: + gradient.assign(tf.zeros_like(gradient)) diff --git a/processing_utils.py b/processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c26463003670337a5e620d073ffe4dd53f6de934 --- /dev/null +++ b/processing_utils.py @@ -0,0 +1,1212 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processing saving/loading class for common processors. +""" + +import copy +import inspect +import json +import os +import sys +import typing +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union + +import numpy as np +import typing_extensions + +from .dynamic_module_utils import custom_object_save +from .image_utils import ChannelDimension, is_valid_image, is_vision_available + + +if is_vision_available(): + from .image_utils import PILImageResampling + +from .tokenization_utils_base import ( + PaddingStrategy, + PreTokenizedInput, + PreTrainedTokenizerBase, + TextInput, + TruncationStrategy, +) +from .utils import ( + PROCESSOR_NAME, + PushToHubMixin, + TensorType, + add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, + cached_file, + copy_func, + direct_transformers_import, + download_url, + is_offline_mode, + is_remote_url, + logging, +) + + +logger = logging.get_logger(__name__) + +# Dynamically import the Transformers module to grab the attribute classes of the processor form their names. +transformers_module = direct_transformers_import(Path(__file__).parent) + + +AUTO_TO_BASE_CLASS_MAPPING = { + "AutoTokenizer": "PreTrainedTokenizerBase", + "AutoFeatureExtractor": "FeatureExtractionMixin", + "AutoImageProcessor": "ImageProcessingMixin", +} + +if sys.version_info >= (3, 11): + Unpack = typing.Unpack +else: + Unpack = typing_extensions.Unpack + + +class TextKwargs(TypedDict, total=False): + """ + Keyword arguments for text processing. For extended documentation, check out tokenization_utils_base methods and + docstrings associated. + + Attributes: + add_special_tokens (`bool`, *optional*) + Whether or not to add special tokens when encoding the sequences. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*) + Activates and controls padding. + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*): + Activates and controls truncation. + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + stride (`int`, *optional*): + If set, the overflowing tokens will contain some tokens from the end of the truncated sequence. + is_split_into_words (`bool`, *optional*): + Whether or not the input is already pre-tokenized. + pad_to_multiple_of (`int`, *optional*): + If set, will pad the sequence to a multiple of the provided value. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. + return_overflowing_tokens (`bool`, *optional*): + Whether or not to return overflowing token sequences. + return_special_tokens_mask (`bool`, *optional*): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*): + Whether or not to return `(char_start, char_end)` for each token. + return_length (`bool`, *optional*): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*): + Whether or not to print more information and warnings. + padding_side (`str`, *optional*): + The side on which padding will be applied. + """ + + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + add_special_tokens: Optional[bool] + padding: Union[bool, str, PaddingStrategy] + truncation: Union[bool, str, TruncationStrategy] + max_length: Optional[int] + stride: Optional[int] + is_split_into_words: Optional[bool] + pad_to_multiple_of: Optional[int] + return_token_type_ids: Optional[bool] + return_attention_mask: Optional[bool] + return_overflowing_tokens: Optional[bool] + return_special_tokens_mask: Optional[bool] + return_offsets_mapping: Optional[bool] + return_length: Optional[bool] + verbose: Optional[bool] + padding_side: Optional[str] + + +class ImagesKwargs(TypedDict, total=False): + """ + Keyword arguments for image processing. For extended documentation, check the appropriate ImageProcessor + class methods and docstrings. + + Attributes: + do_resize (`bool`, *optional*): + Whether to resize the image. + size (`Dict[str, int]`, *optional*): + Resize the shorter side of the input to `size["shortest_edge"]`. + size_divisor (`int`, *optional*): + The size by which to make sure both the height and width can be divided. + crop_size (`Dict[str, int]`, *optional*): + Desired output size when applying center-cropping. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*): + Mean to use if normalizing the image. + image_std (`float` or `List[float]`, *optional*): + Standard deviation to use if normalizing the image. + do_pad (`bool`, *optional*): + Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. + do_center_crop (`bool`, *optional*): + Whether to center crop the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + + do_resize: Optional[bool] + size: Optional[Dict[str, int]] + size_divisor: Optional[int] + crop_size: Optional[Dict[str, int]] + resample: Optional[Union["PILImageResampling", int]] + do_rescale: Optional[bool] + rescale_factor: Optional[float] + do_normalize: Optional[bool] + image_mean: Optional[Union[float, List[float]]] + image_std: Optional[Union[float, List[float]]] + do_pad: Optional[bool] + pad_size: Optional[Dict[str, int]] + do_center_crop: Optional[bool] + data_format: Optional[ChannelDimension] + input_data_format: Optional[Union[str, ChannelDimension]] + + +class VideosKwargs(TypedDict, total=False): + """ + Keyword arguments for video processing. + + Attributes: + do_resize (`bool`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*): + Resize the shorter side of the input to `size["shortest_edge"]`. + size_divisor (`int`, *optional*): + The size by which to make sure both the height and width can be divided. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*): + Mean to use if normalizing the image. + image_std (`float` or `List[float]`, *optional*): + Standard deviation to use if normalizing the image. + do_pad (`bool`, *optional*): + Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + do_center_crop (`bool`, *optional*): + Whether to center crop the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + + do_resize: Optional[bool] + size: Optional[Dict[str, int]] + size_divisor: Optional[int] + resample: Optional["PILImageResampling"] + do_rescale: Optional[bool] + rescale_factor: Optional[float] + do_normalize: Optional[bool] + image_mean: Optional[Union[float, List[float]]] + image_std: Optional[Union[float, List[float]]] + do_pad: Optional[bool] + do_center_crop: Optional[bool] + data_format: Optional[ChannelDimension] + input_data_format: Optional[Union[str, ChannelDimension]] + + +class AudioKwargs(TypedDict, total=False): + """ + Keyword arguments for audio processing. + + Attributes: + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set, will pad the sequence to a multiple of the provided value. + return_attention_mask (`bool`, *optional*): + Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`. + """ + + sampling_rate: Optional[int] + raw_speech: Optional[Union["np.ndarray", List[float], List["np.ndarray"], List[List[float]]]] + padding: Optional[Union[bool, str, PaddingStrategy]] + max_length: Optional[int] + truncation: Optional[bool] + pad_to_multiple_of: Optional[int] + return_attention_mask: Optional[bool] + + +class CommonKwargs(TypedDict, total=False): + return_tensors: Optional[Union[str, TensorType]] + + +class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, total=False): + """ + Base class for kwargs passing to processors. + A model should have its own `ModelProcessorKwargs` class that inherits from `ProcessingKwargs` to provide: + 1) Additional typed keys and that this model requires to process inputs. + 2) Default values for existing keys under a `_defaults` attribute. + New keys have to be defined as follows to ensure type hinting is done correctly. + + ```python + # adding a new image kwarg for this model + class ModelImagesKwargs(ImagesKwargs, total=False): + new_image_kwarg: Optional[bool] + + class ModelProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: ModelImagesKwargs + _defaults = { + "images_kwargs: { + "new_image_kwarg": False, + } + "text_kwargs": { + "padding": "max_length", + }, + } + + ``` + + For Python 3.8 compatibility, when inheriting from this class and overriding one of the kwargs, + you need to manually update the __annotations__ dictionary. This can be done as follows: + + ```python + class CustomProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: CustomImagesKwargs + + CustomProcessorKwargs.__annotations__["images_kwargs"] = CustomImagesKwargs # python 3.8 compatibility + ```python + + """ + + common_kwargs: CommonKwargs = { + **CommonKwargs.__annotations__, + } + text_kwargs: TextKwargs = { + **TextKwargs.__annotations__, + } + images_kwargs: ImagesKwargs = { + **ImagesKwargs.__annotations__, + } + videos_kwargs: VideosKwargs = { + **VideosKwargs.__annotations__, + } + audio_kwargs: AudioKwargs = { + **AudioKwargs.__annotations__, + } + + +class ProcessorMixin(PushToHubMixin): + """ + This is a mixin used to provide saving/loading functionality for all processor classes. + """ + + attributes = ["feature_extractor", "tokenizer"] + optional_attributes = ["chat_template"] + optional_call_args: List[str] = [] + # Names need to be attr_class for attr in attributes + feature_extractor_class = None + tokenizer_class = None + _auto_class = None + valid_kwargs: List[str] = [] + + # args have to match the attributes class attribute + def __init__(self, *args, **kwargs): + # First, extract optional attributes from kwargs if present + # Optional attributes can never be positional arguments + for optional_attribute in self.optional_attributes: + setattr(self, optional_attribute, kwargs.pop(optional_attribute, None)) + # Sanitize args and kwargs + for key in kwargs: + if key not in self.attributes: + raise TypeError(f"Unexpected keyword argument {key}.") + for arg, attribute_name in zip(args, self.attributes): + if attribute_name in kwargs: + raise TypeError(f"Got multiple values for argument {attribute_name}.") + else: + kwargs[attribute_name] = arg + + if len(kwargs) != len(self.attributes): + raise ValueError( + f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got " + f"{len(args)} arguments instead." + ) + + # Check each arg is of the proper class (this will also catch a user initializing in the wrong order) + for attribute_name, arg in kwargs.items(): + class_name = getattr(self, f"{attribute_name}_class") + # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class. + class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) + if isinstance(class_name, tuple): + proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None) + else: + proper_class = getattr(transformers_module, class_name) + + if not isinstance(arg, proper_class): + raise TypeError( + f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected." + ) + + setattr(self, attribute_name, arg) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this processor instance. + """ + output = copy.deepcopy(self.__dict__) + + # Get the kwargs in `__init__`. + sig = inspect.signature(self.__init__) + # Only save the attributes that are presented in the kwargs of `__init__`. + attrs_to_save = sig.parameters + # Don't save attributes like `tokenizer`, `image processor` etc. + attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes] + # extra attributes to be kept + attrs_to_save += ["auto_map"] + + output = {k: v for k, v in output.items() if k in attrs_to_save} + + output["processor_class"] = self.__class__.__name__ + + if "tokenizer" in output: + del output["tokenizer"] + if "image_processor" in output: + del output["image_processor"] + if "feature_extractor" in output: + del output["feature_extractor"] + if "chat_template" in output: + del output["chat_template"] + + # Some attributes have different names but containing objects that are not simple strings + output = { + k: v + for k, v in output.items() + if not (isinstance(v, PushToHubMixin) or v.__class__.__name__ == "BeamSearchDecoderCTC") + } + + return output + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this processor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes] + attributes_repr = "\n".join(attributes_repr) + return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}" + + def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): + """ + Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it + can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. + + + + This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and + [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the + methods above for more information. + + + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will + be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token", None) is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + attrs = [getattr(self, attribute_name) for attribute_name in self.attributes] + configs = [(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) for a in attrs] + configs.append(self) + custom_object_save(self, save_directory, config=configs) + + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name) + # Include the processor class in the attribute config so this processor can then be reloaded with the + # `AutoProcessor` API. + if hasattr(attribute, "_set_processor_class"): + attribute._set_processor_class(self.__class__.__name__) + attribute.save_pretrained(save_directory) + + if self._auto_class is not None: + # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up. + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name) + if isinstance(attribute, PreTrainedTokenizerBase): + del attribute.init_kwargs["auto_map"] + + # If we save using the predefined names, we can load using `from_pretrained` + # plus we save chat_template in its own file + output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) + output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja") + output_chat_template_file = os.path.join(save_directory, "chat_template.json") + + processor_dict = self.to_dict() + # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` + # to avoid serializing chat template in json config file. So let's get it from `self` directly + if self.chat_template is not None: + if kwargs.get("save_raw_chat_template", False): + with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer: + writer.write(self.chat_template) + logger.info(f"chat template saved in {output_raw_chat_template_file}") + else: + chat_template_json_string = ( + json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" + ) + with open(output_chat_template_file, "w", encoding="utf-8") as writer: + writer.write(chat_template_json_string) + logger.info(f"chat template saved in {output_chat_template_file}") + + # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and + # `auto_map` is not specified. + if set(processor_dict.keys()) != {"processor_class"}: + self.to_json_file(output_processor_file) + logger.info(f"processor saved in {output_processor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + if set(processor_dict.keys()) == {"processor_class"}: + return [] + return [output_processor_file] + + @classmethod + def get_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + + Returns: + `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": "processor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) + + if os.path.isfile(pretrained_model_name_or_path): + resolved_processor_file = pretrained_model_name_or_path + # cant't load chat-template when given a file as pretrained_model_name_or_path + resolved_chat_template_file = None + resolved_raw_chat_template_file = None + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + processor_file = pretrained_model_name_or_path + resolved_processor_file = download_url(pretrained_model_name_or_path) + # can't load chat-template when given a file url as pretrained_model_name_or_path + resolved_chat_template_file = None + resolved_raw_chat_template_file = None + else: + processor_file = PROCESSOR_NAME + chat_template_file = "chat_template.json" + raw_chat_template_file = "chat_template.jinja" + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_processor_file = cached_file( + pretrained_model_name_or_path, + processor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + + # Load chat template from a separate json if exists + # because making it part of processor-config break BC. + # Processors in older version do not accept any kwargs + resolved_chat_template_file = cached_file( + pretrained_model_name_or_path, + chat_template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + + resolved_raw_chat_template_file = cached_file( + pretrained_model_name_or_path, + raw_chat_template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {PROCESSOR_NAME} file" + ) + + # Add chat template as kwarg before returning because most models don't have processor config + if resolved_raw_chat_template_file is not None: + with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: + chat_template = reader.read() + kwargs["chat_template"] = chat_template + elif resolved_chat_template_file is not None: + with open(resolved_chat_template_file, "r", encoding="utf-8") as reader: + text = reader.read() + chat_template = json.loads(text)["chat_template"] + kwargs["chat_template"] = chat_template + + # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not + # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict. + # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception) + # However, for models added in the future, we won't get the expected error if this file is missing. + if resolved_processor_file is None: + return {}, kwargs + + try: + # Load processor dict + with open(resolved_processor_file, "r", encoding="utf-8") as reader: + text = reader.read() + processor_dict = json.loads(text) + + except json.JSONDecodeError: + raise EnvironmentError( + f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_processor_file}") + else: + logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}") + + if "chat_template" in processor_dict and processor_dict["chat_template"] is not None: + logger.warning_once( + "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' " + "in the processor's config. Make sure to move your template to its own file." + ) + + if not is_local: + if "auto_map" in processor_dict: + processor_dict["auto_map"] = add_model_info_to_auto_map( + processor_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in processor_dict: + processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + processor_dict["custom_pipelines"], pretrained_model_name_or_path + ) + + return processor_dict, kwargs + + @classmethod + def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): + """ + Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters. + + Args: + processor_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the processor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~processing_utils.ProcessingMixin.to_dict`] method. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the processor object. + + Returns: + [`~processing_utils.ProcessingMixin`]: The processor object instantiated from those + parameters. + """ + processor_dict = processor_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + chat_template = kwargs.pop("chat_template", None) + + # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs + # If we don't pop, some specific kwargs will raise a warning + if "processor_class" in processor_dict: + del processor_dict["processor_class"] + + if "auto_map" in processor_dict: + del processor_dict["auto_map"] + + unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs) + processor = cls(*args, **processor_dict) + if chat_template is not None: + setattr(processor, "chat_template", chat_template) + + # Update processor with kwargs if needed + for key in set(kwargs.keys()): + if hasattr(processor, key): + setattr(processor, key, kwargs.pop(key)) + + kwargs.update(unused_kwargs) + logger.info(f"Processor {processor}") + if return_unused_kwargs: + return processor, kwargs + else: + return processor + + def _merge_kwargs( + self, + ModelProcessorKwargs: ProcessingKwargs, + tokenizer_init_kwargs: Optional[Dict] = None, + **kwargs, + ) -> Dict[str, Dict]: + """ + Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance. + The order of operations is as follows: + 1) kwargs passed as before have highest priority to preserve BC. + ```python + high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"} + processor(..., **high_priority_kwargs) + ``` + 2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API. + ```python + processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}}) + ``` + 3) kwargs passed during instantiation of a modality processor have fourth priority. + ```python + tokenizer = tokenizer_class(..., {"padding": "max_length"}) + image_processor = image_processor_class(...) + processor(tokenizer, image_processor) # will pass max_length unless overriden by kwargs at call + ``` + 4) defaults kwargs specified at processor level have lowest priority. + ```python + class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "max_length", + "max_length": 64, + }, + } + ``` + Args: + ModelProcessorKwargs (`ProcessingKwargs`): + Typed dictionary of kwargs specifically required by the model passed. + tokenizer_init_kwargs (`Dict`, *optional*): + Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults. + + Returns: + output_kwargs (`Dict`): + Dictionary of per-modality kwargs to be passed to each modality-specific processor. + + """ + # Initialize dictionaries + output_kwargs = { + "text_kwargs": {}, + "images_kwargs": {}, + "audio_kwargs": {}, + "videos_kwargs": {}, + "common_kwargs": {}, + } + + default_kwargs = { + "text_kwargs": {}, + "images_kwargs": {}, + "audio_kwargs": {}, + "videos_kwargs": {}, + "common_kwargs": {}, + } + + used_keys = set() + + # get defaults from set model processor kwargs if they exist + for modality in default_kwargs: + default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() + # update defaults with arguments from tokenizer init + for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): + # init with tokenizer init kwargs if necessary + if modality_key in tokenizer_init_kwargs: + value = ( + getattr(self.tokenizer, modality_key) + if hasattr(self.tokenizer, modality_key) + else tokenizer_init_kwargs[modality_key] + ) + default_kwargs[modality][modality_key] = value + # now defaults kwargs are updated with the tokenizers defaults. + # pass defaults to output dictionary + output_kwargs.update(default_kwargs) + + # update modality kwargs with passed kwargs + non_modality_kwargs = set(kwargs) - set(output_kwargs) + for modality in output_kwargs: + for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): + # check if we received a structured kwarg dict or not to handle it correctly + if modality in kwargs: + kwarg_value = kwargs[modality].pop(modality_key, "__empty__") + # check if this key was passed as a flat kwarg. + if kwarg_value != "__empty__" and modality_key in non_modality_kwargs: + raise ValueError( + f"Keyword argument {modality_key} was passed two times:\n" + f"in a dictionary for {modality} and as a **kwarg." + ) + elif modality_key in kwargs: + # we get a modality_key instead of popping it because modality-specific processors + # can have overlapping kwargs + kwarg_value = kwargs.get(modality_key, "__empty__") + else: + kwarg_value = "__empty__" + if kwarg_value != "__empty__": + output_kwargs[modality][modality_key] = kwarg_value + used_keys.add(modality_key) + + # Determine if kwargs is a flat dictionary or contains nested dictionaries + if any(key in default_kwargs for key in kwargs): + # kwargs is dictionary-based, and some keys match modality names + for modality, subdict in kwargs.items(): + if modality in default_kwargs: + for subkey, subvalue in subdict.items(): + if subkey not in used_keys: + output_kwargs[modality][subkey] = subvalue + used_keys.add(subkey) + else: + # kwargs is a flat dictionary + for key in kwargs: + if key not in used_keys: + if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys(): + output_kwargs["common_kwargs"][key] = kwargs[key] + else: + logger.warning_once( + f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored." + ) + + # all modality-specific kwargs are updated with common kwargs + for modality in output_kwargs: + output_kwargs[modality].update(output_kwargs["common_kwargs"]) + return output_kwargs + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + r""" + Instantiate a processor associated with a pretrained model. + + + + This class method is simply calling the feature extractor + [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor + [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the + methods above for more information. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a feature extractor file saved using the + [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`. + - a path or url to a saved feature extractor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + **kwargs + Additional keyword arguments passed along to both + [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`]. + """ + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) + processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_args_and_dict(args, processor_dict, **kwargs) + + @classmethod + def register_for_auto_class(cls, auto_class="AutoProcessor"): + """ + Register this class with a given auto class. This should only be used for custom feature extractors as the ones + in the library are already mapped with `AutoProcessor`. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`): + The auto class to register this new feature extractor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + @classmethod + def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + args = [] + for attribute_name in cls.attributes: + class_name = getattr(cls, f"{attribute_name}_class") + if isinstance(class_name, tuple): + classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name) + use_fast = kwargs.get("use_fast", True) + if use_fast and classes[1] is not None: + attribute_class = classes[1] + else: + attribute_class = classes[0] + else: + attribute_class = getattr(transformers_module, class_name) + + args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) + return args + + @property + def model_input_names(self): + first_attribute = getattr(self, self.attributes[0]) + return getattr(first_attribute, "model_input_names", None) + + @staticmethod + def validate_init_kwargs(processor_config, valid_kwargs): + kwargs_from_config = processor_config.keys() + unused_kwargs = {} + unused_keys = set(kwargs_from_config) - set(valid_kwargs) + if unused_keys: + unused_key_str = ", ".join(unused_keys) + logger.warning( + f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. " + ) + unused_kwargs = {k: processor_config[k] for k in unused_keys} + return unused_kwargs + + def prepare_and_validate_optional_call_args(self, *args): + """ + Matches optional positional arguments to their corresponding names in `optional_call_args` + in the processor class in the order they are passed to the processor call. + + Note that this should only be used in the `__call__` method of the processors with special + arguments. Special arguments are arguments that aren't `text`, `images`, `audio`, nor `videos` + but also aren't passed to the tokenizer, image processor, etc. Examples of such processors are: + - `CLIPSegProcessor` + - `LayoutLMv2Processor` + - `OwlViTProcessor` + + Also note that passing by position to the processor call is now deprecated and will be disallowed + in future versions. We only have this for backward compatibility. + + Example: + Suppose that the processor class has `optional_call_args = ["arg_name_1", "arg_name_2"]`. + And we define the call method as: + ```python + def __call__( + self, + text: str, + images: Optional[ImageInput] = None, + *arg, + audio=None, + videos=None, + ) + ``` + + Then, if we call the processor as: + ```python + images = [...] + processor("What is common in these images?", images, arg_value_1, arg_value_2) + ``` + + Then, this method will return: + ```python + { + "arg_name_1": arg_value_1, + "arg_name_2": arg_value_2, + } + ``` + which we could then pass as kwargs to `self._merge_kwargs` + """ + if len(args): + warnings.warn( + "Passing positional arguments to the processor call is now deprecated and will be disallowed in v4.47. " + "Please pass all arguments as keyword arguments." + ) + if len(args) > len(self.optional_call_args): + raise ValueError( + f"Expected *at most* {len(self.optional_call_args)} optional positional arguments in processor call" + f"which will be matched with {' '.join(self.optional_call_args)} in the order they are passed." + f"However, got {len(args)} positional arguments instead." + "Please pass all arguments as keyword arguments instead (e.g. `processor(arg_name_1=..., arg_name_2=...))`." + ) + return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)} + + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]]], + chat_template: Optional[str] = None, + tokenize: bool = False, + **kwargs, + ) -> str: + """ + Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input + conversations to turn them into a single tokenizable string. + + Args: + conversation (`List[Dict, str, str]`): + The conversation to format. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. If not provided, the tokenizer's + chat template is used. + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + **kwargs: + Additional keyword arguments + """ + + if chat_template is None: + if self.chat_template is not None: + chat_template = self.chat_template + else: + raise ValueError( + "No chat template is set for this processor. Please either set the `chat_template` attribute, " + "or provide a chat template as an argument. See " + "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." + ) + return self.tokenizer.apply_chat_template( + conversation, chat_template=chat_template, tokenize=tokenize, **kwargs + ) + + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of a vlm to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + + +def _validate_images_text_input_order(images, text): + """ + For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. + This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. + Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled + in the processor's `__call__` method before calling this method. + """ + + def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + def _is_valid_images_input_for_processor(imgs): + # If we have an list of images, make sure every image is valid + if isinstance(imgs, (list, tuple)): + for img in imgs: + if not _is_valid_images_input_for_processor(img): + return False + # If not a list or tuple, we have been given a single image or batched tensor of images + elif not (is_valid_image(imgs) or is_url(imgs)): + return False + return True + + def _is_valid_text_input_for_processor(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... not empty + return False + for t_s in t: + return _is_valid_text_input_for_processor(t_s) + return False + + def _is_valid(input, validator): + return validator(input) or input is None + + images_is_valid = _is_valid(images, _is_valid_images_input_for_processor) + images_is_text = _is_valid_text_input_for_processor(images) + + text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) + text_is_images = _is_valid_images_input_for_processor(text) + # Handle cases where both inputs are valid + if images_is_valid and text_is_valid: + return images, text + + # Handle cases where inputs need to and can be swapped + if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images): + logger.warning_once( + "You may have used the wrong order for inputs. `images` should be passed before `text`. " + "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." + ) + return text, images + + raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") + + +ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) +if ProcessorMixin.push_to_hub.__doc__ is not None: + ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( + object="processor", object_class="AutoProcessor", object_files="processor files" + ) diff --git a/pytorch_utils.py b/pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6469a5930234ba675e314005e2df2f8d52005415 --- /dev/null +++ b/pytorch_utils.py @@ -0,0 +1,362 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import inspect +from typing import Callable, List, Optional, Set, Tuple, Union + +import torch +from packaging import version +from safetensors.torch import storage_ptr, storage_size +from torch import nn + +from .utils import is_torch_greater_or_equal, is_torch_xla_available, logging + + +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] + +logger = logging.get_logger(__name__) + +parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) + +is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4") +is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") +is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") +is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") + +# For backwards compatibility (e.g. some remote codes on Hub using those variables). +is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") +is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") +is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") + +# Cache this result has it's a C FFI call which can be pretty time-consuming +_torch_distributed_available = torch.distributed.is_available() + +if is_torch_greater_or_equal("2.5") and _torch_distributed_available: + from torch.distributed.tensor import Replicate + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + ) + + +def softmax_backward_data(parent, grad_output, output, dim, self): + """ + A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according + to the torch version detected. + """ + + from torch import _softmax_backward_data + + return _softmax_backward_data(grad_output, output, parent.dim, self.dtype) + + +def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: + """ + Prune a linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`torch.nn.Linear`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if layer.bias is not None: + if dim == 1: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.nx = nx + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def __repr__(self) -> str: + return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: + """ + Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights + are transposed. + + Used to remove heads. + + Args: + layer ([`~pytorch_utils.Conv1D`]): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices. + + Returns: + [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if dim == 0: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_layer( + layer: Union[nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None +) -> Union[nn.Linear, Conv1D]: + """ + Prune a Conv1D or linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. + """ + if isinstance(layer, nn.Linear): + return prune_linear_layer(layer, index, dim=0 if dim is None else dim) + elif isinstance(layer, Conv1D): + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) + else: + raise ValueError(f"Can't prune layer of class {layer.__class__}") + + +def apply_chunking_to_forward( + forward_fn: Callable[..., torch.Tensor], + chunk_size: int, + chunk_dim: int, + *input_tensors, +) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension + `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. + + If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly + applying `forward_fn` to `input_tensors`. + + Args: + forward_fn (`Callable[..., torch.Tensor]`): + The forward function of the model. + chunk_size (`int`): + The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`. + chunk_dim (`int`): + The dimension over which the `input_tensors` should be chunked. + input_tensors (`Tuple[torch.Tensor]`): + The input tensors of `forward_fn` which will be chunked + + Returns: + `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`. + + + Examples: + + ```python + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) + ```""" + + assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + if num_args_in_forward_chunk_fn != len(input_tensors): + raise ValueError( + f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input " + "tensors are given" + ) + + if chunk_size > 0: + tensor_shape = input_tensors[0].shape[chunk_dim] + for input_tensor in input_tensors: + if input_tensor.shape[chunk_dim] != tensor_shape: + raise ValueError( + f"All input tenors have to be of the same shape: {tensor_shape}, " + f"found shape {input_tensor.shape[chunk_dim]}" + ) + + if input_tensors[0].shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk " + f"size {chunk_size}" + ) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) + + +def find_pruneable_heads_and_indices( + heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] +) -> Tuple[Set[int], torch.LongTensor]: + """ + Finds the heads and their indices taking `already_pruned_heads` into account. + + Args: + heads (`List[int]`): List of the indices of heads to prune. + n_heads (`int`): The number of heads in the model. + head_size (`int`): The size of each head. + already_pruned_heads (`Set[int]`): A set of already pruned heads. + + Returns: + `Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads` + into account and the indices of rows/columns to keep in the layer weight. + """ + mask = torch.ones(n_heads, head_size) + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in already_pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index: torch.LongTensor = torch.arange(len(mask))[mask].long() + return heads, index + + +def meshgrid( + *tensors: Union[torch.Tensor, List[torch.Tensor]], indexing: Optional[str] = None +) -> Tuple[torch.Tensor, ...]: + """ + Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument. + + Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html + """ + return torch.meshgrid(*tensors, indexing=indexing) + + +def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + """ + if tensor.device.type == "xla" and is_torch_xla_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return tensor.device, unique_id, storage_size(tensor) + + +def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor: + """ + Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting + torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 + + Args: + elements (`torch.Tensor`): Input elements + test_elements (`torch.Tensor` or `int`): The elements to check against. + + Returns: + `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements` + and False otherwise + """ + + if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + test_elements = torch.tensor(test_elements) + if test_elements.ndim == 0: + test_elements = test_elements.unsqueeze(0) + return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() + else: + # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 + return torch.isin(elements, test_elements) + + +def translate_to_torch_parallel_style(style: str): + """ + In model configurations, we use a neutral type (string) to specify parallel + styles, here we translate them into torch.distributed tensor-parallel + types. + """ + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + else: + raise ValueError(f"Unsupported parallel style value: {style}") diff --git a/safetensors_conversion.py b/safetensors_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..f1612d3ea57c98fd1d383887cfbeb4e2882d3963 --- /dev/null +++ b/safetensors_conversion.py @@ -0,0 +1,105 @@ +from typing import Optional + +import requests +from huggingface_hub import Discussion, HfApi, get_repo_discussions + +from .utils import cached_file, http_user_agent, logging + + +logger = logging.get_logger(__name__) + + +def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]: + main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id + for discussion in get_repo_discussions(repo_id=model_id, token=token): + if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request: + commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token) + + if main_commit == commits[1].commit_id: + return discussion + return None + + +def spawn_conversion(token: str, private: bool, model_id: str): + logger.info("Attempting to convert .bin model on the fly to safetensors.") + + safetensors_convert_space_url = "https://safetensors-convert.hf.space" + sse_url = f"{safetensors_convert_space_url}/call/run" + + def start(_sse_connection): + for line in _sse_connection.iter_lines(): + line = line.decode() + if line.startswith("event:"): + status = line[7:] + logger.debug(f"Safetensors conversion status: {status}") + + if status == "complete": + return + elif status == "heartbeat": + logger.debug("Heartbeat") + else: + logger.debug(f"Unknown status {status}") + else: + logger.debug(line) + + data = {"data": [model_id, private, token]} + + result = requests.post(sse_url, stream=True, json=data).json() + event_id = result["event_id"] + + with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection: + try: + logger.debug("Spawning safetensors automatic conversion.") + start(sse_connection) + except Exception as e: + logger.warning(f"Error during conversion: {repr(e)}") + + +def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): + private = api.model_info(model_id).private + + logger.info("Attempting to create safetensors variant") + pr_title = "Adding `safetensors` variant of this model" + token = kwargs.get("token") + + # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it + # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent + # security breaches. + pr = previous_pr(api, model_id, pr_title, token=token) + + if pr is None or (not private and pr.author != "SFconvertbot"): + spawn_conversion(token, private, model_id) + pr = previous_pr(api, model_id, pr_title, token=token) + else: + logger.info("Safetensors PR exists") + + sha = f"refs/pr/{pr.num}" + + return sha + + +def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs): + try: + api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()}) + sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) + + if sha is None: + return None, None + cached_file_kwargs["revision"] = sha + del cached_file_kwargs["_commit_hash"] + + # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR + # description. + sharded = api.file_exists( + pretrained_model_name_or_path, + "model.safetensors.index.json", + revision=sha, + token=cached_file_kwargs.get("token"), + ) + filename = "model.safetensors.index.json" if sharded else "model.safetensors" + + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + return resolved_archive_file, sha, sharded + except Exception as e: + if not ignore_errors_during_conversion: + raise e diff --git a/sentencepiece_bpe.cpython-312.pyc b/sentencepiece_bpe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..260af1df06280e71b122d14601ef4ba74107df2d Binary files /dev/null and b/sentencepiece_bpe.cpython-312.pyc differ diff --git a/sentencepiece_bpe.py b/sentencepiece_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..d16f343bff05727a358b48596e02aadfd05178eb --- /dev/null +++ b/sentencepiece_bpe.py @@ -0,0 +1,103 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers +from tokenizers.models import BPE +from tokenizers.normalizers import NFKC + +from .base_tokenizer import BaseTokenizer + + +class SentencePieceBPETokenizer(BaseTokenizer): + """SentencePiece BPE Tokenizer + + Represents the BPE algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]] = None, + unk_token: Union[str, AddedToken] = "", + replacement: str = "▁", + add_prefix_space: bool = True, + dropout: Optional[float] = None, + fuse_unk: Optional[bool] = False, + ): + if vocab is not None and merges is not None: + tokenizer = Tokenizer(BPE(vocab, merges, dropout=dropout, unk_token=unk_token, fuse_unk=fuse_unk)) + else: + tokenizer = Tokenizer(BPE(dropout=dropout, unk_token=unk_token, fuse_unk=fuse_unk)) + + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + + tokenizer.normalizer = NFKC() + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceBPE", + "unk_token": unk_token, + "replacement": replacement, + "add_prefix_space": add_prefix_space, + "dropout": dropout, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_file(vocab_filename, merges_filename) + return SentencePieceBPETokenizer(vocab, merges, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + show_progress: bool = True, + ): + """Train the model using the given files""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + show_progress=show_progress, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + special_tokens: List[Union[str, AddedToken]] = [""], + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + show_progress: bool = True, + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + special_tokens=special_tokens, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + show_progress=show_progress, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) diff --git a/sentencepiece_unigram.cpython-312.pyc b/sentencepiece_unigram.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..188d9f6bf04d463288b659d215161465d201a925 Binary files /dev/null and b/sentencepiece_unigram.cpython-312.pyc differ diff --git a/sentencepiece_unigram.py b/sentencepiece_unigram.py new file mode 100644 index 0000000000000000000000000000000000000000..da3fe32e88fe5521a6615ea86998f349471e8646 --- /dev/null +++ b/sentencepiece_unigram.py @@ -0,0 +1,196 @@ +import json +import os +from typing import Iterator, List, Optional, Union, Tuple + +from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers +from tokenizers.models import Unigram + +from .base_tokenizer import BaseTokenizer + + +class SentencePieceUnigramTokenizer(BaseTokenizer): + """SentencePiece Unigram Tokenizer + + Represents the Unigram algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: Optional[List[Tuple[str, float]]] = None, + replacement: str = "▁", + add_prefix_space: bool = True, + ): + if vocab is not None: + # Let Unigram(..) fail if only one of them is None + tokenizer = Tokenizer(Unigram(vocab)) + else: + tokenizer = Tokenizer(Unigram()) + + tokenizer.normalizer = normalizers.Sequence( + [normalizers.Nmt(), normalizers.NFKC(), normalizers.Replace(Regex(" {2,}"), " ")] + ) + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceUnigram", + "replacement": replacement, + "add_prefix_space": add_prefix_space, + } + + super().__init__(tokenizer, parameters) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: Optional[List[Union[str, AddedToken]]] = None, + initial_alphabet: Optional[List[str]] = None, + unk_token: Optional[str] = None, + ): + """ + Train the model using the given files + + Args: + files (:obj:`List[str]`): + A list of path to the files that we should use for training + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + show_progress (:obj:`bool`): + Whether to show progress bars while training. + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + """ + + if special_tokens is None: + special_tokens = [] + + if initial_alphabet is None: + initial_alphabet = [] + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, + special_tokens=special_tokens, + show_progress=show_progress, + initial_alphabet=initial_alphabet, + unk_token=unk_token, + ) + + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: Optional[List[Union[str, AddedToken]]] = None, + initial_alphabet: Optional[List[str]] = None, + unk_token: Optional[str] = None, + length: Optional[int] = None, + ): + """ + Train the model using the given iterator + + Args: + iterator (:obj:`Union[Iterator[str], Iterator[Iterator[str]]]`): + Any iterator over strings or list of strings + vocab_size (:obj:`int`): + The size of the final vocabulary, including all tokens and alphabet. + show_progress (:obj:`bool`): + Whether to show progress bars while training. + special_tokens (:obj:`List[Union[str, AddedToken]]`, `optional`): + A list of special tokens the model should know of. + initial_alphabet (:obj:`List[str]`, `optional`): + A list of characters to include in the initial alphabet, even + if not seen in the training dataset. + If the strings contain more than one character, only the first one + is kept. + unk_token (:obj:`str`, `optional`): + The unknown token to be used by the model. + length (:obj:`int`, `optional`): + The total number of sequences in the iterator. This is used to + provide meaningful progress tracking + """ + + if special_tokens is None: + special_tokens = [] + + if initial_alphabet is None: + initial_alphabet = [] + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, + special_tokens=special_tokens, + show_progress=show_progress, + initial_alphabet=initial_alphabet, + unk_token=unk_token, + ) + + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) + + @staticmethod + def from_spm(filename: str): + try: + import sys + + sys.path.append(".") + + import sentencepiece_model_pb2 as model + except Exception: + raise Exception( + "You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/src/sentencepiece/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required." + ) + + m = model.ModelProto() + m.ParseFromString(open(filename, "rb").read()) + + precompiled_charsmap = m.normalizer_spec.precompiled_charsmap + vocab = [(piece.piece, piece.score) for piece in m.pieces] + unk_id = m.trainer_spec.unk_id + model_type = m.trainer_spec.model_type + byte_fallback = m.trainer_spec.byte_fallback + if model_type != 1: + raise Exception( + "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" + ) + + replacement = "▁" + add_prefix_space = True + + tokenizer = Tokenizer(Unigram(vocab, unk_id, byte_fallback)) + + if precompiled_charsmap: + tokenizer.normalizer = normalizers.Sequence( + [ + normalizers.Precompiled(precompiled_charsmap), + normalizers.Replace(Regex(" {2,}"), " "), + ] + ) + else: + tokenizer.normalizer = normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")]) + prepend_scheme = "always" if add_prefix_space else "never" + tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) + + parameters = { + "model": "SentencePieceUnigram", + } + + obj = BaseTokenizer.__new__(SentencePieceUnigramTokenizer, tokenizer, parameters) + BaseTokenizer.__init__(obj, tokenizer, parameters) + return obj diff --git a/testing_utils.py b/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7876b22a2bb9070828f76c2cc229ab8a7c264825 --- /dev/null +++ b/testing_utils.py @@ -0,0 +1,2862 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import contextlib +import copy +import doctest +import functools +import gc +import importlib +import inspect +import logging +import multiprocessing +import os +import re +import shlex +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import unittest +from collections import defaultdict +from collections.abc import Mapping +from dataclasses import MISSING, fields +from functools import wraps +from io import StringIO +from pathlib import Path +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union +from unittest import mock +from unittest.mock import patch + +import huggingface_hub.utils +import urllib3 +from huggingface_hub import delete_repo + +from transformers import logging as transformers_logging + +from .integrations import ( + is_clearml_available, + is_optuna_available, + is_ray_available, + is_sigopt_available, + is_tensorboard_available, + is_wandb_available, +) +from .integrations.deepspeed import is_deepspeed_available +from .utils import ( + ACCELERATE_MIN_VERSION, + GGUF_MIN_VERSION, + is_accelerate_available, + is_apex_available, + is_aqlm_available, + is_auto_awq_available, + is_auto_gptq_available, + is_av_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_bs4_available, + is_compressed_tensors_available, + is_cv2_available, + is_cython_available, + is_detectron2_available, + is_eetq_available, + is_essentia_available, + is_faiss_available, + is_fbgemm_gpu_available, + is_flash_attn_2_available, + is_flax_available, + is_flute_available, + is_fsdp_available, + is_ftfy_available, + is_g2p_en_available, + is_galore_torch_available, + is_gguf_available, + is_grokadamw_available, + is_hadamard_available, + is_ipex_available, + is_jieba_available, + is_jinja_available, + is_jumanpp_available, + is_keras_nlp_available, + is_levenshtein_available, + is_librosa_available, + is_liger_kernel_available, + is_lomo_available, + is_natten_available, + is_nltk_available, + is_onnx_available, + is_optimum_available, + is_optimum_quanto_available, + is_pandas_available, + is_peft_available, + is_phonemizer_available, + is_pretty_midi_available, + is_pyctcdecode_available, + is_pytesseract_available, + is_pytest_available, + is_pytorch_quantization_available, + is_rjieba_available, + is_sacremoses_available, + is_safetensors_available, + is_schedulefree_available, + is_scipy_available, + is_sentencepiece_available, + is_seqio_available, + is_soundfile_available, + is_spacy_available, + is_sudachi_available, + is_sudachi_projection_available, + is_tensorflow_probability_available, + is_tensorflow_text_available, + is_tf2onnx_available, + is_tf_available, + is_tiktoken_available, + is_timm_available, + is_tokenizers_available, + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_bf16_cpu_available, + is_torch_bf16_gpu_available, + is_torch_deterministic, + is_torch_fp16_available_on_device, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_sdpa_available, + is_torch_tensorrt_fx_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + is_torchaudio_available, + is_torchdynamo_available, + is_torchvision_available, + is_vision_available, + is_vptq_available, + strtobool, +) + + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils.imports import is_fp8_available + + +if is_pytest_available(): + from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + ) + from _pytest.outcomes import skip + from _pytest.pathlib import import_path + from pytest import DoctestItem +else: + Module = object + DoctestItem = object + + +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" +DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" +# Used to test Auto{Config, Model, Tokenizer} model_type detection. + +# Used to test the hub +USER = "__DUMMY_TRANSFORMERS_USER__" +ENDPOINT_STAGING = "https://hub-ci.huggingface.co" + +# Not critical, only usable on the sandboxed CI instance. +TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + +if is_torch_available(): + import torch + + IS_ROCM_SYSTEM = torch.version.hip is not None + IS_CUDA_SYSTEM = torch.version.cuda is not None +else: + IS_ROCM_SYSTEM = False + IS_CUDA_SYSTEM = False + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +def parse_int_from_env(key, default=None): + try: + value = os.environ[key] + except KeyError: + _value = default + else: + try: + _value = int(value) + except ValueError: + raise ValueError(f"If set, {key} must be a int.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=True) +_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True) +_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) +_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) +_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None) +_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) +_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False) +_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) + + +def get_device_count(): + import torch + + if is_torch_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() + + return num_devices + + +def is_pt_tf_cross_test(test_case): + """ + Decorator marking a test as a test that control interactions between PyTorch and TensorFlow. + + PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable + to a truthy value and selecting the is_pt_tf_cross_test pytest mark. + + """ + if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available(): + return unittest.skip(reason="test is PT+TF test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pt_tf_cross_test()(test_case) + + +def is_pt_flax_cross_test(test_case): + """ + Decorator marking a test as a test that control interactions between PyTorch and Flax + + PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment + variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark. + + """ + if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available(): + return unittest.skip(reason="test is PT+FLAX test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pt_flax_cross_test()(test_case) + + +def is_staging_test(test_case): + """ + Decorator marking a test as a staging test. + + Those tests will run using the staging environment of huggingface.co instead of the real model hub. + """ + if not _run_staging: + return unittest.skip(reason="test is staging test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_staging_test()(test_case) + + +def is_pipeline_test(test_case): + """ + Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be + skipped. + """ + if not _run_pipeline_tests: + return unittest.skip(reason="test is pipeline test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pipeline_test()(test_case) + + +def is_agent_test(test_case): + """ + Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped. + """ + if not _run_agent_tests: + return unittest.skip(reason="test is an agent test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_agent_test()(test_case) + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def tooslow(test_case): + """ + Decorator marking a test as too slow. + + Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as + these will not be tested by the CI. + + """ + return unittest.skip(reason="test is too slow")(test_case) + + +def skip_if_not_implemented(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except NotImplementedError as e: + raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}") + + return wrapper + + +def apply_skip_if_not_implemented(cls): + """ + Class decorator to apply @skip_if_not_implemented to all test methods. + """ + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, skip_if_not_implemented(attr)) + return cls + + +def custom_tokenizers(test_case): + """ + Decorator marking a test for a custom tokenizer. + + Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS + environment variable to a truthy value to run them. + """ + return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case) + + +def require_bs4(test_case): + """ + Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed. + """ + return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case) + + +def require_galore_torch(test_case): + """ + Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed. + https://github.com/jiaweizzhao/GaLore + """ + return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case) + + +def require_lomo(test_case): + """ + Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed. + https://github.com/OpenLMLab/LOMO + """ + return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case) + + +def require_grokadamw(test_case): + """ + Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed. + """ + return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) + + +def require_schedulefree(test_case): + """ + Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed. + https://github.com/facebookresearch/schedule_free + """ + return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case) + + +def require_cv2(test_case): + """ + Decorator marking a test that requires OpenCV. + + These tests are skipped when OpenCV isn't installed. + + """ + return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case) + + +def require_levenshtein(test_case): + """ + Decorator marking a test that requires Levenshtein. + + These tests are skipped when Levenshtein isn't installed. + + """ + return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case) + + +def require_nltk(test_case): + """ + Decorator marking a test that requires NLTK. + + These tests are skipped when NLTK isn't installed. + + """ + return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case) + + +def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless( + is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}" + )(test_case) + + +def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): + """ + Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. + """ + return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")( + test_case + ) + + +def require_fsdp(test_case, min_version: str = "1.12.0"): + """ + Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed. + """ + return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")( + test_case + ) + + +def require_g2p_en(test_case): + """ + Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case) + + +def require_safetensors(test_case): + """ + Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed. + """ + return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) + + +def require_rjieba(test_case): + """ + Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. + """ + return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case) + + +def require_jieba(test_case): + """ + Decorator marking a test that requires jieba. These tests are skipped when jieba isn't installed. + """ + return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case) + + +def require_jinja(test_case): + """ + Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed. + """ + return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case) + + +def require_tf2onnx(test_case): + return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case) + + +def require_onnx(test_case): + return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case) + + +def require_timm(test_case): + """ + Decorator marking a test that requires Timm. + + These tests are skipped when Timm isn't installed. + + """ + return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case) + + +def require_natten(test_case): + """ + Decorator marking a test that requires NATTEN. + + These tests are skipped when NATTEN isn't installed. + + """ + return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case) + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. + + These tests are skipped when PyTorch isn't installed. + + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_flash_attn(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) + + +def require_torch_sdpa(test_case): + """ + Decorator marking a test that requires PyTorch's SDPA. + + These tests are skipped when requirements are not met (torch version). + """ + return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) + + +def require_read_token(fn): + """ + A decorator that loads the HF token for tests that require to load gated models. + """ + token = os.getenv("HF_HUB_READ_TOKEN") + + @wraps(fn) + def _inner(*args, **kwargs): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return fn(*args, **kwargs) + else: # Allow running locally with the default token env variable + return fn(*args, **kwargs) + + return _inner + + +def require_peft(test_case): + """ + Decorator marking a test that requires PEFT. + + These tests are skipped when PEFT isn't installed. + + """ + return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case) + + +def require_torchvision(test_case): + """ + Decorator marking a test that requires Torchvision. + + These tests are skipped when Torchvision isn't installed. + + """ + return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case) + + +def require_torch_or_tf(test_case): + """ + Decorator marking a test that requires PyTorch or TensorFlow. + + These tests are skipped when neither PyTorch not TensorFlow is installed. + + """ + return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")( + test_case + ) + + +def require_intel_extension_for_pytorch(test_case): + """ + Decorator marking a test that requires Intel Extension for PyTorch. + + These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch + version. + + """ + return unittest.skipUnless( + is_ipex_available(), + "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see" + " https://github.com/intel/intel-extension-for-pytorch", + )(test_case) + + +def require_tensorflow_probability(test_case): + """ + Decorator marking a test that requires TensorFlow probability. + + These tests are skipped when TensorFlow probability isn't installed. + + """ + return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")( + test_case + ) + + +def require_torchaudio(test_case): + """ + Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed. + """ + return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case) + + +def require_tf(test_case): + """ + Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed. + """ + return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case) + + +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + +def require_sentencepiece(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case) + + +def require_sacremoses(test_case): + """ + Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed. + """ + return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case) + + +def require_seqio(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case) + + +def require_scipy(test_case): + """ + Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case) + + +def require_tokenizers(test_case): + """ + Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. + """ + return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) + + +def require_tensorflow_text(test_case): + """ + Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't + installed. + """ + return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case) + + +def require_keras_nlp(test_case): + """ + Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed. + """ + return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case) + + +def require_pandas(test_case): + """ + Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. + """ + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) + + +def require_pytesseract(test_case): + """ + Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed. + """ + return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case) + + +def require_pytorch_quantization(test_case): + """ + Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch + Quantization Toolkit isn't installed. + """ + return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")( + test_case + ) + + +def require_vision(test_case): + """ + Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't + installed. + """ + return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case) + + +def require_ftfy(test_case): + """ + Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed. + """ + return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case) + + +def require_spacy(test_case): + """ + Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed. + """ + return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case) + + +def require_torch_multi_gpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without + multiple GPUs. + + To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + device_count = get_device_count() + + return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case) + + +def require_torch_multi_accelerator(test_case): + """ + Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine + without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain + multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator" + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")( + test_case + ) + + +def require_torch_non_multi_gpu(test_case): + """ + Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case) + + +def require_torch_non_multi_accelerator(test_case): + """ + Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case) + + +def require_torch_up_to_2_gpus(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case) + + +def require_torch_up_to_2_accelerators(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")( + test_case + ) + + +def require_torch_xla(test_case): + """ + Decorator marking a test that requires TorchXLA (in PyTorch). + """ + return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case) + + +def require_torch_neuroncore(test_case): + """ + Decorator marking a test that requires NeuronCore (in PyTorch). + """ + return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")( + test_case + ) + + +def require_torch_npu(test_case): + """ + Decorator marking a test that requires NPU (in PyTorch). + """ + return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case) + + +def require_torch_multi_npu(test_case): + """ + Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without + multiple NPUs. + + To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu" + """ + if not is_torch_npu_available(): + return unittest.skip(reason="test requires PyTorch NPU")(test_case) + + return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case) + + +def require_torch_xpu(test_case): + """ + Decorator marking a test that requires XPU (in PyTorch). + + These tests are skipped when XPU backend is not available. XPU backend might be available either via stock + PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version + must match match current PyTorch version. + """ + return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case) + + +def require_non_xpu(test_case): + """ + Decorator marking a test that should be skipped for XPU. + """ + return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case) + + +def require_torch_multi_xpu(test_case): + """ + Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without + multiple XPUs. + + To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" + """ + if not is_torch_xpu_available(): + return unittest.skip(reason="test requires PyTorch XPU")(test_case) + + return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) + + +if is_torch_available(): + # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode + import torch + + if "TRANSFORMERS_TEST_BACKEND" in os.environ: + backend = os.environ["TRANSFORMERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" + f" traceback):\n{e}" + ) from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ: + torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] + if torch_device == "cuda" and not torch.cuda.is_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment." + ) + if torch_device == "xpu" and not is_torch_xpu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment." + ) + if torch_device == "npu" and not is_torch_npu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment." + ) + + try: + # try creating device to see if provided device is valid + _ = torch.device(torch_device) + except RuntimeError as e: + raise RuntimeError( + f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}" + ) from e + elif torch.cuda.is_available(): + torch_device = "cuda" + elif _run_third_party_device_tests and is_torch_npu_available(): + torch_device = "npu" + elif _run_third_party_device_tests and is_torch_xpu_available(): + torch_device = "xpu" + else: + torch_device = "cpu" +else: + torch_device = None + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax + + jax_device = jax.default_backend() +else: + jax_device = None + + +def require_torchdynamo(test_case): + """Decorator marking a test that requires TorchDynamo""" + return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) + + +def require_torchao(test_case): + """Decorator marking a test that requires torchao""" + return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) + + +def require_torch_tensorrt_fx(test_case): + """Decorator marking a test that requires Torch-TensorRT FX""" + return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) + + +def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): + """ + Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. + """ + if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available(): + return test_case + return require_torch_gpu(test_case) + + +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accessible accelerator and PyTorch.""" + return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")( + test_case + ) + + +def require_torch_fp16(test_case): + """Decorator marking a test that requires a device that supports fp16""" + return unittest.skipUnless( + is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support" + )(test_case) + + +def require_fp8(test_case): + """Decorator marking a test that requires supports for fp8""" + return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")( + test_case + ) + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires a device that supports bf16""" + return unittest.skipUnless( + is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support" + )(test_case) + + +def require_torch_bf16_gpu(test_case): + """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0""" + return unittest.skipUnless( + is_torch_bf16_gpu_available(), + "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0", + )(test_case) + + +def require_torch_bf16_cpu(test_case): + """Decorator marking a test that requires torch>=1.10, using CPU.""" + return unittest.skipUnless( + is_torch_bf16_cpu_available(), + "test requires torch>=1.10, using CPU", + )(test_case) + + +def require_deterministic_for_xpu(test_case): + if is_torch_xpu_available(): + return unittest.skipUnless(is_torch_deterministic(), "test requires torch to use deterministic algorithms")( + test_case + ) + else: + return test_case + + +def require_torch_tf32(test_case): + """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7.""" + return unittest.skipUnless( + is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7" + )(test_case) + + +def require_detectron2(test_case): + """Decorator marking a test that requires detectron2.""" + return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case) + + +def require_faiss(test_case): + """Decorator marking a test that requires faiss.""" + return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case) + + +def require_optuna(test_case): + """ + Decorator marking a test that requires optuna. + + These tests are skipped when optuna isn't installed. + + """ + return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case) + + +def require_ray(test_case): + """ + Decorator marking a test that requires Ray/tune. + + These tests are skipped when Ray/tune isn't installed. + + """ + return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case) + + +def require_sigopt(test_case): + """ + Decorator marking a test that requires SigOpt. + + These tests are skipped when SigOpt isn't installed. + + """ + return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case) + + +def require_wandb(test_case): + """ + Decorator marking a test that requires wandb. + + These tests are skipped when wandb isn't installed. + + """ + return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) + + +def require_clearml(test_case): + """ + Decorator marking a test requires clearml. + + These tests are skipped when clearml isn't installed. + + """ + return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case) + + +def require_soundfile(test_case): + """ + Decorator marking a test that requires soundfile + + These tests are skipped when soundfile isn't installed. + + """ + return unittest.skipUnless(is_soundfile_available(), "test requires soundfile")(test_case) + + +def require_deepspeed(test_case): + """ + Decorator marking a test that requires deepspeed + """ + return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case) + + +def require_apex(test_case): + """ + Decorator marking a test that requires apex + """ + return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case) + + +def require_aqlm(test_case): + """ + Decorator marking a test that requires aqlm + """ + return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) + + +def require_vptq(test_case): + """ + Decorator marking a test that requires vptq + """ + return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case) + + +def require_eetq(test_case): + """ + Decorator marking a test that requires eetq + """ + eetq_available = is_eetq_available() + if eetq_available: + try: + import eetq # noqa: F401 + except ImportError as exc: + if "shard_checkpoint" in str(exc): + # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed + # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34. + # TODO: Remove once eetq releases a fix and this release is used in CI + eetq_available = False + return unittest.skipUnless(eetq_available, "test requires eetq")(test_case) + + +def require_av(test_case): + """ + Decorator marking a test that requires av + """ + return unittest.skipUnless(is_av_available(), "test requires av")(test_case) + + +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed. + """ + if is_bitsandbytes_available() and is_torch_available(): + try: + import pytest + + return pytest.mark.bitsandbytes(test_case) + except ImportError: + return test_case + else: + return unittest.skip(reason="test requires bitsandbytes and torch")(test_case) + + +def require_optimum(test_case): + """ + Decorator for optimum dependency + """ + return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case) + + +def require_tensorboard(test_case): + """ + Decorator for `tensorboard` dependency + """ + return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard") + + +def require_auto_gptq(test_case): + """ + Decorator for auto_gptq dependency + """ + return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) + + +def require_auto_awq(test_case): + """ + Decorator for auto_awq dependency + """ + return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case) + + +def require_optimum_quanto(test_case): + """ + Decorator for quanto dependency + """ + return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case) + + +def require_compressed_tensors(test_case): + """ + Decorator for compressed_tensors dependency + """ + return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case) + + +def require_fbgemm_gpu(test_case): + """ + Decorator for fbgemm_gpu dependency + """ + return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) + + +def require_flute_hadamard(test_case): + """ + Decorator marking a test that requires higgs and hadamard + """ + return unittest.skipUnless( + is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform" + )(test_case) + + +def require_phonemizer(test_case): + """ + Decorator marking a test that requires phonemizer + """ + return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case) + + +def require_pyctcdecode(test_case): + """ + Decorator marking a test that requires pyctcdecode + """ + return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case) + + +def require_librosa(test_case): + """ + Decorator marking a test that requires librosa + """ + return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) + + +def require_liger_kernel(test_case): + """ + Decorator marking a test that requires liger_kernel + """ + return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case) + + +def require_essentia(test_case): + """ + Decorator marking a test that requires essentia + """ + return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case) + + +def require_pretty_midi(test_case): + """ + Decorator marking a test that requires pretty_midi + """ + return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case) + + +def cmd_exists(cmd): + return shutil.which(cmd) is not None + + +def require_usr_bin_time(test_case): + """ + Decorator marking a test that requires `/usr/bin/time` + """ + return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case) + + +def require_sudachi(test_case): + """ + Decorator marking a test that requires sudachi + """ + return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case) + + +def require_sudachi_projection(test_case): + """ + Decorator marking a test that requires sudachi_projection + """ + return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")( + test_case + ) + + +def require_jumanpp(test_case): + """ + Decorator marking a test that requires jumanpp + """ + return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case) + + +def require_cython(test_case): + """ + Decorator marking a test that requires jumanpp + """ + return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case) + + +def require_tiktoken(test_case): + """ + Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed. + """ + return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case) + + +def get_gpu_count(): + """ + Return the number of available gpus (regardless of whether torch, tf or jax is used) + """ + if is_torch_available(): + import torch + + return torch.cuda.device_count() + elif is_tf_available(): + import tensorflow as tf + + return len(tf.config.list_physical_devices("GPU")) + elif is_flax_available(): + import jax + + return jax.device_count() + else: + return 0 + + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return os.path.join(tests_dir, append_path) + else: + return tests_dir + + +# +# Helper functions for dealing with testing text outputs +# The original code came from: +# https://github.com/fastai/fastai/blob/master/tests/utils/text.py + + +# When any function contains print() calls that get overwritten, like progress bars, +# a special care needs to be applied, since under pytest -s captured output (capsys +# or contextlib.redirect_stdout) contains any temporary printed strings, followed by +# \r's. This helper function ensures that the buffer will contain the same output +# with and without -s in pytest, by turning: +# foo bar\r tar mar\r final message +# into: +# final message +# it can handle a single string or a multiline buffer +def apply_print_resets(buf): + return re.sub(r"^.*\r", "", buf, 0, re.M) + + +def assert_screenout(out, what): + out_pr = apply_print_resets(out).lower() + match_str = out_pr.find(what.lower()) + assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" + + +def set_model_tester_for_less_flaky_test(test_case): + target_num_hidden_layers = 1 + # TODO (if possible): Avoid exceptional cases + exceptional_classes = [ + "ZambaModelTester", + "RwkvModelTester", + "AriaVisionText2TextModelTester", + "GPTNeoModelTester", + "DPTModelTester", + ] + if test_case.model_tester.__class__.__name__ in exceptional_classes: + target_num_hidden_layers = None + if hasattr(test_case.model_tester, "out_features") or hasattr(test_case.model_tester, "out_indices"): + target_num_hidden_layers = None + + if hasattr(test_case.model_tester, "num_hidden_layers") and target_num_hidden_layers is not None: + test_case.model_tester.num_hidden_layers = target_num_hidden_layers + if ( + hasattr(test_case.model_tester, "vision_config") + and "num_hidden_layers" in test_case.model_tester.vision_config + and target_num_hidden_layers is not None + ): + test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config) + test_case.model_tester.vision_config["num_hidden_layers"] = target_num_hidden_layers + if ( + hasattr(test_case.model_tester, "text_config") + and "num_hidden_layers" in test_case.model_tester.text_config + and target_num_hidden_layers is not None + ): + test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config) + test_case.model_tester.text_config["num_hidden_layers"] = target_num_hidden_layers + + # A few model class specific handling + + # For Albert + if hasattr(test_case.model_tester, "num_hidden_groups"): + test_case.model_tester.num_hidden_groups = test_case.model_tester.num_hidden_layers + + +def set_config_for_less_flaky_test(config): + target_attrs = [ + "rms_norm_eps", + "layer_norm_eps", + "norm_eps", + "norm_epsilon", + "layer_norm_epsilon", + "batch_norm_eps", + ] + for target_attr in target_attrs: + setattr(config, target_attr, 1.0) + + # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance. + # (We don't need the original epsilon values to check eager/sdpa matches) + attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"] + for attr in attrs: + if hasattr(config, attr): + for target_attr in target_attrs: + setattr(getattr(config, attr), target_attr, 1.0) + + +def set_model_for_less_flaky_test(model): + # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.) + target_names = ("LayerNorm", "GroupNorm", "BatchNorm", "RMSNorm", "BatchNorm2d", "BatchNorm1d") + target_attrs = ["eps", "epsilon", "variance_epsilon"] + if is_torch_available() and isinstance(model, torch.nn.Module): + for module in model.modules(): + if type(module).__name__.endswith(target_names): + for attr in target_attrs: + if hasattr(module, attr): + setattr(module, attr, 1.0) + + +class CaptureStd: + """ + Context manager to capture: + + - stdout: replay it, clean it up and make it available via `obj.out` + - stderr: replay it and make it available via `obj.err` + + Args: + out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not. + err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not. + replay (`bool`, *optional*, defaults to `True`): Whether to replay or not. + By default each captured stream gets replayed back on context's exit, so that one can see what the test was + doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to + disable this feature. + + Examples: + + ```python + # to capture stdout only with auto-replay + with CaptureStdout() as cs: + print("Secret message") + assert "message" in cs.out + + # to capture stderr only with auto-replay + import sys + + with CaptureStderr() as cs: + print("Warning: ", file=sys.stderr) + assert "Warning" in cs.err + + # to capture both streams with auto-replay + with CaptureStd() as cs: + print("Secret message") + print("Warning: ", file=sys.stderr) + assert "message" in cs.out + assert "Warning" in cs.err + + # to capture just one of the streams, and not the other, with auto-replay + with CaptureStd(err=False) as cs: + print("Secret message") + assert "message" in cs.out + # but best use the stream-specific subclasses + + # to capture without auto-replay + with CaptureStd(replay=False) as cs: + print("Secret message") + assert "message" in cs.out + ```""" + + def __init__(self, out=True, err=True, replay=True): + self.replay = replay + + if out: + self.out_buf = StringIO() + self.out = "error: CaptureStd context is unfinished yet, called too early" + else: + self.out_buf = None + self.out = "not capturing stdout" + + if err: + self.err_buf = StringIO() + self.err = "error: CaptureStd context is unfinished yet, called too early" + else: + self.err_buf = None + self.err = "not capturing stderr" + + def __enter__(self): + if self.out_buf: + self.out_old = sys.stdout + sys.stdout = self.out_buf + + if self.err_buf: + self.err_old = sys.stderr + sys.stderr = self.err_buf + + return self + + def __exit__(self, *exc): + if self.out_buf: + sys.stdout = self.out_old + captured = self.out_buf.getvalue() + if self.replay: + sys.stdout.write(captured) + self.out = apply_print_resets(captured) + + if self.err_buf: + sys.stderr = self.err_old + captured = self.err_buf.getvalue() + if self.replay: + sys.stderr.write(captured) + self.err = captured + + def __repr__(self): + msg = "" + if self.out_buf: + msg += f"stdout: {self.out}\n" + if self.err_buf: + msg += f"stderr: {self.err}\n" + return msg + + +# in tests it's the best to capture only the stream that's wanted, otherwise +# it's easy to miss things, so unless you need to capture both streams, use the +# subclasses below (less typing). Or alternatively, configure `CaptureStd` to +# disable the stream you don't need to test. + + +class CaptureStdout(CaptureStd): + """Same as CaptureStd but captures only stdout""" + + def __init__(self, replay=True): + super().__init__(err=False, replay=replay) + + +class CaptureStderr(CaptureStd): + """Same as CaptureStd but captures only stderr""" + + def __init__(self, replay=True): + super().__init__(out=False, replay=replay) + + +class CaptureLogger: + """ + Context manager to capture `logging` streams + + Args: + logger: 'logging` logger object + + Returns: + The captured output is available via `self.out` + + Example: + + ```python + >>> from transformers import logging + >>> from transformers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" + + +@contextlib.contextmanager +def LoggingLevel(level): + """ + This is a context manager to temporarily change transformers modules logging level to the desired value and have it + restored to the original setting at the end of the scope. + + Example: + + ```python + with LoggingLevel(logging.INFO): + AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times + ``` + """ + orig_level = transformers_logging.get_verbosity() + try: + transformers_logging.set_verbosity(level) + yield + finally: + transformers_logging.set_verbosity(orig_level) + + +class TemporaryHubRepo: + """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to + `tempfile.TemporaryDirectory` and can be used as a context manager. For example: + + with TemporaryHubRepo(token=self._token) as temp_repo: + ... + + Upon exiting the context, the repository and everything contained in it are removed. + + Example: + + ```python + with TemporaryHubRepo(token=self._token) as temp_repo: + model.push_to_hub(tmp_repo.repo_id, token=self._token) + ``` + """ + + def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None: + self.token = token + with tempfile.TemporaryDirectory() as tmp_dir: + repo_id = Path(tmp_dir).name + if namespace is not None: + repo_id = f"{namespace}/{repo_id}" + self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token) + + def __enter__(self): + return self.repo_url + + def __exit__(self, exc, value, tb): + delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True) + + +@contextlib.contextmanager +# adapted from https://stackoverflow.com/a/64789046/9201239 +def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: + """ + Temporary add given path to `sys.path`. + + Usage : + + ```python + with ExtendSysPath("/path/to/dir"): + mymodule = importlib.import_module("mymodule") + ``` + """ + + path = os.fspath(path) + try: + sys.path.insert(0, path) + yield + finally: + sys.path.remove(path) + + +class TestCasePlus(unittest.TestCase): + """ + This class extends *unittest.TestCase* with additional features. + + Feature 1: A set of fully resolved important file and dir path accessors. + + In tests often we need to know where things are relative to the current test file, and it's not trivial since the + test could be invoked from more than one directory or could reside in sub-directories with different depths. This + class solves this problem by sorting out all the basic paths and provides easy accessors to them: + + - `pathlib` objects (all fully resolved): + + - `test_file_path` - the current test file path (=`__file__`) + - `test_file_dir` - the directory containing the current test file + - `tests_dir` - the directory of the `tests` test suite + - `examples_dir` - the directory of the `examples` test suite + - `repo_root_dir` - the directory of the repository + - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides) + + - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects: + + - `test_file_path_str` + - `test_file_dir_str` + - `tests_dir_str` + - `examples_dir_str` + - `repo_root_dir_str` + - `src_dir_str` + + Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test. + + 1. Create a unique temporary dir: + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir() + ``` + + `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the + test. + + + 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't + empty it after the test. + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir("./xxx") + ``` + + This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests + didn't leave any data in there. + + 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the + following behavior: + + `before=True`: the temporary dir will always be cleared at the beginning of the test. + + `before=False`: if the temporary dir already existed, any existing files will remain there. + + `after=True`: the temporary dir will always be deleted at the end of the test. + + `after=False`: the temporary dir will always be left intact at the end of the test. + + Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are + allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem + will get nuked. i.e. please always pass paths that start with `./` + + Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested + otherwise. + + Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This + is useful for invoking external programs from the test suite - e.g. distributed training. + + + ```python + def test_whatever(self): + env = self.get_env() + ```""" + + def setUp(self): + # get_auto_remove_tmp_dir feature: + self.teardown_tmp_dirs = [] + + # figure out the resolved paths for repo_root, tests, examples, etc. + self._test_file_path = inspect.getfile(self.__class__) + path = Path(self._test_file_path).resolve() + self._test_file_dir = path.parents[0] + for up in [1, 2, 3]: + tmp_dir = path.parents[up] + if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir(): + break + if tmp_dir: + self._repo_root_dir = tmp_dir + else: + raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}") + self._tests_dir = self._repo_root_dir / "tests" + self._examples_dir = self._repo_root_dir / "examples" + self._src_dir = self._repo_root_dir / "src" + + @property + def test_file_path(self): + return self._test_file_path + + @property + def test_file_path_str(self): + return str(self._test_file_path) + + @property + def test_file_dir(self): + return self._test_file_dir + + @property + def test_file_dir_str(self): + return str(self._test_file_dir) + + @property + def tests_dir(self): + return self._tests_dir + + @property + def tests_dir_str(self): + return str(self._tests_dir) + + @property + def examples_dir(self): + return self._examples_dir + + @property + def examples_dir_str(self): + return str(self._examples_dir) + + @property + def repo_root_dir(self): + return self._repo_root_dir + + @property + def repo_root_dir_str(self): + return str(self._repo_root_dir) + + @property + def src_dir(self): + return self._src_dir + + @property + def src_dir_str(self): + return str(self._src_dir) + + def get_env(self): + """ + Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's + invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training. + + It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally + the preset `PYTHONPATH` if any (all full resolved paths). + + """ + env = os.environ.copy() + paths = [self.src_dir_str] + if "/examples" in self.test_file_dir_str: + paths.append(self.examples_dir_str) + else: + paths.append(self.tests_dir_str) + paths.append(env.get("PYTHONPATH", "")) + + env["PYTHONPATH"] = ":".join(paths) + return env + + def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): + """ + Args: + tmp_dir (`string`, *optional*): + if `None`: + + - a unique temporary path will be created + - sets `before=True` if `before` is `None` + - sets `after=True` if `after` is `None` + else: + + - `tmp_dir` will be created + - sets `before=True` if `before` is `None` + - sets `after=False` if `after` is `None` + before (`bool`, *optional*): + If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the + `tmp_dir` already exists, any existing files will remain there. + after (`bool`, *optional*): + If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents + intact at the end of the test. + + Returns: + tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir + """ + if tmp_dir is not None: + # defining the most likely desired behavior for when a custom path is provided. + # this most likely indicates the debug mode where we want an easily locatable dir that: + # 1. gets cleared out before the test (if it already exists) + # 2. is left intact after the test + if before is None: + before = True + if after is None: + after = False + + # using provided path + path = Path(tmp_dir).resolve() + + # to avoid nuking parts of the filesystem, only relative paths are allowed + if not tmp_dir.startswith("./"): + raise ValueError( + f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`" + ) + + # ensure the dir is empty to start with + if before is True and path.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + path.mkdir(parents=True, exist_ok=True) + + else: + # defining the most likely desired behavior for when a unique tmp path is auto generated + # (not a debug mode), here we require a unique tmp dir that: + # 1. is empty before the test (it will be empty in this situation anyway) + # 2. gets fully removed after the test + if before is None: + before = True + if after is None: + after = True + + # using unique tmp dir (always empty, regardless of `before`) + tmp_dir = tempfile.mkdtemp() + + if after is True: + # register for deletion + self.teardown_tmp_dirs.append(tmp_dir) + + return tmp_dir + + def python_one_liner_max_rss(self, one_liner_str): + """ + Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the + program. + + Args: + one_liner_str (`string`): + a python one liner code that gets passed to `python -c` + + Returns: + max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run. + + Requirements: + this helper needs `/usr/bin/time` to be installed (`apt install time`) + + Example: + + ``` + one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")' + max_rss = self.python_one_liner_max_rss(one_liner_str) + ``` + """ + + if not cmd_exists("/usr/bin/time"): + raise ValueError("/usr/bin/time is required, install with `apt install time`") + + cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'") + with CaptureStd() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + # returned data is in KB so convert to bytes + max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024 + return max_rss + + def tearDown(self): + # get_auto_remove_tmp_dir feature: remove registered temp dirs + for path in self.teardown_tmp_dirs: + shutil.rmtree(path, ignore_errors=True) + self.teardown_tmp_dirs = [] + if is_accelerate_available(): + AcceleratorState._reset_state() + PartialState._reset_state() + + # delete all the env variables having `ACCELERATE` in them + for k in list(os.environ.keys()): + if "ACCELERATE" in k: + del os.environ[k] + + +def mockenv(**kwargs): + """ + this is a convenience wrapper, that allows this :: + + @mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): + run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False) + + """ + return mock.patch.dict(os.environ, kwargs) + + +# from https://stackoverflow.com/a/34333710/9201239 +@contextlib.contextmanager +def mockenv_context(*remove, **update): + """ + Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv + + The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations. + + Args: + remove: Environment variables to remove. + update: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after] + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal + changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` + plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = f"reports/{id}" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + + # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it + # takes > 10 minutes (as this part doesn't generate any output on the terminal). + # (also, it seems there is no useful information in this report, and we rarely need to read it) + # with open(report_files["passes"], "w") as f: + # tr._tw = create_terminal_writer(config, f) + # tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +# --- distributed testing functions --- # + +# adapted from https://stackoverflow.com/a/59041913/9201239 +import asyncio # noqa + + +class _RunOutput: + def __init__(self, returncode, stdout, stderr): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + # XXX: the timeout doesn't seem to make any difference here + await asyncio.wait( + [ + _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")), + _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) + ) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + # check that the subprocess actually did run and produced some output, should the test rely on + # the remote side to do the testing + if not result.stdout and not result.stderr: + raise RuntimeError(f"'{cmd_str}' produced no output.") + + return result + + +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 + if `-n 1` or `pytest-xdist` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.M) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. + + Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same + port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta + + +def nested_simplify(obj, decimals=3): + """ + Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test + within tests. + """ + import numpy as np + + if isinstance(obj, list): + return [nested_simplify(item, decimals) for item in obj] + if isinstance(obj, tuple): + return tuple([nested_simplify(item, decimals) for item in obj]) + elif isinstance(obj, np.ndarray): + return nested_simplify(obj.tolist()) + elif isinstance(obj, Mapping): + return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} + elif isinstance(obj, (str, int, np.int64)): + return obj + elif obj is None: + return obj + elif is_torch_available() and isinstance(obj, torch.Tensor): + return nested_simplify(obj.tolist(), decimals) + elif is_tf_available() and tf.is_tensor(obj): + return nested_simplify(obj.numpy().tolist()) + elif isinstance(obj, float): + return round(obj, decimals) + elif isinstance(obj, (np.int32, np.float32, np.float16)): + return nested_simplify(obj.item(), decimals) + else: + raise Exception(f"Not supported: {type(obj)}") + + +def check_json_file_has_correct_format(file_path): + with open(file_path, "r") as f: + lines = f.readlines() + if len(lines) == 1: + # length can only be 1 if dict is empty + assert lines[0] == "{}" + else: + # otherwise make sure json has correct format (at least 3 lines) + assert len(lines) >= 3 + # each key one line, ident should be 2, min length is 3 + assert lines[0].strip() == "{" + for line in lines[1:-1]: + left_indent = len(lines[1]) - len(lines[1].lstrip()) + assert left_indent == 2 + assert lines[-1].strip() == "}" + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# These utils relate to ensuring the right error message is received when running scripts +class SubprocessCallException(Exception): + pass + + +def run_command(command: List[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occured while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +class RequestCounter: + """ + Helper class that will count all requests made online. + + Might not be robust if urllib3 changes its logging format but should be good enough for us. + + Usage: + ```py + with RequestCounter() as counter: + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + assert counter["GET"] == 0 + assert counter["HEAD"] == 1 + assert counter.total_calls == 1 + ``` + """ + + def __enter__(self): + self._counter = defaultdict(int) + self._thread_id = threading.get_ident() + self._extra_info = [] + + def patched_with_thread_info(func): + def wrap(*args, **kwargs): + self._extra_info.append(threading.get_ident()) + return func(*args, **kwargs) + + return wrap + + self.patcher = patch.object( + urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug) + ) + self.mock = self.patcher.start() + return self + + def __exit__(self, *args, **kwargs) -> None: + assert len(self.mock.call_args_list) == len(self._extra_info) + + for thread_id, call in zip(self._extra_info, self.mock.call_args_list): + if thread_id != self._thread_id: + continue + log = call.args[0] % call.args[1:] + for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): + if method in log: + self._counter[method] += 1 + break + self.patcher.stop() + + def __getitem__(self, key: str) -> int: + return self._counter[key] + + @property + def total_calls(self) -> int: + return sum(self._counter.values()) + + +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + description (`str`, *optional*): + A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, + etc.) + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') + + +def run_test_using_subprocess(func): + """ + To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory + issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`). + """ + import pytest + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if os.getenv("_INSIDE_SUB_PROCESS", None) == "1": + func(*args, **kwargs) + else: + test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1]) + try: + import copy + + env = copy.deepcopy(os.environ) + env["_INSIDE_SUB_PROCESS"] = "1" + # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the + # full information can be passed to the parent pytest process. + # See: https://docs.pytest.org/en/stable/explanation/ci.html + env["CI"] = "true" + + # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments + if "pytestconfig" in kwargs: + command = list(kwargs["pytestconfig"].invocation_params.args) + for idx, x in enumerate(command): + if x in kwargs["pytestconfig"].args: + test = test.split("::")[1:] + command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test) + command = [f"{sys.executable}", "-m", "pytest"] + command + command = [x for x in command if x not in ["--no-summary"]] + # Otherwise, simply run the test with no option at all + else: + command = [f"{sys.executable}", "-m", "pytest", f"{test}"] + + subprocess.run(command, env=env, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + exception_message = e.stdout.decode() + lines = exception_message.split("\n") + # Add a first line with more informative information instead of just `= test session starts =`. + # This makes the `short test summary info` section more useful. + if "= test session starts =" in lines[0]: + text = "" + for line in lines[1:]: + if line.startswith("FAILED "): + text = line[len("FAILED ") :] + text = "".join(text.split(" - ")[1:]) + elif line.startswith("=") and line.endswith("=") and " failed in " in line: + break + elif len(text) > 0: + text += f"\n{line}" + text = "(subprocess) " + text + lines = [text] + lines + exception_message = "\n".join(lines) + raise pytest.fail(exception_message, pytrace=False) + + return wrapper + + +""" +The following contains utils to run the documentation tests without having to overwrite any files. + +The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is +made as a print would otherwise fail the corresonding line. + +To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules +""" + + +def preprocess_string(string, skip_cuda_tests): + """Prepare a docstring or a `.md` file to be run by doctest. + + The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of + its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a + cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for + `string`. + """ + codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:.*?\n)*?.*?```)" + codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string) + is_cuda_found = False + for i, codeblock in enumerate(codeblocks): + if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock: + codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock) + if ( + (">>>" in codeblock or "..." in codeblock) + and re.search(r"cuda|to\(0\)|device=0", codeblock) + and skip_cuda_tests + ): + is_cuda_found = True + break + + modified_string = "" + if not is_cuda_found: + modified_string = "".join(codeblocks) + + return modified_string + + +class HfDocTestParser(doctest.DocTestParser): + """ + Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This + means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also + added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line. + + Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough. + """ + + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + # fmt: off + _EXAMPLE_RE = re.compile(r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) >>> .*) # PS1 line + (?:\n [ ]* \.\.\. .*)*) # PS2 lines + \n? + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*>>>) # Not a line starting with PS1 + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:\n|$) # Match a new line or end of string + )*) + ''', re.MULTILINE | re.VERBOSE + ) + # fmt: on + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False)) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + + def parse(self, string, name=""): + """ + Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before + calling `super().parse` + """ + string = preprocess_string(string, self.skip_cuda_tests) + return super().parse(string, name) + + +class HfDoctestModule(Module): + """ + Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering + tests. + """ + + def collect(self) -> Iterable[DoctestItem]: + class MockAwareDocTestFinder(doctest.DocTestFinder): + """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug. + + https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532 + """ + + def _find_lineno(self, obj, source_lines): + """Doctest code does not take into account `@property`, this + is a hackish way to fix it. https://bugs.python.org/issue17446 + + Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be + reported upstream. #8796 + """ + if isinstance(obj, property): + obj = getattr(obj, "fget", obj) + + if hasattr(obj, "__wrapped__"): + # Get the main obj in case of it being wrapped + obj = inspect.unwrap(obj) + + # Type ignored because this is a private function. + return super()._find_lineno( # type:ignore[misc] + obj, + source_lines, + ) + + def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: + if _is_mocked(obj): + return + with _patch_unwrap_mock_aware(): + # Type ignored because this is a private function. + super()._find( # type:ignore[misc] + tests, obj, name, module, source_lines, globs, seen + ) + + if self.path.name == "conftest.py": + module = self.config.pluginmanager._importconftest( + self.path, + self.config.getoption("importmode"), + rootpath=self.config.rootpath, + ) + else: + try: + module = import_path( + self.path, + root=self.config.rootpath, + mode=self.config.getoption("importmode"), + ) + except ImportError: + if self.config.getvalue("doctest_ignore_import_errors"): + skip("unable to import module %r" % self.path) + else: + raise + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + finder = MockAwareDocTestFinder(parser=HfDocTestParser()) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + optionflags = get_optionflags(self) + runner = _get_runner( + verbose=False, + optionflags=optionflags, + checker=_get_checker(), + continue_on_failure=_get_continue_on_failure(self.config), + ) + for test in finder.find(module, module.__name__): + if test.examples: # skip empty doctests and cuda + yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) + + +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against `None` + # instead at user level. + if fn is None: + return None + return fn(*args, **kwargs) + + +if is_torch_available(): + # Mappings from device names to callable functions to support device agnostic + # testing. + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None} + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1} +else: + BACKEND_MANUAL_SEED = {"default": None} + BACKEND_EMPTY_CACHE = {"default": None} + BACKEND_DEVICE_COUNT = {"default": lambda: 0} + + +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +if is_torch_available(): + # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries + # into device to function mappings. + if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError( + f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}" + ) + + # Try to strip extension for later import – also verifies we are importing a + # python file. + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early. + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError as e: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") + + +def compare_pipeline_output_to_hub_spec(output, hub_spec): + missing_keys = [] + unexpected_keys = [] + all_field_names = {field.name for field in fields(hub_spec)} + matching_keys = sorted([key for key in output.keys() if key in all_field_names]) + + # Fields with a MISSING default are required and must be in the output + for field in fields(hub_spec): + if field.default is MISSING and field.name not in output: + missing_keys.append(field.name) + + # All output keys must match either a required or optional field in the Hub spec + for output_key in output: + if output_key not in all_field_names: + unexpected_keys.append(output_key) + + if missing_keys or unexpected_keys: + error = ["Pipeline output does not match Hub spec!"] + if matching_keys: + error.append(f"Matching keys: {matching_keys}") + if missing_keys: + error.append(f"Missing required keys in pipeline output: {missing_keys}") + if unexpected_keys: + error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}") + raise KeyError("\n".join(error)) + + +@require_torch +def cleanup(device: str, gc_collect=False): + if gc_collect: + gc.collect() + backend_empty_cache(device) diff --git a/tf_utils.py b/tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b91a2ea520f0d03ac724566c2c8ce2ca6361f52d --- /dev/null +++ b/tf_utils.py @@ -0,0 +1,294 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import tensorflow as tf + +from .feature_extraction_utils import BatchFeature +from .tokenization_utils_base import BatchEncoding +from .utils import logging + + +logger = logging.get_logger(__name__) + + +def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: + """ + Deal with dynamic shape in tensorflow cleanly. + + Args: + tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of. + + Returns: + `List[int]`: The shape of the tensor as a list. + """ + if isinstance(tensor, np.ndarray): + return list(tensor.shape) + + dynamic = tf.shape(tensor) + + if tensor.shape == tf.TensorShape(None): + return dynamic + + static = tensor.shape.as_list() + + return [dynamic[i] if s is None else s for i, s in enumerate(static)] + + +def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor: + """ + Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is + meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be + removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that + `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html). + + Args: + logits (`tf.Tensor`): + Must be one of the following types: half, float32, float64. + axis (`int`, *optional*): + The dimension softmax would be performed on. The default is -1 which indicates the last dimension. + name (`str`, *optional*): + A name for the operation. + + Returns: + `tf.Tensor`: + A Tensor. Has the same type and shape as logits. + """ + # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if + # it has the fix. After we drop the support for unfixed versions, remove this function. + return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) + + +def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): + # This is a very simplified functional layernorm, designed to duplicate + # the functionality of PyTorch nn.functional.layer_norm when this is needed to port + # models in Transformers. + + if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int): + raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.") + + # Get mean and variance on the axis to be normalized + mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True) + + if axis != -1: + # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions + # on every dimension except axis + shape = [1] * inputs.shape.rank + shape[axis] = shape_list(inputs)[axis] + weight = tf.reshape(weight, shape) + bias = tf.reshape(bias, shape) + + # Compute layer normalization using the batch_normalization + # function. + outputs = tf.nn.batch_normalization( + inputs, + mean, + variance, + offset=bias, + scale=weight, + variance_epsilon=epsilon, + ) + return outputs + + +def scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: float = None +): + """TF equivalent for torch's nn.functional.scaled_dot_product_attention""" + if dropout_p != 0.0: + raise ValueError( + "Dropout is not supported in this implementation - file an issue " + "with Transformers and ping @Rocketknight1 if you need it for a port!" + ) + if is_causal and attn_mask is not None: + raise ValueError("You cannot specify an attn_mask and is_causal at the same time!") + if is_causal: + attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32) + attn_mask = tf.experimental.numpy.tril(attn_mask, k=0) + if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool): + # Convert boolean mask to a negative logit bias + attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype)) + logits = tf.einsum("...qd, ...kd -> ...qk", query, key) + if scale is None: + scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5 + logits *= scale # scale by 1/sqrt(key_dim) + if attn_mask is not None: + logits += attn_mask + probs = tf.nn.softmax(logits) + return probs @ value + + +def flatten(input, start_dim=0, end_dim=-1): + # Replicates the behavior of torch.flatten in TF + + # If end_dim or start_dim is negative, count them from the end + if end_dim < 0: + end_dim += input.shape.rank + if start_dim < 0: + start_dim += input.shape.rank + + if start_dim == end_dim: + return input + + in_shape = tf.shape(input) + flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) + out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) + return tf.reshape(input, out_shape) + + +def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `tf.Tensor`: The inverted attention mask. + """ + if not isinstance(encoder_attention_mask, tf.Tensor): + encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs + if encoder_attention_mask.shape.rank == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.shape.rank == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = ( + tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask + ) * encoder_extended_attention_mask.dtype.min + + return encoder_extended_attention_mask + + +def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None: + """ + `tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning + zeros instead. This function adds a check against that dangerous silent behavior. + + Args: + tensor (`tf.Tensor`): The tensor of indices to check. + embed_dim (`int`): The embedding dimension. + tensor_name (`str`, *optional*): The name of the tensor to use in the error message. + """ + tf.debugging.assert_less( + tensor, + tf.cast(embed_dim, dtype=tensor.dtype), + message=( + f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding " + f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time." + ), + ) + + +def save_attributes_to_hdf5_group(group, name, data): + """Saves attributes (data) of the specified name into the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not able to store data larger than + HDF5_OBJECT_HEADER_LIMIT bytes. + + Args: + group: A pointer to a HDF5 group. + name: A name of the attributes to save. + data: Attributes data to store. + + Raises: + RuntimeError: If any single attribute is too large to be saved. + + Copied from Keras to Transformers to avoid versioning issues. + """ + HDF5_OBJECT_HEADER_LIMIT = 64512 + # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` + # because in that case even chunking the array would not make the saving + # possible. + bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] + + # Expecting this to never be true. + if bad_attributes: + raise RuntimeError( + "The following attributes cannot be saved to HDF5 file because " + f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} " + f"bytes: {bad_attributes}" + ) + + data_npy = np.asarray(data) + + num_chunks = 1 + chunked_data = np.array_split(data_npy, num_chunks) + + # This will never loop forever thanks to the test above. + while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): + num_chunks += 1 + chunked_data = np.array_split(data_npy, num_chunks) + + if num_chunks > 1: + for chunk_id, chunk_data in enumerate(chunked_data): + group.attrs["%s%d" % (name, chunk_id)] = chunk_data + else: + group.attrs[name] = data + + +def load_attributes_from_hdf5_group(group, name): + """Loads attributes of the specified name from the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not able to store data larger than + HDF5_OBJECT_HEADER_LIMIT bytes. + + Args: + group: A pointer to a HDF5 group. + name: A name of the attributes to load. + + Returns: + data: Attributes data. + + Copied from Keras to Transformers to avoid versioning issues. + """ + if name in group.attrs: + data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]] + else: + data = [] + chunk_id = 0 + while "%s%d" % (name, chunk_id) in group.attrs: + data.extend( + [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]] + ) + chunk_id += 1 + return data + + +def expand_1d(data): + """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s. + Copied from Keras to here to avoid versioning issues.""" + + def _expand_single_1d_tensor(t): + if isinstance(t, tf.Tensor) and t.shape.rank == 1: + return tf.expand_dims(t, axis=-1) + return t + + return tf.nest.map_structure(_expand_single_1d_tensor, data) + + +def convert_batch_encoding(*args, **kwargs): + # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands + if args and isinstance(args[0], (BatchEncoding, BatchFeature)): + args = list(args) + args[0] = dict(args[0]) + elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)): + kwargs["x"] = dict(kwargs["x"]) + return args, kwargs diff --git a/time_series_utils.py b/time_series_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9716e481240294fde41608f6e00e3485e93063 --- /dev/null +++ b/time_series_utils.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Time series distributional output classes and utilities. +""" + +from typing import Callable, Dict, Optional, Tuple + +import torch +from torch import nn +from torch.distributions import ( + AffineTransform, + Distribution, + Independent, + NegativeBinomial, + Normal, + StudentT, + TransformedDistribution, +) + + +class AffineTransformed(TransformedDistribution): + def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0): + self.scale = 1.0 if scale is None else scale + self.loc = 0.0 if loc is None else loc + + super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)]) + + @property + def mean(self): + """ + Returns the mean of the distribution. + """ + return self.base_dist.mean * self.scale + self.loc + + @property + def variance(self): + """ + Returns the variance of the distribution. + """ + return self.base_dist.variance * self.scale**2 + + @property + def stddev(self): + """ + Returns the standard deviation of the distribution. + """ + return self.variance.sqrt() + + +class ParameterProjection(nn.Module): + def __init__( + self, in_features: int, args_dim: Dict[str, int], domain_map: Callable[..., Tuple[torch.Tensor]], **kwargs + ) -> None: + super().__init__(**kwargs) + self.args_dim = args_dim + self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()]) + self.domain_map = domain_map + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + params_unbounded = [proj(x) for proj in self.proj] + + return self.domain_map(*params_unbounded) + + +class LambdaLayer(nn.Module): + def __init__(self, function): + super().__init__() + self.function = function + + def forward(self, x, *args): + return self.function(x, *args) + + +class DistributionOutput: + distribution_class: type + in_features: int + args_dim: Dict[str, int] + + def __init__(self, dim: int = 1) -> None: + self.dim = dim + self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim} + + def _base_distribution(self, distr_args): + if self.dim == 1: + return self.distribution_class(*distr_args) + else: + return Independent(self.distribution_class(*distr_args), 1) + + def distribution( + self, + distr_args, + loc: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + ) -> Distribution: + distr = self._base_distribution(distr_args) + if loc is None and scale is None: + return distr + else: + return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim) + + @property + def event_shape(self) -> Tuple: + r""" + Shape of each individual event contemplated by the distributions that this object constructs. + """ + return () if self.dim == 1 else (self.dim,) + + @property + def event_dim(self) -> int: + r""" + Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object + constructs. + """ + return len(self.event_shape) + + @property + def value_in_support(self) -> float: + r""" + A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By + default 0.0. This value will be used when padding data series. + """ + return 0.0 + + def get_parameter_projection(self, in_features: int) -> nn.Module: + r""" + Return the parameter projection layer that maps the input to the appropriate parameters of the distribution. + """ + return ParameterProjection( + in_features=in_features, + args_dim=self.args_dim, + domain_map=LambdaLayer(self.domain_map), + ) + + def domain_map(self, *args: torch.Tensor): + r""" + Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the + correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a + distribution of the right event_shape. + """ + raise NotImplementedError() + + @staticmethod + def squareplus(x: torch.Tensor) -> torch.Tensor: + r""" + Helper to map inputs to the positive orthant by applying the square-plus operation. Reference: + https://twitter.com/jon_barron/status/1387167648669048833 + """ + return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0 + + +class StudentTOutput(DistributionOutput): + """ + Student-T distribution output class. + """ + + args_dim: Dict[str, int] = {"df": 1, "loc": 1, "scale": 1} + distribution_class: type = StudentT + + @classmethod + def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor): + scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps) + df = 2.0 + cls.squareplus(df) + return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1) + + +class NormalOutput(DistributionOutput): + """ + Normal distribution output class. + """ + + args_dim: Dict[str, int] = {"loc": 1, "scale": 1} + distribution_class: type = Normal + + @classmethod + def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor): + scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps) + return loc.squeeze(-1), scale.squeeze(-1) + + +class NegativeBinomialOutput(DistributionOutput): + """ + Negative Binomial distribution output class. + """ + + args_dim: Dict[str, int] = {"total_count": 1, "logits": 1} + distribution_class: type = NegativeBinomial + + @classmethod + def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor): + total_count = cls.squareplus(total_count) + return total_count.squeeze(-1), logits.squeeze(-1) + + def _base_distribution(self, distr_args) -> Distribution: + total_count, logits = distr_args + if self.dim == 1: + return self.distribution_class(total_count=total_count, logits=logits) + else: + return Independent(self.distribution_class(total_count=total_count, logits=logits), 1) + + # Overwrites the parent class method. We cannot scale using the affine + # transformation since negative binomial should return integers. Instead + # we scale the parameters. + def distribution( + self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None + ) -> Distribution: + total_count, logits = distr_args + + if scale is not None: + # See scaling property of Gamma. + logits += scale.log() + + return self._base_distribution((total_count, logits)) diff --git a/tokenization_utils.py b/tokenization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc13020e65b66af6fdbd621aec5a466c494c68f --- /dev/null +++ b/tokenization_utils.py @@ -0,0 +1,1134 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see +tokenization_utils_fast.py +""" + +import bisect +import itertools +import re +import unicodedata +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Union, overload + +from .tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, + INIT_TOKENIZER_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + EncodedInputPair, + PreTokenizedInput, + PreTokenizedInputPair, + PreTrainedTokenizerBase, + TextInput, + TextInputPair, + TruncationStrategy, +) +from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +# Slow tokenizers are saved in a vocabulary plus three separated files +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +ADDED_TOKENS_FILE = "added_tokens.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" + + +class Trie: + """ + Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass + Loose reference https://en.wikipedia.org/wiki/Trie + """ + + def __init__(self, *args): + self.data = {} + self._tokens = set() + self._termination_char = "" + self.update(*args) + + def update(self, *args): + """ + Updates the Trie with new tokens provided as arguments. + + Args: + *args: Variable number of words to be added to the Trie. + """ + for token in tuple(*args): + self.add(token) + + def add(self, word: str): + """ + Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. + The special key `""` in `self._termination_char` is used to represent termination. + + This function is idempotent, adding twice the same word will leave the trie unchanged + + Example: + + ```python + >>> trie = Trie() + >>> trie.add("Hello 友達") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} + + >>> trie.add("Hello") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} + ``` + """ + if not word: + # Prevent empty string + return + + self._tokens.add(word) + ref = self.data + for char in word: + ref[char] = ref.setdefault(char, {}) + ref = ref[char] + ref[self._termination_char] = 1 + + def split(self, text: str) -> List[str]: + """ + Will look for the words added to the trie within `text`. Output is the original string splitted along the + boundaries of the words found. + + This trie will match the longest possible word first ! + + Example: + + ```python + >>> trie = Trie() + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS] This is a extra_id_100"] + + >>> trie.add("[CLS]") + >>> trie.add("extra_id_1") + >>> trie.add("extra_id_100") + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS]", " This is a ", "extra_id_100"] + ``` + """ + # indexes are counted left of the chars index. + # "hello", index 0, is left of h, index 1 is between h and e. + # index 5 is right of the "o". + + # States are going to capture every possible start (indexes as above) + # as keys, and have as values, a pointer to the position in the trie + # where we're at. This is a partial match for now. + # This enables to keep track of multiple matches while we're iterating + # the string + # If the trie contains, "blowing", and "lower" and we encounter the + # string "blower", we need to split into ["b", "lower"]. + # This is where we need to keep track of multiple possible starts. + states = OrderedDict() + + # This will contain every indices where we need + # to cut. + # We force to cut at offset 0 and len(text) (added later) + offsets = [0] + + # This is used by the lookahead which needs to skip over + # some text where the full match exceeded the place in the initial + # for loop + skip = 0 + # Main loop, Giving this algorithm O(n) complexity + for current, current_char in enumerate(text): + if skip and current < skip: + # Prevents the lookahead for matching twice + # like extra_id_100 and id_100 + continue + + # This will track every state + # that stop matching, we need to stop tracking them. + # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then + # fail on "b", we need to remove 0 from the valid states. + to_remove = set() + # Whenever we found a match, we need to drop everything + # this is a greedy algorithm, it will match on the first found token + reset = False + + # In this case, we already have partial matches (But unfinished) + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + + # Lookahead to match longest first + # Important in case of extra_id_1 vs extra_id_100 + # Here we are also actively looking for other earlier partial + # matches + # "[CLS]", "L", we need to match CLS even if L is special + for lookstart, looktrie_pointer in states.items(): + if lookstart > start: + # This partial match is later, we can stop looking + break + elif lookstart < start: + # This partial match is earlier, the trie pointer + # was already updated, so index is + 1 + lookahead_index = current + 1 + end = current + 1 + else: + # Here lookstart == start and + # looktrie_pointer == trie_pointer + # It wasn't updated yet so indices are current ones + lookahead_index = current + end = current + next_char = text[lookahead_index] if lookahead_index < len(text) else None + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + + while next_char in looktrie_pointer: + looktrie_pointer = looktrie_pointer[next_char] + lookahead_index += 1 + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + + if lookahead_index == len(text): + # End of string + break + next_char = text[lookahead_index] + # End lookahead + + # Storing and resetting + offsets.append(start) + offsets.append(end) + reset = True + break + elif current_char in trie_pointer: + # The current character being looked at has a match within the trie + # update the pointer (it will be stored back into states later). + trie_pointer = trie_pointer[current_char] + + # Storing back the new pointer into the states. + # Partial matches got longer by one. + states[start] = trie_pointer + else: + # The new character has not match in the trie, we need + # to stop keeping track of this partial match. + # We can't do it directly within the loop because of how + # python iteration works + to_remove.add(start) + + # Either clearing the full start (we found a real match) + # Or clearing only the partial matches that didn't work. + if reset: + states = {} + else: + for start in to_remove: + del states[start] + + # If this character is a starting character within the trie + # start keeping track of this partial match. + if current >= skip and current_char in self.data: + states[current] = self.data[current_char] + + # We have a cut at the end with states. + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + end = len(text) + offsets.append(start) + offsets.append(end) + # Longest cut is always the one with lower start so the first + # item so we need to break. + break + + return self.cut_text(text, offsets) + + def cut_text(self, text, offsets): + # We have all the offsets now, we just need to do the actual splitting. + # We need to eventually add the first part of the string and the eventual + # last part. + offsets.append(len(text)) + tokens = [] + start = 0 + for end in offsets: + if start > end: + logger.error( + "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it" + " anyway." + ) + continue + elif start == end: + # This might happen if there's a match at index 0 + # we're also preventing zero-width cuts in case of two + # consecutive matches + continue + tokens.append(text[start:end]) + start = end + + return tokens + + +class ExtensionsTrie(Trie): + def __init__(self, *args): + super().__init__(*args) + + def extensions(self, prefix: str): + """ + Generates all extensions of a given prefix token in the Trie. + + Example: + + ```python + >>> trie = Trie() + >>> trie.add("apple") + >>> trie.add("app") + >>> trie.add("application") + >>> trie.extensions("app") + ['app', 'apple', 'application'] + ``` + """ + prefix_node = self._get_node(prefix) + ret = self._collect_tokens(prefix_node) + return [prefix + token for token in ret] + + def _get_node(self, token: str) -> dict: + """ + Retrieves the node corresponding to the given token in the Trie. + + Args: + token (str): The token for which the corresponding node needs to be retrieved. + + Returns: + dict: The node in the Trie corresponding to the given token. + """ + node = self.data + for char in token: + if char not in node: + break + + node = node[char] + return node + + def _collect_tokens(self, node: dict) -> list: + """ + Generates all tokens in the Trie starting from a given node. + + Args: + node (dict): The node in the Trie from which tokens need to be generated. + + Returns: + list: List of tokens generated from the given node. + """ + tokens = [self._termination_char] if self._termination_char in node else [] + for token, subtrie_head in node.items(): + if token != self._termination_char: + subtokens = self._collect_tokens(subtrie_head) + tokens.extend([token + subtoken for subtoken in subtokens]) + return tokens + + +def _is_whitespace(char): + """Checks whether `char` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `char` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `char` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def _is_end_of_word(text): + """Checks whether the last character in text is one of a punctuation, control or whitespace character.""" + last_char = text[-1] + return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char)) + + +def _is_start_of_word(text): + """Checks whether the first character in text is one of a punctuation, control or whitespace character.""" + first_char = text[0] + return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) + + +def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str): + """ + Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted. + """ + insertion_idx = bisect.bisect_left(token_list, new_token) + # Checks if new_token is already in the ordered token_list + if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token: + # new_token is in token_list, don't add + return + else: + token_list.insert(insertion_idx, new_token) + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class PreTrainedTokenizer(PreTrainedTokenizerBase): + """ + Base class for all slow tokenizers. + + Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`]. + + Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading + pretrained tokenizers as well as adding tokens to the vocabulary. + + This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the + specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). + """ + + def __init__(self, **kwargs): + # 1. Init the parent class + + self.tokens_trie = Trie() + + # 2. init `_added_tokens_decoder` if child class did not + if not hasattr(self, "_added_tokens_decoder"): + self._added_tokens_decoder: Dict[int, AddedToken] = {} + + # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite + self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {})) + self._added_tokens_encoder: Dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()} + + # 4 init the parent class + super().__init__(**kwargs) + + # 4. If some of the special tokens are not part of the vocab, we add them, at the end. + # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers` + self._add_tokens( + [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder], + special_tokens=True, + ) + + self._decode_use_source_tokenizer = False + + @property + def is_fast(self) -> bool: + return False + + @property + def vocab_size(self) -> int: + """ + `int`: Size of the base vocabulary (without the added tokens). + """ + raise NotImplementedError + + @property + def added_tokens_encoder(self) -> Dict[str, int]: + """ + Returns the sorted mapping from string to index. The added tokens encoder is cached for performance + optimisation in `self._added_tokens_encoder` for the slow tokenizers. + """ + return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])} + + @property + def added_tokens_decoder(self) -> Dict[int, AddedToken]: + """ + Returns the added tokens in the vocabulary as a dictionary of index to AddedToken. + + Returns: + `Dict[str, int]`: The added tokens. + """ + return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])) + + @added_tokens_decoder.setter + def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict[int, AddedToken]: + # Always raise an error if string because users should define the behavior + for index, token in value.items(): + if not isinstance(token, (str, AddedToken)) or not isinstance(index, int): + raise TypeError( + f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}" + ) + + self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token + self._added_tokens_encoder[str(token)] = index + self._update_total_vocab_size() + + def get_added_vocab(self) -> Dict[str, int]: + """ + Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from + the fast call because for now we always add the tokens even if they are already in the vocabulary. This is + something we should change. + + Returns: + `Dict[str, int]`: The added tokens. + """ + return self._added_tokens_encoder + + def __len__(self): + """ + Size of the full vocabulary with the added tokens. + """ + return self.total_vocab_size + + def _update_total_vocab_size(self): + """ + Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because + otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and + is only updated when adding tokens. + """ + self.total_vocab_size = len(self.get_vocab()) + + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to + it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the + vocab which is why they have to be handled specifically. + + Args: + new_tokens (`List[str]`or `List[tokenizers.AddedToken]`): + Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary + (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part + of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the + stripping and normalization of this token. This is NOT possible in `tokenizers`. + special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the tokens should be added as special tokens. + + Returns: + `int`: The number of tokens actually added to the vocabulary. + + Examples: + + ```python + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = BertModel.from_pretrained("google-bert/bert-base-uncased") + + num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"]) + print("We have added", num_added_toks, "tokens") + # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + ```""" + added_tokens = 0 + if new_tokens is None: + return added_tokens + # TODO this is fairly slow to improve! + current_vocab = self.get_vocab().copy() + new_idx = len(current_vocab) # only call this once, len gives the last index + 1 + for token in new_tokens: + if not isinstance(token, (str, AddedToken)): + raise TypeError(f"Token {token} is not a string but a {type(token)}.") + if str(token) == "": + continue + if isinstance(token, str): + if token in self._added_tokens_encoder: + continue + else: + # very important for fast and slow equivalence! + is_special = token in self.all_special_tokens or special_tokens + token = AddedToken( + token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special + ) + elif special_tokens: + # doing token.special=True changes the normalization! will fix in rust + # this is important and the only reason why the AddedTokens in each class are normalized by default + token.__setstate__({"special": True, "normalized": token.normalized}) + if token in self._added_tokens_decoder: + continue + if not token.special and token.normalized and getattr(self, "do_lower_case", False): + # Normalize if requested + token.content = token.content.lower() + if token.content not in current_vocab: + token_index = new_idx + added_tokens + current_vocab[token.content] = token_index + added_tokens += 1 + else: + token_index = current_vocab[token.content] + + if token.special and str(token) not in self.all_special_tokens: + self._special_tokens_map["additional_special_tokens"].append(token) + # the setter automatically updates the reverse map + self._added_tokens_decoder[token_index] = token + self._added_tokens_encoder[token.content] = token_index + if self.verbose: + logger.info(f"Adding {token} to the vocabulary") + + self._update_trie() + self._update_total_vocab_size() + return added_tokens + + def _update_trie(self, unique_no_split_tokens: Optional[str] = []): + for token in self._added_tokens_decoder.values(): + if token not in self.tokens_trie._tokens: + self.tokens_trie.add(token.content) + for token in unique_no_split_tokens: + if token not in self.tokens_trie._tokens: + self.tokens_trie.add(token) + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + + + This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put + this inside your training loop. + + + + Args: + pair (`bool`, *optional*, defaults to `False`): + Whether the number of added tokens should be computed in the case of a sequence pair or a single + sequence. + + Returns: + `int`: Number of special tokens added to sequences. + """ + token_ids_0 = [] + token_ids_1 = [] + return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) + + def tokenize(self, text: TextInput, **kwargs) -> List[str]: + """ + Converts a string into a sequence of tokens, using the tokenizer. + + Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies + (BPE/SentencePieces/WordPieces). Takes care of added tokens. + + Args: + text (`str`): + The sequence to be encoded. + **kwargs (additional keyword arguments): + Passed along to the model-specific `prepare_for_tokenization` preprocessing method. + + Returns: + `List[str]`: The list of tokens. + """ + split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens) + + text, kwargs = self.prepare_for_tokenization(text, **kwargs) + + if kwargs: + logger.warning(f"Keyword arguments {kwargs} not recognized.") + + if hasattr(self, "do_lower_case") and self.do_lower_case: + # convert non-special tokens to lowercase. Might be super slow as well? + escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)] + escaped_special_toks += [ + re.escape(s_tok.content) + for s_tok in (self._added_tokens_decoder.values()) + if not s_tok.special and s_tok.normalized + ] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) + + if split_special_tokens: + no_split_token = [] + tokens = [text] + else: + no_split_token = self._added_tokens_encoder.keys() # don't split on any of the added tokens + # "This is something else" + tokens = self.tokens_trie.split(text) + + # ["This is something", "", " else"] + for i, token in enumerate(tokens): + if token in no_split_token: + tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None) + left = tokens[i - 1] if i > 0 else None + right = tokens[i + 1] if i < len(tokens) - 1 else None + if isinstance(tok_extended, AddedToken): + if tok_extended.rstrip and right: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + tokens[i + 1] = right.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and left: + tokens[i - 1] = left.rstrip() # Opposite here + if tok_extended.single_word and left and left[-1] != " ": + tokens[i - 1] += token + tokens[i] = "" + elif tok_extended.single_word and right and right[0] != " ": + tokens[i + 1] = token + tokens[i + 1] + tokens[i] = "" + else: + raise ValueError( + f"{tok_extended} cannot be tokenized because it was not properly added" + f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}" + ) + # ["This is something", "", "else"] + tokenized_text = [] + for token in tokens: + # Need to skip eventual empty (fully stripped) tokens + if not token: + continue + if token in no_split_token: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize(token)) + # ["This", " is", " something", "", "else"] + return tokenized_text + + def _tokenize(self, text, **kwargs): + """ + Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: + """ + Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the + vocabulary. + + Args: + tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s). + + Returns: + `int` or `List[int]`: The token id or list of token ids. + """ + if tokens is None: + return None + + if isinstance(tokens, str): + return self._convert_token_to_id_with_added_voc(tokens) + + ids = [] + for token in tokens: + ids.append(self._convert_token_to_id_with_added_voc(token)) + return ids + + def _convert_token_to_id_with_added_voc(self, token): + if token is None: + return None + + if token in self._added_tokens_encoder: + return self._added_tokens_encoder[token] + return self._convert_token_to_id(token) + + def _convert_token_to_id(self, token): + raise NotImplementedError + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_split_into_words: + tokens = list( + itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) + ) + return self.convert_tokens_to_ids(tokens) + else: + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + if is_split_into_words: + raise ValueError( + f"Input {text} is not valid. Should be a string or a list/tuple of strings when" + " `is_split_into_words=True`." + ) + else: + raise ValueError( + f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" + " integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + first_ids = get_input_ids(text) + second_ids = get_input_ids(text_pair) if text_pair is not None else None + + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_split_into_words: + tokens = list( + itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) + ) + return self.convert_tokens_to_ids(tokens) + else: + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + input_ids = [] + for ids_or_pair_ids in batch_text_or_text_pairs: + if not isinstance(ids_or_pair_ids, (list, tuple)): + ids, pair_ids = ids_or_pair_ids, None + elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)): + ids, pair_ids = ids_or_pair_ids, None + else: + ids, pair_ids = ids_or_pair_ids + + first_ids = get_input_ids(ids) + second_ids = get_input_ids(pair_ids) if pair_ids is not None else None + input_ids.append((first_ids, second_ids)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + split_special_tokens=split_special_tokens, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for first_ids, second_ids in batch_ids_pairs: + outputs = self.prepare_for_model( + first_ids, + second_ids, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + split_special_tokens=split_special_tokens, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + def prepare_for_tokenization( + self, text: str, is_split_into_words: bool = False, **kwargs + ) -> Tuple[str, Dict[str, Any]]: + """ + Performs any necessary transformations before tokenization. + + This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the + `kwargs` at the end of the encoding process to be sure all the arguments have been used. + + Args: + text (`str`): + The text to prepare. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments to use for the tokenization. + + Returns: + `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs. + """ + return (text, kwargs) + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of ids of the first sequence. + token_ids_1 (`List[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) + + @overload + def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ... + + @overload + def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ... + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + if ids in self._added_tokens_decoder: + return self._added_tokens_decoder[ids].content + else: + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + if index in self._added_tokens_decoder: + tokens.append(self._added_tokens_decoder[index].content) + else: + tokens.append(self._convert_id_to_token(index)) + return tokens + + def _convert_id_to_token(self, index: int) -> str: + raise NotImplementedError + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return " ".join(tokens) + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + spaces_between_special_tokens: bool = True, + **kwargs, + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + # If given is a single id, prevents splitting the string in upcoming loop + if isinstance(filtered_tokens, str): + filtered_tokens = [filtered_tokens] + + legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { + token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size + } + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_tokens: + continue + if token in legacy_added_tokens: + if current_sub_text: + string = self.convert_tokens_to_string(current_sub_text) + if len(string) > 0: + sub_texts.append(string) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + + if spaces_between_special_tokens: + text = " ".join(sub_texts) + else: + text = "".join(sub_texts) + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text diff --git a/tokenization_utils_base.py b/tokenization_utils_base.py new file mode 100644 index 0000000000000000000000000000000000000000..86e07a382f8812f0b6b790139a003e9ce3bb1532 --- /dev/null +++ b/tokenization_utils_base.py @@ -0,0 +1,4157 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user +fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary +of output with special method for the Fast tokenizers) +""" + +import copy +import json +import os +import re +import warnings +from collections import UserDict +from collections.abc import Mapping, Sized +from contextlib import contextmanager +from dataclasses import dataclass +from inspect import isfunction +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +import numpy as np +from packaging import version + +from . import __version__ +from .dynamic_module_utils import custom_object_save +from .utils import ( + ExplicitEnum, + PaddingStrategy, + PushToHubMixin, + TensorType, + add_end_docstrings, + add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, + cached_file, + copy_func, + download_url, + extract_commit_hash, + get_json_schema, + is_flax_available, + is_jax_tensor, + is_mlx_available, + is_numpy_array, + is_offline_mode, + is_protobuf_available, + is_remote_url, + is_tf_available, + is_tf_tensor, + is_tokenizers_available, + is_torch_available, + is_torch_device, + is_torch_tensor, + logging, + requires_backends, + to_py_obj, +) +from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices +from .utils.import_utils import PROTOBUF_IMPORT_ERROR + + +if TYPE_CHECKING: + if is_torch_available(): + import torch + if is_tf_available(): + import tensorflow as tf + if is_flax_available(): + import jax.numpy as jnp # noqa: F401 + + +def import_protobuf_decode_error(error_message=""): + if is_protobuf_available(): + from google.protobuf.message import DecodeError + + return DecodeError + else: + raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) + + +if is_tokenizers_available(): + from tokenizers import AddedToken + from tokenizers import Encoding as EncodingFast +else: + + @dataclass(frozen=False, eq=True) + class AddedToken: + """ + AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the + way it should behave. + + The `normalized` will default to `not special` if it is not specified, similarly to the definition in + `tokenizers`. + """ + + def __init__( + self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None + ): + self.content = content + self.single_word = single_word + self.lstrip = lstrip + self.rstrip = rstrip + self.special = special + self.normalized = normalized if normalized is not None else not special + + def __getstate__(self): + return self.__dict__ + + def __str__(self): + return self.content + + @dataclass + class EncodingFast: + """This is dummy class because without the `tokenizers` library we don't have these objects anyway""" + + pass + + +logger = logging.get_logger(__name__) + +VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input +LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER + +# Define type aliases and NamedTuples +TextInput = str +PreTokenizedInput = List[str] +EncodedInput = List[int] +TextInputPair = Tuple[str, str] +PreTokenizedInputPair = Tuple[List[str], List[str]] +EncodedInputPair = Tuple[List[int], List[int]] + +# Define type aliases for text-related non-text modalities +AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch.Tensor"]] + +# Slow tokenizers used to be saved in three separated files +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +ADDED_TOKENS_FILE = "added_tokens.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +CHAT_TEMPLATE_FILE = "chat_template.jinja" + +# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file +FULL_TOKENIZER_FILE = "tokenizer.json" +_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json") + + +class TruncationStrategy(ExplicitEnum): + """ + Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in + an IDE. + """ + + ONLY_FIRST = "only_first" + ONLY_SECOND = "only_second" + LONGEST_FIRST = "longest_first" + DO_NOT_TRUNCATE = "do_not_truncate" + + +class CharSpan(NamedTuple): + """ + Character span in the original string. + + Args: + start (`int`): Index of the first character in the original string. + end (`int`): Index of the character following the last character in the original string. + """ + + start: int + end: int + + +class TokenSpan(NamedTuple): + """ + Token span in an encoded string (list of tokens). + + Args: + start (`int`): Index of the first token in the span. + end (`int`): Index of the token following the last token in the span. + """ + + start: int + end: int + + +class BatchEncoding(UserDict): + """ + Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`], + [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and + [`~tokenization_utils_base.PreTrainedTokenizerBase.batch_encode_plus`] methods (tokens, attention_masks, etc). + + This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes + utility methods to map from word/character space to token space. + + Args: + data (`dict`, *optional*): + Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods + ('input_ids', 'attention_mask', etc.). + encoding (`tokenizers.Encoding` or `Sequence[tokenizers.Encoding]`, *optional*): + If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character + space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this + information. + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + prepend_batch_axis (`bool`, *optional*, defaults to `False`): + Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this + parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*. + n_sequences (`Optional[int]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + def __init__( + self, + data: Optional[Dict[str, Any]] = None, + encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, + tensor_type: Union[None, str, TensorType] = None, + prepend_batch_axis: bool = False, + n_sequences: Optional[int] = None, + ): + super().__init__(data) + + if isinstance(encoding, EncodingFast): + encoding = [encoding] + + self._encodings = encoding + + if n_sequences is None and encoding is not None and len(encoding): + n_sequences = encoding[0].n_sequences + + self._n_sequences = n_sequences + + self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) + + @property + def n_sequences(self) -> Optional[int]: + """ + `Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this + [`BatchEncoding`]. Currently can be one of `None` (unknown), `1` (a single sentence) or `2` (a pair of + sentences) + """ + return self._n_sequences + + @property + def is_fast(self) -> bool: + """ + `bool`: Indicate whether this [`BatchEncoding`] was generated from the result of a [`PreTrainedTokenizerFast`] + or not. + """ + return self._encodings is not None + + def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]: + """ + If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', + etc.). + + If the key is an integer, get the `tokenizers.Encoding` for batch item with index `key`. + + If the key is a slice, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', etc.) + with the constraint of slice. + """ + if isinstance(item, str): + return self.data[item] + elif self._encodings is not None: + return self._encodings[item] + elif isinstance(item, slice): + return {key: self.data[key][item] for key in self.data.keys()} + else: + raise KeyError( + "Invalid key. Only three types of key are available: " + "(1) string, (2) integers for backend Encoding, and (3) slices for data subsetting." + ) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data, "encodings": self._encodings} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + if "encodings" in state: + self._encodings = state["encodings"] + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + # After this point: + # Extended properties and methods only available for fast (Rust-based) tokenizers + # provided by HuggingFace tokenizers library. + + @property + def encodings(self) -> Optional[List[EncodingFast]]: + """ + `Optional[List[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if + the input was tokenized through Python (i.e., not a fast) tokenizer. + """ + return self._encodings + + def tokens(self, batch_index: int = 0) -> List[str]: + """ + Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to + integer indices) at a given batch index (only works for the output of a fast tokenizer). + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `List[str]`: The list of tokens at that index. + """ + if not self._encodings: + raise ValueError( + "tokens() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + return self._encodings[batch_index].tokens + + def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]: + """ + Return a list mapping the tokens to the id of their original sentences: + + - `None` for special tokens added around or between sequences, + - `0` for tokens corresponding to words in the first sequence, + - `1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly + encoded. + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `List[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added + by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding + sequence. + """ + if not self._encodings: + raise ValueError( + "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + return self._encodings[batch_index].sequence_ids + + def words(self, batch_index: int = 0) -> List[Optional[int]]: + """ + Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the + tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word + (several tokens will be mapped to the same word index if they are parts of that word). + """ + if not self._encodings: + raise ValueError( + "words() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + warnings.warn( + "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " + "but more self-explanatory `BatchEncoding.word_ids()` property.", + FutureWarning, + ) + return self.word_ids(batch_index) + + def word_ids(self, batch_index: int = 0) -> List[Optional[int]]: + """ + Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `List[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the + tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word + (several tokens will be mapped to the same word index if they are parts of that word). + """ + if not self._encodings: + raise ValueError( + "word_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + return self._encodings[batch_index].word_ids + + def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ + Get the index of the sequence represented by the given token. In the general use case, this method returns `0` + for a single sequence or the first sequence of a pair, and `1` for the second sequence of a pair + + Can be called as: + + - `self.token_to_sequence(token_index)` if batch size is 1 + - `self.token_to_sequence(batch_index, token_index)` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., + words are defined by the user). In this case it allows to easily associate encoded tokens with provided + tokenized words. + + Args: + batch_or_token_index (`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of + the token in the sequence. + token_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the + sequence. + + Returns: + `int`: Index of the word in the input sequence. + """ + + if not self._encodings: + raise ValueError("token_to_sequence() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_sequence(token_index) + + def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ + Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch. + + Can be called as: + + - `self.token_to_word(token_index)` if batch size is 1 + - `self.token_to_word(batch_index, token_index)` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., + words are defined by the user). In this case it allows to easily associate encoded tokens with provided + tokenized words. + + Args: + batch_or_token_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the token in the sequence. + token_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the + sequence. + + Returns: + `int`: Index of the word in the input sequence. + """ + + if not self._encodings: + raise ValueError("token_to_word() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_word(token_index) + + def word_to_tokens( + self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 + ) -> Optional[TokenSpan]: + """ + Get the encoded token span corresponding to a word in a sequence of the batch. + + Token spans are returned as a [`~tokenization_utils_base.TokenSpan`] with: + + - **start** -- Index of the first token. + - **end** -- Index of the token following the last token. + + Can be called as: + + - `self.word_to_tokens(word_index, sequence_index: int = 0)` if batch size is 1 + - `self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)` if batch size is greater or equal to + 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words + are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized + words. + + Args: + batch_or_word_index (`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of + the word in the sequence. + word_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the + sequence. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided word index belongs to. + + Returns: + ([`~tokenization_utils_base.TokenSpan`], *optional*): Span of tokens in the encoded sequence. Returns + `None` if no tokens correspond to the word. This can happen especially when the token is a special token + that has been used to format the tokenization. For example when we add a class token at the very beginning + of the tokenization. + """ + + if not self._encodings: + raise ValueError("word_to_tokens() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if word_index < 0: + word_index = self._seq_len + word_index + span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index) + return TokenSpan(*span) if span is not None else None + + def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: + """ + Get the character span corresponding to an encoded token in a sequence of the batch. + + Character spans are returned as a [`~tokenization_utils_base.CharSpan`] with: + + - **start** -- Index of the first character in the original string associated to the token. + - **end** -- Index of the character following the last character in the original string associated to the + token. + + Can be called as: + + - `self.token_to_chars(token_index)` if batch size is 1 + - `self.token_to_chars(batch_index, token_index)` if batch size is greater or equal to 1 + + Args: + batch_or_token_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the token in the sequence. + token_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the token or tokens in + the sequence. + + Returns: + [`~tokenization_utils_base.CharSpan`]: Span of characters in the original string, or None, if the token + (e.g. , ) doesn't correspond to any chars in the origin string. + """ + + if not self._encodings: + raise ValueError("token_to_chars() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + span_indices = self._encodings[batch_index].token_to_chars(token_index) + + return CharSpan(*span_indices) if span_indices is not None else None + + def char_to_token( + self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0 + ) -> int: + """ + Get the index of the token in the encoded output comprising a character in the original string for a sequence + of the batch. + + Can be called as: + + - `self.char_to_token(char_index)` if batch size is 1 + - `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words + are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized + words. + + Args: + batch_or_char_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the word in the sequence + char_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the + sequence. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided character index belongs to. + + + Returns: + `int`: Index of the token, or None if the char index refers to a whitespace only token and whitespace is + trimmed with `trim_offsets=True`. + """ + + if not self._encodings: + raise ValueError("char_to_token() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_token(char_index, sequence_index) + + def word_to_chars( + self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 + ) -> CharSpan: + """ + Get the character span in the original string corresponding to given word in a sequence of the batch. + + Character spans are returned as a CharSpan NamedTuple with: + + - start: index of the first character in the original string + - end: index of the character following the last character in the original string + + Can be called as: + + - `self.word_to_chars(word_index)` if batch size is 1 + - `self.word_to_chars(batch_index, word_index)` if batch size is greater or equal to 1 + + Args: + batch_or_word_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the word in the sequence + word_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the + sequence. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided word index belongs to. + + Returns: + `CharSpan` or `List[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan + are NamedTuple with: + + - start: index of the first character associated to the token in the original string + - end: index of the character following the last character associated to the token in the original + string + """ + + if not self._encodings: + raise ValueError("word_to_chars() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index))) + + def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: + """ + Get the word in the original string corresponding to a character in the original string of a sequence of the + batch. + + Can be called as: + + - `self.char_to_word(char_index)` if batch size is 1 + - `self.char_to_word(batch_index, char_index)` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words + are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized + words. + + Args: + batch_or_char_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the character in the original string. + char_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the character in the + original string. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided character index belongs to. + + + Returns: + `int` or `List[int]`: Index or indices of the associated encoded token(s). + """ + + if not self._encodings: + raise ValueError("char_to_word() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_word(char_index, sequence_index) + + def convert_to_tensors( + self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False + ): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + prepend_batch_axis (`int`, *optional*, defaults to `False`): + Whether or not to add the batch dimension during the conversion. + """ + if tensor_type is None: + return self + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + as_tensor = tf.constant + is_tensor = tf.is_tensor + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch + + is_tensor = torch.is_tensor + + def as_tensor(value, dtype=None): + if isinstance(value, list) and isinstance(value[0], np.ndarray): + return torch.from_numpy(np.array(value)) + return torch.tensor(value) + + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + as_tensor = jnp.array + is_tensor = is_jax_tensor + + elif tensor_type == TensorType.MLX: + if not is_mlx_available(): + raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.") + import mlx.core as mx + + as_tensor = mx.array + + def is_tensor(obj): + return isinstance(obj, mx.array) + else: + + def as_tensor(value, dtype=None): + if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)): + value_lens = [len(val) for val in value] + if len(set(value_lens)) > 1 and dtype is None: + # we have a ragged list so handle explicitly + value = as_tensor([np.asarray(val) for val in value], dtype=object) + return np.asarray(value, dtype=dtype) + + is_tensor = is_numpy_array + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if prepend_batch_axis: + value = [value] + + if not is_tensor(value): + tensor = as_tensor(value) + + # Removing this for now in favor of controlling the shape with `prepend_batch_axis` + # # at-least2d + # if tensor.ndim > 2: + # tensor = tensor.squeeze(0) + # elif tensor.ndim < 2: + # tensor = tensor[None, :] + + self[key] = tensor + except Exception as e: + if key == "overflowing_tokens": + raise ValueError( + "Unable to create tensor returning overflowing tokens of different lengths. " + "Please see if a fast version of this tokenizer is available to have this feature available." + ) from e + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding with" + " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your" + f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is" + " expected)." + ) from e + + return self + + def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": + """ + Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). + + Args: + device (`str` or `torch.device`): The device to put the tensors on. + non_blocking (`bool`): Whether to perform the copy asynchronously. + + Returns: + [`BatchEncoding`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + import torch + + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): + self.data = { + k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in self.data.items() + } + else: + logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") + return self + + +class SpecialTokensMixin: + """ + A mixin derived by [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] to handle specific behaviors related to + special tokens. In particular, this class hold the attributes which can be used to directly access these special + tokens in a model-independent manner and allow to set and update the special tokens. + + Args: + bos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the beginning of a sentence. + eos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the end of a sentence. + unk_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing an out-of-vocabulary token. + sep_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token separating two different sentences in the same input (used by BERT for instance). + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + cls_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the class of the input (used by BERT for instance). + mask_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing a masked token (used by masked-language modeling pretraining objectives, like + BERT). + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be + skipped when decoding if `skip_special_tokens` is set to `True`. + """ + + SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] + + def __init__(self, verbose=False, **kwargs): + self._pad_token_type_id = 0 + self.verbose = verbose + self._special_tokens_map = {attr: None for attr in self.SPECIAL_TOKENS_ATTRIBUTES} + self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list + + # We directly set the hidden value to allow initialization with special tokens + # which are not yet in the vocabulary. Necessary for serialization/de-serialization + # TODO clean this up at some point (probably by switching to fast tokenizers) + + for key, value in kwargs.items(): + if value is None: + continue + if key in self.SPECIAL_TOKENS_ATTRIBUTES: + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" + assert all( + isinstance(t, (str, AddedToken)) for t in value + ), "One of the tokens is not a string or an AddedToken" + setattr(self, key, value) + elif isinstance(value, (str, AddedToken)): + setattr(self, key, value) + else: + raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") + + def sanitize_special_tokens(self) -> int: + """ + The `sanitize_special_tokens` is now deprecated kept for backward compatibility and will be removed in + transformers v5. + """ + logger.warning_once("The `sanitize_special_tokens` will be removed in transformers v5.") + return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) + + def add_special_tokens( + self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True + ) -> int: + """ + Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If + special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the + current vocabulary). + + When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the + model so that its embedding matrix matches the tokenizer. + + In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. + + Using `add_special_tokens` will ensure your special tokens can be used in several ways: + + - Special tokens can be skipped when decoding using `skip_special_tokens = True`. + - Special tokens are carefully handled by the tokenizer (they are never split), similar to `AddedTokens`. + - You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This + makes it easy to develop model-agnostic training and fine-tuning scripts. + + When possible, special tokens are already registered for provided pretrained models (for instance + [`BertTokenizer`] `cls_token` is already registered to be :obj*'[CLS]'* and XLM's one is also registered to be + `''`). + + Args: + special_tokens_dict (dictionary *str* to *str* or `tokenizers.AddedToken`): + Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`, + `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`]. + + Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer + assign the index of the `unk_token` to them). + replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`): + If `True`, the existing list of additional special tokens will be replaced by the list provided in + `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former + case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged + as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the + `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous + `additional_special_tokens` are still added tokens, and will not be split by the model. + + Returns: + `int`: Number of tokens added to the vocabulary. + + Examples: + + ```python + # Let's see how to add a new classification token to GPT-2 + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = GPT2Model.from_pretrained("openai-community/gpt2") + + special_tokens_dict = {"cls_token": ""} + + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + print("We have added", num_added_toks, "tokens") + # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + + assert tokenizer.cls_token == "" + ```""" + if not special_tokens_dict: + return 0 + + added_tokens = [] + for key, value in special_tokens_dict.items(): + assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token" + + if self.verbose: + logger.info(f"Assigning {value} to the {key} key of the tokenizer") + + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)) and all( + isinstance(t, (str, AddedToken)) for t in value + ), f"Tokens {value} for key {key} should all be str or AddedToken instances" + + to_add = [] + for token in value: + if isinstance(token, str): + # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this + token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) + if not replace_additional_special_tokens and str(token) in self.additional_special_tokens: + continue + to_add.append(token) + if replace_additional_special_tokens and len(to_add) > 0: + setattr(self, key, list(to_add)) + else: + self._special_tokens_map["additional_special_tokens"].extend(to_add) + added_tokens += to_add + + else: + if not isinstance(value, (str, AddedToken)): + raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance") + if isinstance(value, (str)): + # for legacy purpose we default to stripping. `False` depends on this + value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True) + if isinstance(value, AddedToken): + setattr(self, key, value) + if value not in added_tokens: + added_tokens.append(value) + + # if we are adding tokens that were not part of the vocab, we ought to add them + added_tokens = self.add_tokens(added_tokens, special_tokens=True) + return added_tokens + + def add_tokens( + self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False + ) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to + it with indices starting from length of the current vocabulary and will be isolated before the tokenization + algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore + not treated in the same way. + + Note, when adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix + of the model so that its embedding matrix matches the tokenizer. + + In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. + + Args: + new_tokens (`str`, `tokenizers.AddedToken` or a list of *str* or `tokenizers.AddedToken`): + Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string + token to let you personalize its behavior: whether this token should only match against a single word, + whether this token should strip all potential whitespaces on the left side, whether this token should + strip all potential whitespaces on the right side, etc. + special_tokens (`bool`, *optional*, defaults to `False`): + Can be used to specify if the token is a special token. This mostly change the normalization behavior + (special tokens like CLS or [MASK] are usually not lower-cased for instance). + + See details for `tokenizers.AddedToken` in HuggingFace tokenizers library. + + Returns: + `int`: Number of tokens added to the vocabulary. + + Examples: + + ```python + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased") + model = BertModel.from_pretrained("google-bert/bert-base-uncased") + + num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"]) + print("We have added", num_added_toks, "tokens") + # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + ```""" + if not new_tokens: + return 0 + + if not isinstance(new_tokens, (list, tuple)): + new_tokens = [new_tokens] + + return self._add_tokens(new_tokens, special_tokens=special_tokens) + + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + raise NotImplementedError + + @property + def pad_token_type_id(self) -> int: + """ + `int`: Id of the padding token type in the vocabulary. + """ + return self._pad_token_type_id + + def __setattr__(self, key, value): + key_without_id = key + key_is_special_id = key.endswith("_id") or key.endswith("_ids") + if key_is_special_id: + key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] + + if self.__dict__.get("_special_tokens_map", None) is not None and any( + name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] + ): + if key_is_special_id: + if value is not None: + value = ( + self.convert_ids_to_tokens(value) + if key != "additional_special_tokens" + else [self.convert_ids_to_tokens(val) for val in value] + ) + key = key_without_id + + if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None: + raise ValueError(f"Cannot set a non-string value as the {key}") + self._special_tokens_map[key] = value + else: + super().__setattr__(key, value) + + def __getattr__(self, key): + key_without_id = key + key_is_special_id = key.endswith("_id") or key.endswith("_ids") + if key_is_special_id: + key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] + + if self.__dict__.get("_special_tokens_map", None) is not None and any( + name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] + ): + _special_tokens_map = self.__dict__["_special_tokens_map"] + if not key_is_special_id: + if _special_tokens_map[key] is None: + if self.verbose: + logger.error(f"Using {key}, but it is not set yet.") + return None + value = _special_tokens_map[key] + return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value] + else: + attr_as_tokens = getattr(self, key_without_id) + return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None + + if key not in self.__dict__: + raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") + else: + return super().__getattr__(key) + + @property + def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]: + """ + `Dict[str, Union[str, List[str]]]`: A dictionary mapping special token class attributes (`cls_token`, + `unk_token`, etc.) to their values (`''`, `''`, etc.). + + Convert potential tokens of `tokenizers.AddedToken` type to string. + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = getattr(self, attr) + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]: + """ + `Dict[str, Union[str, tokenizers.AddedToken, List[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping + special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.). + + Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how + special tokens are tokenized. + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = self._special_tokens_map[attr] + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]: + """ + `List[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has + nothing to do with the index of each tokens. If you want to know the correct indices, check + `self.added_tokens_encoder`. We can't create an order anymore as the keys are `AddedTokens` and not `Strings`. + + Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how + special tokens are tokenized. + """ + all_tokens = [] + seen = set() + for value in self.special_tokens_map_extended.values(): + if isinstance(value, (list, tuple)): + tokens_to_add = [token for token in value if str(token) not in seen] + else: + tokens_to_add = [value] if str(value) not in seen else [] + seen.update(map(str, tokens_to_add)) + all_tokens.extend(tokens_to_add) + return all_tokens + + @property + def all_special_tokens(self) -> List[str]: + """ + `List[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). + + Convert tokens of `tokenizers.AddedToken` type to string. + """ + all_toks = [str(s) for s in self.all_special_tokens_extended] + return all_toks + + @property + def all_special_ids(self) -> List[int]: + """ + `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. + """ + all_toks = self.all_special_tokens + all_ids = self.convert_tokens_to_ids(all_toks) + return all_ids + + def _set_model_specific_special_tokens(self, special_tokens: List[str]): + """ + Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part + of "self.special_tokens" and saved as a special token in tokenizer's config. + This allows us to dynamically add new model-type specific tokens after initilizing the tokenizer. + For example: if the model tokenizers is multimodal, we can support special image or audio tokens. + """ + self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys()) + for key, value in special_tokens.items(): + if isinstance(value, (str, AddedToken)): + self._special_tokens_map[key] = value + else: + raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") + + +ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to add special tokens when encoding the sequences. This will use the underlying + `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are + automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens + automatically. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`) +""" + + +INIT_TOKENIZER_DOCSTRING = r""" + Class attributes (overridden by derived classes) + + - **vocab_files_names** (`Dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each + vocabulary file required by the model, and as associated values, the filename for saving the associated file + (string). + - **pretrained_vocab_files_map** (`Dict[str, Dict[str, str]]`) -- A dictionary of dictionaries, with the + high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the + low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the + associated pretrained vocabulary file. + - **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model. + - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied. + Should be `'right'` or `'left'`. + - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation + applied. Should be `'right'` or `'left'`. + + Args: + model_max_length (`int`, *optional*): + The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is + loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the + value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will + default to VERY_LARGE_INTEGER (`int(1e30)`). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + truncation_side (`str`, *optional*): + The side on which the model should have truncation applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + chat_template (`str`, *optional*): + A Jinja template string that will be used to format lists of chat messages. See + https://huggingface.co/docs/transformers/chat_templating for a full description. + model_input_names (`List[string]`, *optional*): + The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or + `"attention_mask"`). Default value is picked from the class attribute of the same name. + bos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the beginning of a sentence. Will be associated to `self.bos_token` and + `self.bos_token_id`. + eos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the end of a sentence. Will be associated to `self.eos_token` and + `self.eos_token_id`. + unk_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing an out-of-vocabulary token. Will be associated to `self.unk_token` and + `self.unk_token_id`. + sep_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token separating two different sentences in the same input (used by BERT for instance). Will be + associated to `self.sep_token` and `self.sep_token_id`. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. Will be associated to `self.pad_token` and `self.pad_token_id`. + cls_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the class of the input (used by BERT for instance). Will be associated to + `self.cls_token` and `self.cls_token_id`. + mask_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing a masked token (used by masked-language modeling pretraining objectives, like + BERT). Will be associated to `self.mask_token` and `self.mask_token_id`. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. Add them here to ensure they are skipped when decoding with + `skip_special_tokens` is set to True. If they are not part of the vocabulary, they will be added at the end + of the vocabulary. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. Passing will affect the + internal state of the tokenizer. The default behavior is to not split special tokens. This means that if + `` is the `bos_token`, then `tokenizer.tokenize("") = ['`]. Otherwise, if + `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<','s', '>']`. +""" + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): + """ + Base class for [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`]. + + Handles shared (mostly boiler plate) methods for those two classes. + """ + + vocab_files_names: Dict[str, str] = {} + pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {} + _auto_class: Optional[str] = None + + # first name has to correspond to main model input name + # to make sure `tokenizer.pad(...)` works correctly + model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"] + padding_side: str = "right" + truncation_side: str = "right" + slow_tokenizer_class = None + + def __init__(self, **kwargs): + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + for key in kwargs: + if hasattr(self, key) and callable(getattr(self, key)): + raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") + + self.init_kwargs = copy.deepcopy(kwargs) + self.name_or_path = kwargs.pop("name_or_path", "") + self._processor_class = kwargs.pop("processor_class", None) + + # For backward compatibility we fallback to set model_max_length from max_len if provided + model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) + self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER + + # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it + # is changed. + self.padding_side = kwargs.pop("padding_side", self.padding_side) + if self.padding_side not in ["right", "left"]: + raise ValueError( + f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" + ) + + self.truncation_side = kwargs.pop("truncation_side", self.truncation_side) + if self.truncation_side not in ["right", "left"]: + raise ValueError( + f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}" + ) + + self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) + + # By default, cleaning tokenization spaces for both fast and slow tokenizers + self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False) + + # By default, do not split special tokens for both fast and slow tokenizers + self.split_special_tokens = kwargs.pop("split_special_tokens", False) + + self.deprecation_warnings = {} # Use to store when we have already noticed a deprecation warning (avoid overlogging). + self._in_target_context_manager = False + + # Stores a Jinja template that formats chat histories into tokenizable strings + self.chat_template = kwargs.pop("chat_template", None) + if isinstance(self.chat_template, (list, tuple)): + # Chat templates are stored as lists of dicts with fixed key names, + # we reconstruct that into a single dict while loading them. + self.chat_template = {template["name"]: template["template"] for template in self.chat_template} + + super().__init__(**kwargs) + + self.extra_special_tokens = kwargs.pop("extra_special_tokens", {}) + self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens) + + @property + def max_len_single_sentence(self) -> int: + """ + `int`: The maximum length of a sentence that can be fed to the model. + """ + return self.model_max_length - self.num_special_tokens_to_add(pair=False) + + @property + def max_len_sentences_pair(self) -> int: + """ + `int`: The maximum combined length of a pair of sentences that can be fed to the model. + """ + return self.model_max_length - self.num_special_tokens_to_add(pair=True) + + @max_len_single_sentence.setter + def max_len_single_sentence(self, value) -> int: + # For backward compatibility, allow to try to setup 'max_len_single_sentence'. + if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose: + if not self.deprecation_warnings.get("max_len_single_sentence", False): + logger.warning( + "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." + ) + self.deprecation_warnings["max_len_single_sentence"] = True + else: + raise ValueError( + "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." + ) + + @max_len_sentences_pair.setter + def max_len_sentences_pair(self, value) -> int: + # For backward compatibility, allow to try to setup 'max_len_sentences_pair'. + if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose: + if not self.deprecation_warnings.get("max_len_sentences_pair", False): + logger.warning( + "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up." + ) + self.deprecation_warnings["max_len_sentences_pair"] = True + else: + raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.") + + def _set_processor_class(self, processor_class: str): + """Sets processor class as an attribute.""" + self._processor_class = processor_class + + @property + def added_tokens_decoder(self) -> Dict[int, AddedToken]: + raise NotImplementedError() + + def __repr__(self) -> str: + added_tokens_decoder_rep = "\n\t".join([f"{k}: {v.__repr__()}," for k, v in self.added_tokens_decoder.items()]) + return ( + f"{self.__class__.__name__}(name_or_path='{self.name_or_path}'," + f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast}," + f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}'," + f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}," + " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)" + ) + + def __len__(self) -> int: + raise NotImplementedError() + + def get_vocab(self) -> Dict[str, int]: + """ + Returns the vocabulary as a dictionary of token to index. + + `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the + vocab. + + Returns: + `Dict[str, int]`: The vocabulary. + """ + raise NotImplementedError() + + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + tools: Optional[List[Union[Dict, Callable]]] = None, + documents: Optional[List[Dict[str, str]]] = None, + chat_template: Optional[str] = None, + add_generation_prompt: bool = False, + continue_final_message: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_dict: bool = False, + return_assistant_tokens_mask: bool = False, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: + """ + Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token + ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to + determine the format and control tokens to use when converting. + + Args: + conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts + with "role" and "content" keys, representing the chat history so far. + tools (`List[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. + documents (`List[Dict[str, str]]`, *optional*): + A list of dicts representing documents that will be accessible to the model if it is performing RAG + (retrieval-augmented generation). If the template does not support RAG, this argument will have no + effect. We recommend that each document should be a dict containing "title" and "text" keys. Please + see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) + for examples of passing documents with chat templates. + chat_template (`str`, *optional*): + A Jinja template to use for this conversion. It is usually not necessary to pass anything to this + argument, as the model's template will be used by default. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. + tokenize (`bool`, defaults to `True`): + Whether to tokenize the output. If `False`, the output will be a string. + padding (`bool`, defaults to `False`): + Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. + truncation (`bool`, defaults to `False`): + Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. + max_length (`int`, *optional*): + Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If + not specified, the tokenizer's `max_length` attribute will be used as a default. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable + values are: + - `'tf'`: Return TensorFlow `tf.Tensor` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + + Returns: + `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is + set, will return a dict of tokenizer outputs instead. + """ + + if return_dict and not tokenize: + raise ValueError( + "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " + "of tokenizer outputs to return." + ) + + if return_assistant_tokens_mask and not return_dict: + raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`") + + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + + chat_template = self.get_chat_template(chat_template, tools) + + if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template): + logger.warning_once( + "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword." + ) + + # Compilation function uses a cache to avoid recompiling the same template + compiled_template = _compile_jinja_template(chat_template) + + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") + ): + conversations = conversation + is_batched = True + else: + conversations = [conversation] + is_batched = False + + if continue_final_message: + if add_generation_prompt: + raise ValueError( + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + ) + if return_assistant_tokens_mask: + raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + + # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas + if tools is not None: + tool_schemas = [] + for tool in tools: + if isinstance(tool, dict): + tool_schemas.append(tool) + elif isfunction(tool): + tool_schemas.append(get_json_schema(tool)) + else: + raise ValueError( + "Tools should either be a JSON schema, or a callable function with type hints " + "and a docstring suitable for auto-conversion to a schema." + ) + else: + tool_schemas = None + + if documents is not None: + for document in documents: + if not isinstance(document, dict): + raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") + + rendered = [] + all_generation_indices = [] + template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present + for chat in conversations: + if hasattr(chat, "messages"): + # Indicates it's a Conversation object + chat = chat.messages + if return_assistant_tokens_mask: + rendered_chat, generation_indices = _render_with_assistant_indices( + compiled_template=compiled_template, + messages=chat, + tools=tool_schemas, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ) + all_generation_indices.append(generation_indices) + else: + rendered_chat = compiled_template.render( + messages=chat, + tools=tool_schemas, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ) + if continue_final_message: + final_message = chat[-1]["content"] + if isinstance(final_message, (list, tuple)): + final_message = final_message[-1]["text"] + try: + rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)] + except: # noqa: E722 + # Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here. + final_message = final_message.strip() + rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)] + rendered.append(rendered_chat) + + if not is_batched: + rendered = rendered[0] + + if tokenize: + out = self( + rendered, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + if return_dict: + if return_assistant_tokens_mask: + assistant_masks = [] + if is_batched or return_tensors: + input_ids = out["input_ids"] + else: + input_ids = [out["input_ids"]] + for i in range(len(input_ids)): + current_mask = [0] * len(input_ids[i]) + for assistant_start_char, assistant_end_char in all_generation_indices[i]: + start_token = out.char_to_token(i, assistant_start_char) + end_token = out.char_to_token(i, assistant_end_char - 1) + if start_token is None: + # start_token is out of bounds maybe due to truncation. + break + for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): + current_mask[token_id] = 1 + assistant_masks.append(current_mask) + out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0] + return out + else: + return out["input_ids"] + else: + return rendered + + def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str: + """ + Retrieve the chat template string used for tokenizing chat messages. This template is used + internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat + template for better generation tracking. + + Args: + chat_template (`str`, *optional*): + A Jinja template or the name of a template to use for this conversion. + It is usually not necessary to pass anything to this argument, + as the model's template will be used by default. + tools (`List[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. + + Returns: + `str`: The chat template string. + """ + # First, handle the cases when the model has a dict of multiple templates + if isinstance(self.chat_template, dict): + template_dict = self.chat_template + if chat_template is not None and chat_template in template_dict: + # The user can pass the name of a template to the chat template argument instead of an entire template + chat_template = template_dict[chat_template] + elif chat_template is None: + if tools is not None and "tool_use" in template_dict: + chat_template = template_dict["tool_use"] + elif "default" in template_dict: + chat_template = template_dict["default"] + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template_dict.keys())}." + ) + + elif chat_template is None: + # These are the cases when the model has a single template + # priority: `chat_template` argument > `tokenizer.chat_template` + if self.chat_template is not None: + chat_template = self.chat_template + else: + raise ValueError( + "Cannot use chat template functions because tokenizer.chat_template is not set and no template " + "argument was passed! For information about writing templates and setting the " + "tokenizer.chat_template attribute, please see the documentation at " + "https://huggingface.co/docs/transformers/main/en/chat_templating" + ) + + return chat_template + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + *init_inputs, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + trust_remote_code=False, + **kwargs, + ): + r""" + Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined + tokenizer. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g., + `./my_model_directory/`. + - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary + file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g., + `./my_model_directory/vocab.txt`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download the vocabulary files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only rely on local files and not to attempt to download any files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for + facebook/rag-token-base), specify it here. + inputs (additional positional arguments, *optional*): + Will be passed along to the Tokenizer `__init__` method. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (additional keyword arguments, *optional*): + Will be passed to the Tokenizer `__init__` method. Can be used to set special tokens like `bos_token`, + `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`, + `additional_special_tokens`. See parameters in the `__init__` for more details. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer + # Download vocabulary from huggingface.co and cache. + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + + # Download vocabulary from huggingface.co (user-uploaded) and cache. + tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-german-cased") + + # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*) + tokenizer = BertTokenizer.from_pretrained("./test/saved_model/") + + # If the tokenizer uses a single vocabulary file, you can point directly to this file + tokenizer = BertTokenizer.from_pretrained("./test/saved_model/my_vocab.txt") + + # You can link tokens to special vocabulary when instantiating + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", unk_token="") + # You should be sure '' is in the vocabulary when doing that. + # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) + assert tokenizer.unk_token == "" + ```""" + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + subfolder = kwargs.pop("subfolder", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + commit_hash = kwargs.pop("_commit_hash", None) + gguf_file = kwargs.get("gguf_file", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + vocab_files = {} + init_configuration = {} + + is_local = os.path.isdir(pretrained_model_name_or_path) + single_file_id = None + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + if len(cls.vocab_files_names) > 1 and not gguf_file: + raise ValueError( + f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " + "supported for this tokenizer. Use a model identifier or the path to a directory instead." + ) + warnings.warn( + f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " + "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", + FutureWarning, + ) + file_id = list(cls.vocab_files_names.keys())[0] + + vocab_files[file_id] = pretrained_model_name_or_path + single_file_id = file_id + else: + if gguf_file: + vocab_files["vocab_file"] = gguf_file + else: + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders + "tokenizer_file": FULL_TOKENIZER_FILE, + "chat_template_file": CHAT_TEMPLATE_FILE, + } + vocab_files = {**cls.vocab_files_names, **additional_files_names} + if "tokenizer_file" in vocab_files: + # Try to get the tokenizer config to see if there are versioned tokenizer files. + fast_tokenizer_file = FULL_TOKENIZER_FILE + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + if resolved_config_file is not None: + with open(resolved_config_file, encoding="utf-8") as reader: + tokenizer_config = json.load(reader) + if "fast_tokenizer_files" in tokenizer_config: + fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) + vocab_files["tokenizer_file"] = fast_tokenizer_file + + # Get files from url, cache, or disk depending on the case + resolved_vocab_files = {} + unresolved_files = [] + for file_id, file_path in vocab_files.items(): + if file_path is None: + resolved_vocab_files[file_id] = None + elif single_file_id == file_id: + if os.path.isfile(file_path): + resolved_vocab_files[file_id] = file_path + elif is_remote_url(file_path): + resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies) + else: + resolved_vocab_files[file_id] = cached_file( + pretrained_model_name_or_path, + file_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, + ) + commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) + + if len(unresolved_files) > 0: + logger.info( + f"Can't load following files from cache: {unresolved_files} and cannot check if these " + "files are necessary for the tokenizer to operate." + ) + + # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be + # loaded directly from the GGUF file. + if all(full_file_name is None for full_file_name in resolved_vocab_files.values()) and not gguf_file: + raise EnvironmentError( + f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing all relevant files for a {cls.__name__} tokenizer." + ) + + for file_id, file_path in vocab_files.items(): + if file_id not in resolved_vocab_files: + continue + + if is_local: + logger.info(f"loading file {file_path}") + else: + logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") + + return cls._from_pretrained( + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=commit_hash, + _is_local=is_local, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + @classmethod + def _from_pretrained( + cls, + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=None, + cache_dir=None, + local_files_only=False, + _commit_hash=None, + _is_local=False, + trust_remote_code=False, + **kwargs, + ): + # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json + # file or if `from_slow` is set to True. + from_slow = kwargs.get("from_slow", False) + gguf_file = kwargs.get("gguf_file", None) + has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None + + # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be + # loaded directly from the GGUF file. + if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file: + slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( + copy.deepcopy(resolved_vocab_files), + pretrained_model_name_or_path, + copy.deepcopy(init_configuration), + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=_commit_hash, + **(copy.deepcopy(kwargs)), + ) + else: + slow_tokenizer = None + + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) + if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: + init_kwargs = json.load(tokenizer_config_handle) + # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. + config_tokenizer_class = init_kwargs.get("tokenizer_class") + init_kwargs.pop("tokenizer_class", None) + if not has_tokenizer_file: + init_kwargs.pop("tokenizer_file", None) + saved_init_inputs = init_kwargs.pop("init_inputs", ()) + if not init_inputs: + init_inputs = saved_init_inputs + else: + config_tokenizer_class = None + init_kwargs = init_configuration + + # If an independent chat template file exists, it takes priority over template entries in the tokenizer config + chat_template_file = resolved_vocab_files.pop("chat_template_file", None) + if chat_template_file is not None: + with open(chat_template_file) as chat_template_handle: + init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config + + if not _is_local: + if "auto_map" in init_kwargs: + # For backward compatibility with odl format. + if isinstance(init_kwargs["auto_map"], (tuple, list)): + init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]} + init_kwargs["auto_map"] = add_model_info_to_auto_map( + init_kwargs["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in init_kwargs: + init_kwargs["custom_pipelines"] = add_model_info_to_custom_pipelines( + init_kwargs["custom_pipelines"], pretrained_model_name_or_path + ) + + if config_tokenizer_class is None: + # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo. + # If not, it raises a warning, but otherwise continues. Since we mostly load tokenizers with + # AutoTokenizer these days, it seems like a lot of work (and a source of bugs) for little gain. + # Maybe we can just remove this entirely? + from .models.auto.configuration_auto import AutoConfig # tests_ignore + + # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. + try: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + trust_remote_code=trust_remote_code, + _commit_hash=_commit_hash, + ) + config_tokenizer_class = config.tokenizer_class + except (OSError, ValueError, KeyError): + # skip if an error occurred. + config = None + if config_tokenizer_class is None: + # Third attempt. If we have not yet found the original type of the tokenizer, + # we are loading we see if we can infer it from the type of the configuration file + from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore + + if hasattr(config, "model_type"): + model_type = config.model_type + else: + # Fallback: use pattern matching on the string. + model_type = None + for pattern in TOKENIZER_MAPPING_NAMES.keys(): + if pattern in str(pretrained_model_name_or_path): + model_type = pattern + break + + if model_type is not None: + config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES.get( + model_type, (None, None) + ) + if config_tokenizer_class is None: + config_tokenizer_class = config_tokenizer_class_fast + + if config_tokenizer_class is not None: + if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): + logger.warning( + "The tokenizer class you load from this checkpoint is not the same type as the class this" + " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you" + f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called" + f" from is '{cls.__name__}'." + ) + + # Update with newly provided kwargs + init_kwargs.update(kwargs) + + # Merge resolved_vocab_files arguments in init_kwargs. + added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) + special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) + for args_name, file_path in resolved_vocab_files.items(): + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path + tokenizer_file = resolved_vocab_files.pop("tokenizer_file", None) + + if slow_tokenizer is not None: + init_kwargs["__slow_tokenizer"] = slow_tokenizer + init_kwargs["name_or_path"] = pretrained_model_name_or_path + + #### Handle tokenizer serialization of added and special tokens + added_tokens_decoder: Dict[int, AddedToken] = {} + added_tokens_map: Dict[str, AddedToken] = {} + # if we have info on the slow added tokens + if "added_tokens_decoder" in init_kwargs: + for idx, token in init_kwargs["added_tokens_decoder"].items(): + if isinstance(token, dict): + token = AddedToken(**token) + if isinstance(token, AddedToken): + added_tokens_decoder[int(idx)] = token + added_tokens_map[str(token)] = token + else: + raise ValueError( + f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance" + ) + else: + # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified + if special_tokens_map_file is not None: + with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: + special_tokens_map = json.load(special_tokens_map_handle) + for key, value in special_tokens_map.items(): + if key in kwargs and kwargs[key]: + # This value has already been redefined by the kwargs + # We keep this new value and ignore the one stored in the special_tokens_map_file + continue + if isinstance(value, dict): + value["special"] = True + value = AddedToken(**value) + elif key == "additional_special_tokens" and isinstance(value, list): + additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or [] + for token in value: + if isinstance(token, dict): + token["special"] = True + token = AddedToken(**token) + if token not in additional_special_tokens: + additional_special_tokens.append(token) + value = additional_special_tokens + init_kwargs[key] = value + + # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`. + # this is for legacy purpose. We don't add the tokens after init for efficiency. + if added_tokens_file is not None: + special_tokens = [] + for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys(): + if init_kwargs[key] is not None: + if key == "additional_special_tokens": + special_tokens += [str(token) for token in init_kwargs[key]] + else: + special_tokens.append(str(init_kwargs[key])) + + with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: + added_tok_encoder = json.load(added_tokens_handle) + for str_token, index in added_tok_encoder.items(): + # if index not in added_tokens_decoder and str_token not in added_tokens_map: + special = str_token in special_tokens + added_tokens_decoder[index] = AddedToken( + str_token, rstrip=False, lstrip=False, normalized=not special, special=special + ) + added_tokens_map[str(token)] = added_tokens_decoder[index] + + # allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer + # if `tokenizer_config.json` is `None` + if tokenizer_file is not None: + # This is for slow so can be done before + with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle: + tokenizer_file_handle = json.load(tokenizer_file_handle) + added_tokens = tokenizer_file_handle.pop("added_tokens") + for serialized_tokens in added_tokens: + idx = serialized_tokens.pop("id") + added_tokens_decoder[idx] = AddedToken(**serialized_tokens) + added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx] + # end legacy + + # Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken + # convert {'__type': 'AddedToken', 'content': '', 'lstrip': False, 'normalized': True, ...} to AddedTokens + init_kwargs["added_tokens_decoder"] = added_tokens_decoder + init_kwargs = cls.convert_added_tokens(init_kwargs, save=False) + for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys(): + if added_tokens_map != {} and init_kwargs[key] is not None: + if key != "additional_special_tokens": + init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key]) + + # Instantiate the tokenizer. + try: + tokenizer = cls(*init_inputs, **init_kwargs) + except import_protobuf_decode_error(): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(Google protobuf error: Tried to load SPM model with non-SPM vocab file).", + ) + return False + except RuntimeError as e: + if "sentencepiece_processor.cc" in str(e): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).", + ) + return False + except OSError: + raise OSError( + "Unable to load vocabulary from file. " + "Please check that the provided vocabulary is accessible and not corrupted." + ) + except RuntimeError as e: + if "sentencepiece_processor.cc" in str(e): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).", + ) + return False + + if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size: + logger.info( + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are" + " fine-tuned or trained." + ) + return tokenizer + + @staticmethod + def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + # This method should be deleted in Transformers v5 + # Its only purpose is to potentially throw a warning + # that incorrectly defined max lengths of T5's tokenizer are used + # which we will correct in Transformers v5. + return max_model_length + + @classmethod + def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_field=True): + if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken": + obj.pop("__type") + return AddedToken(**obj) + if isinstance(obj, AddedToken) and save: + obj = obj.__getstate__() + if add_type_field: + obj["__type"] = "AddedToken" + else: + # Don't save "special" for previous tokenizers + obj.pop("special") + return obj + elif isinstance(obj, (list, tuple)): + return [cls.convert_added_tokens(o, save=save, add_type_field=add_type_field) for o in obj] + elif isinstance(obj, dict): + return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()} + return obj + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + legacy_format: Optional[bool] = None, + filename_prefix: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> Tuple[str]: + """ + Save the full tokenizer state. + + + This method make sure the full tokenizer can then be re-loaded using the + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] class method.. + + Warning,None This won't save modifications you may have applied to the tokenizer after the instantiation (for + instance, modifying `tokenizer.do_lower_case` after creation). + + Args: + save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved. + legacy_format (`bool`, *optional*): + Only applicable for a fast tokenizer. If unset (default), will save the tokenizer in the unified JSON + format as well as in legacy format if it exists, i.e. with tokenizer specific vocabulary and a separate + added_tokens files. + + If `False`, will only save the tokenizer in the unified JSON format. This format is incompatible with + "slow" tokenizers (not powered by the *tokenizers* library), so the tokenizer will not be able to be + loaded in the corresponding "slow" tokenizer. + + If `True`, will save the tokenizer in legacy format. If the "slow" tokenizer doesn't exits, a value + error is raised. + filename_prefix (`str`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + + Returns: + A tuple of `str`: The files saved. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token", None) is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + special_tokens_map_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE + ) + tokenizer_config_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE + ) + chat_template_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE + ) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + + # Let's save the init kwargs + target_keys = set(self.init_kwargs.keys()) + # Let's save the special tokens map (only the strings) + target_keys.update(["model_max_length", "clean_up_tokenization_spaces"]) + + for k in target_keys: + if hasattr(self, k): + tokenizer_config[k] = getattr(self, k) + + # Let's make sure we properly save the special tokens + tokenizer_config.update(self.special_tokens_map) + if "extra_special_tokens" not in tokenizer_config: + tokenizer_config["extra_special_tokens"] = self.extra_special_tokens + tokenizer_config.update(self.extra_special_tokens) + + saved_raw_chat_template = False + if self.chat_template is not None: + if isinstance(self.chat_template, dict): + # Chat template dicts are saved to the config as lists of dicts with fixed key names. + # They will be reconstructed as a single dict during loading. + # We're trying to discourage chat template dicts, and they are always + # saved in the config, never as single files. + tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] + elif kwargs.get("save_raw_chat_template", False): + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template) + saved_raw_chat_template = True + logger.info(f"chat template saved in {chat_template_file}") + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + else: + tokenizer_config["chat_template"] = self.chat_template + + if len(self.init_inputs) > 0: + tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names.keys(): + tokenizer_config.pop(file_id, None) + + # no typefields, this way old fast and slow can load it + tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True) + + # Process added tokens seperatly: allows previous versions to ignore it! + added_tokens = {} + for key, value in self.added_tokens_decoder.items(): + added_tokens[key] = value.__getstate__() + tokenizer_config["added_tokens_decoder"] = added_tokens + + # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained + tokenizer_class = self.__class__.__name__ + # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast` + if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": + tokenizer_class = tokenizer_class[:-4] + tokenizer_config["tokenizer_class"] = tokenizer_class + if getattr(self, "_auto_map", None) is not None: + tokenizer_config["auto_map"] = self._auto_map + if getattr(self, "_processor_class", None) is not None: + tokenizer_config["processor_class"] = self._processor_class + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=tokenizer_config) + + # remove private information + if "name_or_path" in tokenizer_config: + tokenizer_config.pop("name_or_path") + tokenizer_config.pop("special_tokens_map_file", None) + tokenizer_config.pop("tokenizer_file", None) + if "device_map" in tokenizer_config: + tokenizer_config.pop("device_map") + + with open(tokenizer_config_file, "w", encoding="utf-8") as f: + out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + logger.info(f"tokenizer config file saved in {tokenizer_config_file}") + + # Sanitize AddedTokens in special_tokens_map + + # kept for forward compatibility, will be removed in transoformers 5. Typefields are not saved for FC, special should not be save either + write_dict = self.convert_added_tokens(self.special_tokens_map_extended, save=True, add_type_field=False) + with open(special_tokens_map_file, "w", encoding="utf-8") as f: + out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + logger.info(f"Special tokens file saved in {special_tokens_map_file}") + + file_names = (tokenizer_config_file, special_tokens_map_file) + if saved_raw_chat_template: + file_names += (chat_template_file,) + + save_files = self._save_pretrained( + save_directory=save_directory, + file_names=file_names, + legacy_format=legacy_format, + filename_prefix=filename_prefix, + ) + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return save_files + + def _save_pretrained( + self, + save_directory: Union[str, os.PathLike], + file_names: Tuple[str], + legacy_format: Optional[bool] = None, + filename_prefix: Optional[str] = None, + ) -> Tuple[str]: + """ + Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens. + + Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the + specific [`~tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`] + """ + if legacy_format is False: + raise ValueError( + "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format." + ) + + save_directory = str(save_directory) + + added_tokens_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE + ) + # the new get_added_vocab() also returns special tokens and tokens that have an index < vocab_size + added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size} + if added_vocab: + with open(added_tokens_file, "w", encoding="utf-8") as f: + out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + logger.info(f"added tokens file saved in {added_tokens_file}") + + vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) + + return file_names + vocab_files + (added_tokens_file,) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save only the vocabulary of the tokenizer (vocabulary + added tokens). + + This method won't save the configuration and special token mappings of the tokenizer. Use + [`~PreTrainedTokenizerFast._save_pretrained`] to save the whole state of the tokenizer. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + raise NotImplementedError + + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: + """ + Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. + + Args: + text (`str`): + The sequence to be encoded. + pair (`str`, *optional*): + A second sequence to be encoded with the first. + add_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add the special tokens associated with the corresponding model. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. See details in + [`~PreTrainedTokenizerBase.__call__`] + + Returns: + `List[str]`: The list of tokens. + """ + raise NotImplementedError + + @add_end_docstrings( + ENCODE_KWARGS_DOCSTRING, + """ + **kwargs: Passed along to the `.tokenize()` method. + """, + """ + Returns: + `List[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text. + """, + ) + def encode( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. + + Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`. + + Args: + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + encoded_inputs = self.encode_plus( + text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + padding_side=padding_side, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + raise NotImplementedError + + def _get_padding_truncation_strategies( + self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs + ): + """ + Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy + and pad_to_max_length) and behaviors. + """ + old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") + old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) + + # Backward compatibility for previous behavior, maybe we should deprecate it: + # If you only set max_length, it activates truncation for max_length + if max_length is not None and padding is False and truncation is None: + if verbose: + if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): + logger.warning( + "Truncation was not explicitly activated but `max_length` is provided a specific value, please" + " use `truncation=True` to explicitly truncate examples to max length. Defaulting to" + " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the" + " tokenizer you can select this strategy more precisely by providing a specific strategy to" + " `truncation`." + ) + self.deprecation_warnings["Truncation-not-explicitly-activated"] = True + truncation = "longest_first" + + # Get padding strategy + if padding is False and old_pad_to_max_length: + if verbose: + warnings.warn( + "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " + "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " + "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " + "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " + "maximal input size of the model (e.g. 512 for Bert).", + FutureWarning, + ) + if max_length is None: + padding_strategy = PaddingStrategy.LONGEST + else: + padding_strategy = PaddingStrategy.MAX_LENGTH + elif padding is not False: + if padding is True: + if verbose: + if max_length is not None and ( + truncation is None or truncation is False or truncation == "do_not_truncate" + ): + warnings.warn( + "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " + "To pad to max length, use `padding='max_length'`." + ) + if old_pad_to_max_length is not False: + warnings.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") + padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD + + # Get truncation strategy + if truncation is None and old_truncation_strategy != "do_not_truncate": + if verbose: + warnings.warn( + "The `truncation_strategy` argument is deprecated and will be removed in a future version, use" + " `truncation=True` to truncate examples to a max length. You can give a specific length with" + " `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input" + " size of the model (e.g. 512 for Bert). If you have pairs of inputs, you can give a specific" + " truncation strategy selected among `truncation='only_first'` (will only truncate the first" + " sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the" + " pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence" + " in the pairs).", + FutureWarning, + ) + truncation_strategy = TruncationStrategy(old_truncation_strategy) + elif truncation is not False and truncation is not None: + if truncation is True: + truncation_strategy = ( + TruncationStrategy.LONGEST_FIRST + ) # Default to truncate the longest sequences in pairs of inputs + elif not isinstance(truncation, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation) + elif isinstance(truncation, TruncationStrategy): + truncation_strategy = truncation + else: + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + + # Set max length if needed + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + if self.model_max_length > LARGE_INTEGER: + if verbose: + if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): + logger.warning( + "Asking to pad to max_length but no maximum length is provided and the model has no" + " predefined maximum length. Default to no padding." + ) + self.deprecation_warnings["Asking-to-pad-to-max_length"] = True + padding_strategy = PaddingStrategy.DO_NOT_PAD + else: + max_length = self.model_max_length + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: + if self.model_max_length > LARGE_INTEGER: + if verbose: + if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): + logger.warning( + "Asking to truncate to max_length but no maximum length is provided and the model has" + " no predefined maximum length. Default to no truncation." + ) + self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + else: + max_length = self.model_max_length + + # Test if we have a padding token + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0): + raise ValueError( + "Asking to pad but the tokenizer does not have a padding token. " + "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " + "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." + ) + + # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided + if ( + truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE + and padding_strategy != PaddingStrategy.DO_NOT_PAD + and pad_to_multiple_of is not None + and max_length is not None + and (max_length % pad_to_multiple_of != 0) + ): + raise ValueError( + "Truncation and padding are both activated but " + f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." + ) + + return padding_strategy, truncation_strategy, max_length, kwargs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences. + + Args: + text (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + """ + # To avoid duplicating + all_kwargs = { + "add_special_tokens": add_special_tokens, + "padding": padding, + "truncation": truncation, + "max_length": max_length, + "stride": stride, + "is_split_into_words": is_split_into_words, + "pad_to_multiple_of": pad_to_multiple_of, + "padding_side": padding_side, + "return_tensors": return_tensors, + "return_token_type_ids": return_token_type_ids, + "return_attention_mask": return_attention_mask, + "return_overflowing_tokens": return_overflowing_tokens, + "return_special_tokens_mask": return_special_tokens_mask, + "return_offsets_mapping": return_offsets_mapping, + "return_length": return_length, + "split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens), + "verbose": verbose, + } + all_kwargs.update(kwargs) + if text is None and text_target is None: + raise ValueError("You need to specify either `text` or `text_target`.") + if text is not None: + # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the + # input mode in this case. + if not self._in_target_context_manager: + self._switch_to_input_mode() + encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs) + if text_target is not None: + self._switch_to_target_mode() + target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs) + # Leave back tokenizer in input mode + self._switch_to_input_mode() + + if text_target is None: + return encodings + elif text is None: + return target_encodings + else: + encodings["labels"] = target_encodings["input_ids"] + return encodings + + def _call_one( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if not _is_valid_text_input(text): + raise ValueError( + "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None and not _is_valid_text_input(text_pair): + raise ValueError( + "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if is_split_into_words: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) + + if is_batched: + if isinstance(text_pair, str): + raise TypeError( + "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as" + " `text`." + ) + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + text (`str`, `List[str]` or (for non-fast tokenizers) `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=kwargs.pop("split_special_tokens", self.split_special_tokens), + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + raise NotImplementedError + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + batch_text_or_text_pairs (`List[str]`, `List[Tuple[str, str]]`, `List[List[str]]`, `List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also `List[List[int]]`, `List[Tuple[List[int], List[int]]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of + string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see + details in `encode_plus`). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + raise NotImplementedError + + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + List[BatchEncoding], + Dict[str, EncodedInput], + Dict[str, List[EncodedInput]], + List[Dict[str, EncodedInput]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. + + Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, + `self.pad_token_id` and `self.pad_token_type_id`). + + Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the + text followed by a call to the `pad` method to get a padded encoding. + + + + If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + + + Args: + encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `Dict[str, List[int]]`, `Dict[str, List[List[int]]` or `List[Dict[str, List[int]]]`): + Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of + tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str, + List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. + + Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see + the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + """ + if self.__class__.__name__.endswith("Fast"): + if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False): + logger.warning_advice( + f"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer," + " using the `__call__` method is faster than using a method to encode the text followed by a call" + " to the `pad` method to get a padded encoding." + ) + self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} + + # The model's main input name, usually `input_ids`, has been passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method " + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0): + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + for item in required_input: + if len(item) != 0: + first_element = item[0] + break + # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. + if not isinstance(first_element, (int, list, tuple)): + if is_tf_tensor(first_element): + return_tensors = "tf" if return_tensors is None else return_tensors + elif is_torch_tensor(first_element): + return_tensors = "pt" if return_tensors is None else return_tensors + elif isinstance(first_element, np.ndarray): + return_tensors = "np" if return_tensors is None else return_tensors + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + "Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in encoded_inputs.items(): + encoded_inputs[key] = to_py_obj(value) + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + required_input = encoded_inputs[self.model_input_names[0]] + if required_input and not isinstance(required_input[0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(required_input) + assert all( + len(v) == batch_size for v in encoded_inputs.values() + ), "Some items in the output dictionary have a different batch size than others." + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = {k: v[i] for k, v in encoded_inputs.items()} + outputs = self._pad( + inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create the token type IDs corresponding to the sequences passed. [What are token type + IDs?](../glossary#token-type-ids) + + Should be overridden in a subclass if the model has a special way of building those. + + Args: + token_ids_0 (`List[int]`): The first tokenized sequence. + token_ids_1 (`List[int]`, *optional*): The second tokenized sequence. + + Returns: + `List[int]`: The token type ids. + """ + if token_ids_1 is None: + return len(token_ids_0) * [0] + return [0] * len(token_ids_0) + [1] * len(token_ids_1) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. + + This implementation does not add special tokens and this method should be overridden in a subclass. + + Args: + token_ids_0 (`List[int]`): The first tokenized sequence. + token_ids_1 (`List[int]`, *optional*): The second tokenized sequence. + + Returns: + `List[int]`: The model input with special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + return token_ids_0 + token_ids_1 + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids* + different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return + overflowing tokens. Such a combination of arguments will raise an error. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def truncate_sequences( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `'longest_first'`): + The strategy to follow for truncation. Can be: + + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + + Returns: + `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair + of sequences (or a batch of pairs) is provided. + """ + if num_tokens_to_remove <= 0: + return ids, pair_ids, [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( + truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None + ): + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + if self.truncation_side == "left": + overflowing_tokens = ids[:window_len] + ids = ids[num_tokens_to_remove:] + elif self.truncation_side == "right": + overflowing_tokens = ids[-window_len:] + ids = ids[:-num_tokens_to_remove] + else: + raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.") + + else: + error_msg = ( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + ) + if truncation_strategy == TruncationStrategy.ONLY_FIRST: + error_msg = ( + error_msg + "Please select another truncation strategy than " + f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." + ) + logger.error(error_msg) + elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: + logger.warning( + "Be aware, overflowing tokens are not returned for the setting you have chosen," + f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " + "truncation strategy. So the returned list will always be empty even if some " + "tokens have been removed." + ) + len_pair_ids = len(pair_ids) if pair_ids is not None else 0 + len_ids = len(ids) + first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove) + second_remove = num_tokens_to_remove - first_remove + if len_ids > len_pair_ids: + ids_to_move = first_remove + second_remove // 2 + pair_ids_to_move = second_remove - second_remove // 2 + else: + ids_to_move = second_remove // 2 + pair_ids_to_move = first_remove + second_remove - (second_remove // 2) + + if self.truncation_side == "right": + ids = ids[:-ids_to_move] if ids_to_move > 0 else ids + pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids + elif self.truncation_side == "left": + ids = ids[ids_to_move:] + pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None + else: + raise ValueError(f"invalid truncation strategy:{self.truncation_side}") + + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + if self.truncation_side == "right": + overflowing_tokens = pair_ids[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + elif self.truncation_side == "left": + overflowing_tokens = pair_ids[:window_len] + pair_ids = pair_ids[num_tokens_to_remove:] + else: + raise ValueError(f"invalid truncation strategy:{self.truncation_side}") + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return (ids, pair_ids, overflowing_tokens) + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in `padding_side` argument: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError(f"Invalid padding strategy:{padding_side}") + + return encoded_inputs + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """ + Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we + often want to remove sub-word tokenization artifacts at the same time. + + Args: + tokens (`List[str]`): The token to join in a string. + + Returns: + `str`: The joined tokens. + """ + raise NotImplementedError + + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> List[str]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `List[str]`: The list of decoded sentences. + """ + return [ + self.decode( + seq, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + for seq in sequences + ] + + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str`: The decoded sentence. + """ + # Convert inputs to python lists + token_ids = to_py_obj(token_ids) + + return self._decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + raise NotImplementedError + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of ids of the first sequence. + token_ids_1 (`List[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + assert already_has_special_tokens and token_ids_1 is None, ( + "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " + "Please use a slow (full python) tokenizer to activate this argument. " + "Or set `return_special_tokens_mask=True` when calling the encoding method " + "to get the special tokens mask in any tokenizer. " + ) + + all_special_ids = self.all_special_ids # cache the property + + special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] + + return special_tokens_mask + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + """ + Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms. + + Args: + out_string (`str`): The text to clean up. + + Returns: + `str`: The cleaned-up string. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool): + """ + Depending on the input and internal state we might trigger a warning about a sequence that is too long for its + corresponding model + + Args: + ids (`List[str]`): The ids produced by the tokenization + max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set) + verbose (`bool`): Whether or not to print more information and warnings. + + """ + if max_length is None and len(ids) > self.model_max_length and verbose: + if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model " + "will result in indexing errors" + ) + self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True + + def _switch_to_input_mode(self): + """ + Private method to put the tokenizer in input mode (when it has different modes for input/outputs) + """ + pass + + def _switch_to_target_mode(self): + """ + Private method to put the tokenizer in target mode (when it has different modes for input/outputs) + """ + pass + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + warnings.warn( + "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your " + "labels by using the argument `text_target` of the regular `__call__` method (either in the same call as " + "your input texts if you use the same keyword arguments, or in a separate call." + ) + self._switch_to_target_mode() + self._in_target_context_manager = True + yield + self._in_target_context_manager = False + self._switch_to_input_mode() + + @classmethod + def register_for_auto_class(cls, auto_class="AutoTokenizer"): + """ + Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the + library are already mapped with `AutoTokenizer`. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`): + The auto class to register this new tokenizer with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = None, + truncation: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare model inputs for translation. For best performance, translate one sentence at a time. + + Arguments: + src_texts (`List[str]`): + List of documents to summarize or source language texts. + tgt_texts (`list`, *optional*): + List of summaries or target language texts. + max_length (`int`, *optional*): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) If + left unset or set to `None`, this will use the predefined model maximum length if a maximum length is + required by one of the truncation/padding parameters. If the model has no specific maximum input length + (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (`int`, *optional*): + Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set + to `None`, this will use the max_length value. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `True`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to `self.__call__`. + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **labels** -- List of token ids for tgt_texts. + + The full set of keys `[input_ids, attention_mask, labels]`, will only be returned if tgt_texts is passed. + Otherwise, input_ids, attention_mask will be the only keys. + """ + # docstyle-ignore + formatted_warning = """ +`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular +`__call__` method to prepare your inputs and targets. + +Here is a short example: + +model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...) + +If you either need to use different keyword arguments for the source and target texts, you should do two calls like +this: + +model_inputs = tokenizer(src_texts, ...) +labels = tokenizer(text_target=tgt_texts, ...) +model_inputs["labels"] = labels["input_ids"] + +See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. +For a more complete example, see the implementation of `prepare_seq2seq_batch`. +""" + warnings.warn(formatted_warning, FutureWarning) + # mBART-specific kwargs that should be ignored by other models. + kwargs.pop("src_lang", None) + kwargs.pop("tgt_lang", None) + if max_length is None: + max_length = self.model_max_length + model_inputs = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + with self.as_target_tokenizer(): + labels = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + ) + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + +def get_fast_tokenizer_file(tokenization_files: List[str]) -> str: + """ + Get the tokenization file to use for this version of transformers. + + Args: + tokenization_files (`List[str]`): The list of available configuration files. + + Returns: + `str`: The tokenization file to use. + """ + tokenizer_files_map = {} + for file_name in tokenization_files: + search = _re_tokenizer_file.search(file_name) + if search is not None: + v = search.groups()[0] + tokenizer_files_map[v] = file_name + available_versions = sorted(tokenizer_files_map.keys()) + + # Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions. + tokenizer_file = FULL_TOKENIZER_FILE + transformers_version = version.parse(__version__) + for v in available_versions: + if version.parse(v) <= transformers_version: + tokenizer_file = tokenizer_files_map[v] + else: + # No point going further since the versions are sorted. + break + + return tokenizer_file + + +# To update the docstring, we need to copy the method, otherwise we change the original docstring. +PreTrainedTokenizerBase.push_to_hub = copy_func(PreTrainedTokenizerBase.push_to_hub) +if PreTrainedTokenizerBase.push_to_hub.__doc__ is not None: + PreTrainedTokenizerBase.push_to_hub.__doc__ = PreTrainedTokenizerBase.push_to_hub.__doc__.format( + object="tokenizer", object_class="AutoTokenizer", object_files="tokenizer files" + ) diff --git a/tokenization_utils_fast.py b/tokenization_utils_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..925069f2c2f9574c173f4d7ca399b0786dbc0f2e --- /dev/null +++ b/tokenization_utils_fast.py @@ -0,0 +1,908 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers +see tokenization_utils.py +""" + +import copy +import json +import os +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import tokenizers.pre_tokenizers as pre_tokenizers_fast +from tokenizers import Encoding as EncodingFast +from tokenizers import Tokenizer as TokenizerFast +from tokenizers.decoders import Decoder as DecoderFast +from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer + +from .convert_slow_tokenizer import convert_slow_tokenizer +from .integrations.ggml import convert_gguf_tokenizer +from .modeling_gguf_pytorch_utils import load_gguf_checkpoint +from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils_base import ( + INIT_TOKENIZER_DOCSTRING, + AddedToken, + BatchEncoding, + PreTokenizedInput, + PreTokenizedInputPair, + PreTrainedTokenizerBase, + SpecialTokensMixin, + TextInput, + TextInputPair, + TruncationStrategy, +) +from .utils import PaddingStrategy, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file +TOKENIZER_FILE = "tokenizer.json" +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +TIKTOKEN_VOCAB_FILE = "tokenizer.model" + +# Slow tokenizers have an additional added tokens files +ADDED_TOKENS_FILE = "added_tokens.json" + +INIT_TOKENIZER_DOCSTRING += """ + tokenizer_object ([`tokenizers.Tokenizer`]): + A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗 + tokenizers](../fast_tokenizers) for more information. + tokenizer_file ([`str`]): + A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗 + tokenizers. +""" + +MODEL_TO_TRAINER_MAPPING = { + "BPE": BpeTrainer, + "Unigram": UnigramTrainer, + "WordLevel": WordLevelTrainer, + "WordPiece": WordPieceTrainer, +} + +VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE} + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class PreTrainedTokenizerFast(PreTrainedTokenizerBase): + """ + Base class for all fast tokenizers (wrapping HuggingFace tokenizers library). + + Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`]. + + Handles all the shared methods for tokenization and special tokens, as well as methods for + downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary. + + This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the + specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class: PreTrainedTokenizer = None + + def __init__(self, *args, **kwargs): + tokenizer_object = kwargs.pop("tokenizer_object", None) + slow_tokenizer = kwargs.pop("__slow_tokenizer", None) + gguf_file = kwargs.pop("gguf_file", None) + fast_tokenizer_file = kwargs.pop("tokenizer_file", None) + from_slow = kwargs.pop("from_slow", False) + added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) + self.add_prefix_space = kwargs.get("add_prefix_space", False) + + if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None: + raise ValueError( + "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you " + "have sentencepiece installed." + ) + + if tokenizer_object is not None: + fast_tokenizer = copy.deepcopy(tokenizer_object) + elif fast_tokenizer_file is not None and not from_slow: + # We have a serialization from tokenizers which let us directly build the backend + fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file) + elif slow_tokenizer: + # We need to convert a slow tokenizer to build the backend + fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) + elif gguf_file is not None: + # We need to convert a slow tokenizer to build the backend + gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file")) + architecture = gguf_param["config"]["model_type"] + tokenizer_dict = gguf_param["tokenizer"] + tokenizer_config = gguf_param["tokenizer_config"] + fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict) + kwargs.update(tokenizer_config) + if len(additional_kwargs) > 0: + kwargs.update(additional_kwargs) + elif self.slow_tokenizer_class is not None and slow_tokenizer is not False: + # We need to create and convert a slow tokenizer to build the backend + slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs) + fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) + elif not slow_tokenizer: + # We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken + self.vocab_file = kwargs.get("vocab_file", None) + self.additional_special_tokens = kwargs.get("additional_special_tokens", []) + fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True) + slow_tokenizer = None + else: + raise ValueError( + "Couldn't instantiate the backend tokenizer from one of: \n" + "(1) a `tokenizers` library serialization file, \n" + "(2) a slow tokenizer instance to convert or \n" + "(3) an equivalent slow tokenizer class to instantiate and convert. \n" + "You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one." + ) + + self._tokenizer = fast_tokenizer + + if slow_tokenizer is not None: + kwargs.update(slow_tokenizer.init_kwargs) + + self._decode_use_source_tokenizer = False + + _truncation = self._tokenizer.truncation + + if _truncation is not None: + self._tokenizer.enable_truncation(**_truncation) + kwargs.setdefault("max_length", _truncation["max_length"]) + kwargs.setdefault("truncation_side", _truncation["direction"]) + kwargs.setdefault("stride", _truncation["stride"]) + kwargs.setdefault("truncation_strategy", _truncation["strategy"]) + else: + self._tokenizer.no_truncation() + + _padding = self._tokenizer.padding + if _padding is not None: + self._tokenizer.enable_padding(**_padding) + kwargs.setdefault("pad_token", _padding["pad_token"]) + kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"]) + kwargs.setdefault("padding_side", _padding["direction"]) + kwargs.setdefault("max_length", _padding["length"]) + kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"]) + + # We call this after having initialized the backend tokenizer because we update it. + super().__init__(**kwargs) + self._tokenizer.encode_special_tokens = self.split_special_tokens + + added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder} + tokens_to_add = [ + token + for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0]) + if hash(repr(token)) not in added_tokens_decoder_hash + ] + encoder = list(self.added_tokens_encoder.keys()) + [str(token) for token in tokens_to_add] + # if some of the special tokens are strings, we check if we don't already have a token + tokens_to_add += [ + token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add + ] + + if len(tokens_to_add) > 0: + tokens = [] + special_tokens = self.all_special_tokens + for token in tokens_to_add: + is_special = ( + (token.special or str(token) in special_tokens) + if isinstance(token, AddedToken) + else str(token) in special_tokens + ) + if isinstance(token, str): + token = AddedToken(token, special=is_special) + else: + token.special = is_special + tokens.append(token) + if tokens: + self.add_tokens(tokens) + + try: + pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) + if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space: + pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type")) + pre_tok_state["add_prefix_space"] = self.add_prefix_space + self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) + except Exception: + # We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can + # not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer + # for which we need to update the `add_prefix_space` attribute. + pass + + @property + def is_fast(self) -> bool: + return True + + @property + def can_save_slow_tokenizer(self) -> bool: + """ + `bool`: Whether or not the slow tokenizer can be saved. Usually for sentencepiece based slow tokenizer, this + can only be `True` if the original `"sentencepiece.model"` was not deleted. + """ + return True + + @property + def vocab_size(self) -> int: + """ + `int`: Size of the base vocabulary (without the added tokens). + """ + return self._tokenizer.get_vocab_size(with_added_tokens=False) + + def get_vocab(self) -> Dict[str, int]: + return self._tokenizer.get_vocab(with_added_tokens=True) + + @property + def vocab(self) -> Dict[str, int]: + return self.get_vocab() + + @property + def added_tokens_encoder(self) -> Dict[str, int]: + """ + Returns the sorted mapping from string to index. The added tokens encoder is cached for performance + optimisation in `self._added_tokens_encoder` for the slow tokenizers. + """ + return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])} + + @property + def added_tokens_decoder(self) -> Dict[int, AddedToken]: + """ + Returns the added tokens in the vocabulary as a dictionary of index to AddedToken. + + Returns: + `Dict[str, int]`: The added tokens. + """ + return self._tokenizer.get_added_tokens_decoder() + + def get_added_vocab(self) -> Dict[str, int]: + """ + Returns the added tokens in the vocabulary as a dictionary of token to index. + + Returns: + `Dict[str, int]`: The added tokens. + """ + return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])} + + def __len__(self) -> int: + """ + Size of the full vocabulary with the added tokens. + """ + return self._tokenizer.get_vocab_size(with_added_tokens=True) + + @property + def backend_tokenizer(self) -> TokenizerFast: + """ + `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend. + """ + return self._tokenizer + + @property + def decoder(self) -> DecoderFast: + """ + `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer. + """ + return self._tokenizer.decoder + + def _convert_encoding( + self, + encoding: EncodingFast, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> Tuple[Dict[str, Any], List[EncodingFast]]: + """ + Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list + of encodings, take care of building a batch from overflowing tokens. + + Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are + lists (overflows) of lists (tokens). + + Output shape: (overflows, sequence length) + """ + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_overflowing_tokens and encoding.overflowing is not None: + encodings = [encoding] + encoding.overflowing + else: + encodings = [encoding] + + encoding_dict = defaultdict(list) + for e in encodings: + encoding_dict["input_ids"].append(e.ids) + + if return_token_type_ids: + encoding_dict["token_type_ids"].append(e.type_ids) + if return_attention_mask: + encoding_dict["attention_mask"].append(e.attention_mask) + if return_special_tokens_mask: + encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) + if return_offsets_mapping: + encoding_dict["offset_mapping"].append(e.offsets) + if return_length: + encoding_dict["length"].append(len(e.ids)) + + return encoding_dict, encodings + + def convert_tokens_to_ids(self, tokens: Union[str, Iterable[str]]) -> Union[int, List[int]]: + """ + Converts a token string (or a sequence of tokens) in a single integer id (or a Iterable of ids), using the + vocabulary. + + Args: + tokens (`str` or `Iterable[str]`): One or several token(s) to convert to token id(s). + + Returns: + `int` or `List[int]`: The token id or list of token ids. + """ + if isinstance(tokens, str): + return self._convert_token_to_id_with_added_voc(tokens) + + return [self._convert_token_to_id_with_added_voc(token) for token in tokens] + + def _convert_token_to_id_with_added_voc(self, token: str) -> int: + index = self._tokenizer.token_to_id(token) + if index is None: + return self.unk_token_id + return index + + def _convert_id_to_token(self, index: int) -> Optional[str]: + return self._tokenizer.id_to_token(int(index)) + + def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int: + if special_tokens: + return self._tokenizer.add_special_tokens(new_tokens) + + return self._tokenizer.add_tokens(new_tokens) + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + + + This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put + this inside your training loop. + + + + Args: + pair (`bool`, *optional*, defaults to `False`): + Whether the number of added tokens should be computed in the case of a sequence pair or a single + sequence. + + Returns: + `int`: Number of special tokens added to sequences. + """ + return self._tokenizer.num_special_tokens_to_add(pair) + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + return self._tokenizer.id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + tokens.append(self._tokenizer.id_to_token(index)) + return tokens + + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: + return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens() + + def set_truncation_and_padding( + self, + padding_strategy: PaddingStrategy, + truncation_strategy: TruncationStrategy, + max_length: int, + stride: int, + pad_to_multiple_of: Optional[int], + padding_side: Optional[bool], + ): + """ + Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers + library) and restore the tokenizer settings afterwards. + + The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a + padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed + section. + + Args: + padding_strategy ([`~utils.PaddingStrategy`]): + The kind of padding that will be applied to the input + truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]): + The kind of truncation that will be applied to the input + max_length (`int`): + The maximum size of a sequence. + stride (`int`): + The stride to use when handling overflow. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + """ + _truncation = self._tokenizer.truncation + _padding = self._tokenizer.padding + # Set truncation and padding on the backend tokenizer + if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE: + if _truncation is not None: + self._tokenizer.no_truncation() + else: + target = { + "max_length": max_length, + "stride": stride, + "strategy": truncation_strategy.value, + "direction": self.truncation_side, + } + + # _truncation might contain more keys that the target `transformers` + # supports. Use only the target keys to trigger `enable_truncation`. + # This should enable this code to works on various `tokenizers` + # targets. + if _truncation is None: + current = None + else: + current = {k: _truncation.get(k, None) for k in target} + + if current != target: + self._tokenizer.enable_truncation(**target) + + if padding_strategy == PaddingStrategy.DO_NOT_PAD: + if _padding is not None: + self._tokenizer.no_padding() + else: + length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None + target = { + "length": length, + "direction": padding_side if padding_side is not None else self.padding_side, + "pad_id": self.pad_token_id, + "pad_token": self.pad_token, + "pad_type_id": self.pad_token_type_id, + "pad_to_multiple_of": pad_to_multiple_of, + } + if _padding != target: + self._tokenizer.enable_padding(**target) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair] + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + ) -> BatchEncoding: + if not isinstance(batch_text_or_text_pairs, (tuple, list)): + raise TypeError( + f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})" + ) + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + ) + + if self._tokenizer.encode_special_tokens != split_special_tokens: + self._tokenizer.encode_special_tokens = split_special_tokens + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=is_split_into_words, + ) + + # Convert encoding to dict + # `Tokens` has type: Tuple[ + # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]], + # List[EncodingFast] + # ] + # with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0].keys(): + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + batched_input = [(text, text_pair)] if text_pair else [text] + batched_output = self._batch_encode_plus( + batched_input, + is_split_into_words=is_split_into_words, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: (value[0] if len(value) > 0 and isinstance(value[0], list) else value) + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return ( + self.backend_tokenizer.decoder.decode(tokens) + if self.backend_tokenizer.decoder is not None + else " ".join(tokens) + ) + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + if isinstance(token_ids, int): + token_ids = [token_ids] + text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def _save_pretrained( + self, + save_directory: Union[str, os.PathLike], + file_names: Tuple[str], + legacy_format: Optional[bool] = None, + filename_prefix: Optional[str] = None, + ) -> Tuple[str]: + """ + Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens as well as in a unique JSON + file containing {config + vocab + added-tokens}. + """ + save_directory = str(save_directory) + + if self.slow_tokenizer_class is None and legacy_format is True: + raise ValueError( + "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You" + " might consider leaving the legacy_format at `None` or setting it to `False`." + ) + + save_slow = ( + (legacy_format is None or legacy_format is True) + and self.slow_tokenizer_class is not None + and self.can_save_slow_tokenizer + ) + save_fast = legacy_format is None or legacy_format is False + + if save_slow: + added_tokens_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE + ) + # make sure to be foward compatible + added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size} + if added_vocab: + with open(added_tokens_file, "w", encoding="utf-8") as f: + out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + + vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) + file_names = file_names + vocab_files + (added_tokens_file,) + + if save_fast: + tokenizer_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE + ) + self.backend_tokenizer.save(tokenizer_file) + file_names = file_names + (tokenizer_file,) + + return file_names + + def train_new_from_iterator( + self, + text_iterator, + vocab_size, + length=None, + new_special_tokens=None, + special_tokens_map=None, + **kwargs, + ): + """ + Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline) + as the current one. + + Args: + text_iterator (generator of `List[str]`): + The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts + if you have everything in memory. + vocab_size (`int`): + The size of the vocabulary you want for your tokenizer. + length (`int`, *optional*): + The total number of sequences in the iterator. This is used to provide meaningful progress tracking + new_special_tokens (list of `str` or `AddedToken`, *optional*): + A list of new special tokens to add to the tokenizer you are training. + special_tokens_map (`Dict[str, str]`, *optional*): + If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special + token name to new special token name in this argument. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library. + + Returns: + [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on + `text_iterator`. + + """ + tokenizer_json = json.loads(self._tokenizer.to_str()) + # Remove added tokens for now (uses IDs of tokens) + added_tokens = tokenizer_json.pop("added_tokens") + # Remove post processor for now (uses IDs of tokens) + post_processor = tokenizer_json.pop("post_processor") + + unk_token = None + # Remove vocab + if tokenizer_json["model"]["type"] == "BPE": + tokenizer_json["model"]["vocab"] = {} + tokenizer_json["model"]["merges"] = [] + elif tokenizer_json["model"]["type"] == "Unigram": + if tokenizer_json["model"]["unk_id"] is not None: + unk_id = tokenizer_json["model"]["unk_id"] + unk_token = tokenizer_json["model"]["vocab"][unk_id][0] + if special_tokens_map is not None and unk_token in special_tokens_map: + unk_token = special_tokens_map[unk_token] + tokenizer_json["model"]["unk_id"] = 0 + tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]] + elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]: + tokenizer_json["model"]["vocab"] = {} + else: + raise ValueError( + f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) " + "only BPE, Unigram, WordLevel and WordPiece." + ) + + if ( + special_tokens_map is not None + and "unk_token" in tokenizer_json["model"] + and tokenizer_json["model"]["unk_token"] in special_tokens_map + ): + tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]] + + tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json)) + + # Get the special tokens from the current tokenizer if none are specified. + special_tokens = [] + for added_token in added_tokens: + special = added_token.pop("special", None) + _ = added_token.pop("id", None) + if tokenizer_json["model"]["type"] != "Unigram" and not special: + continue + if special_tokens_map is not None and added_token["content"] in special_tokens_map: + added_token["content"] = special_tokens_map[added_token["content"]] + special_tokens.append(AddedToken(**added_token)) + + if new_special_tokens is not None: + special_tokens.extend(new_special_tokens) + + # Trainer needs to know the end of word / continuing subword thingies in BPE + if ( + tokenizer_json["model"]["type"] == "BPE" + and "continuing_subword_prefix" not in kwargs + and tokenizer_json["model"]["continuing_subword_prefix"] is not None + ): + kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"] + if ( + tokenizer_json["model"]["type"] == "BPE" + and "end_of_word_suffix" not in kwargs + and tokenizer_json["model"]["end_of_word_suffix"] is not None + ): + kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"] + if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None: + kwargs["unk_token"] = unk_token + if tokenizer_json["pre_tokenizer"] is not None: + if ( + tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel" + or tokenizer_json["pre_tokenizer"]["type"] == "Sequence" + and "pretokenizers" in tokenizer_json["pre_tokenizer"] + and any( + pretokenizer["type"] == "ByteLevel" + for pretokenizer in tokenizer_json["pre_tokenizer"]["pretokenizers"] + ) + ): + kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet() + + trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]] + trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs) + tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer) + + if post_processor is not None: + trained_tokenizer_json = json.loads(tokenizer.to_str()) + # Almost done, we just have to adjust the token IDs in the post processor + if "special_tokens" in post_processor: + for key in post_processor["special_tokens"]: + tokens = post_processor["special_tokens"][key]["tokens"] + if special_tokens_map is not None: + tokens = [special_tokens_map.get(token, token) for token in tokens] + post_processor["special_tokens"][key]["tokens"] = tokens + for token in tokens: + token_id = tokenizer.token_to_id(token) + if token_id is None: + raise ValueError( + "Attempted to set a token in the post processor that does not exist in the mapping" + ) + + post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens] + + for special_token in ["cls", "sep"]: + if special_token in post_processor: + token, _ = post_processor[special_token] + if special_tokens_map is not None and token in special_tokens_map: + token = special_tokens_map[token] + token_id = tokenizer.token_to_id(token) + if token_id is None: + raise ValueError( + "Attempted to set a token in the post processor that does not exist in the mapping" + ) + post_processor[special_token] = [token, token_id] + + trained_tokenizer_json["post_processor"] = post_processor + tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json)) + + kwargs = self.init_kwargs.copy() + # Map pad/cls/mask token at the Transformers level + special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy() + special_tokens_list.remove("additional_special_tokens") + for token in special_tokens_list: + if getattr(self, token) is not None: + special_token = getattr(self, token) + if special_tokens_map is not None and special_token in special_tokens_map: + special_token = special_tokens_map[special_token] + + special_token_full = self._special_tokens_map.get(token, None) + if isinstance(special_token_full, AddedToken): + # Create an added token with the same parameters except the content + kwargs[token] = AddedToken( + special_token, + single_word=special_token_full.single_word, + lstrip=special_token_full.lstrip, + rstrip=special_token_full.rstrip, + normalized=special_token_full.normalized, + special=True, + ) + else: + kwargs[token] = special_token + + additional_special_tokens = self.additional_special_tokens + if new_special_tokens is not None: + additional_special_tokens.extend(new_special_tokens) + if len(additional_special_tokens) > 0: + kwargs["additional_special_tokens"] = additional_special_tokens + + return self.__class__(tokenizer_object=tokenizer, **kwargs) diff --git a/tokenizers.pyd b/tokenizers.pyd new file mode 100644 index 0000000000000000000000000000000000000000..5679f0f0a7cbb51e22f45c93ea287c10ab39065d --- /dev/null +++ b/tokenizers.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbcb8a7adfebb3ebe57fe5b1ba7c9668fbbe45e16de35d6e8180f6a2750e02f2 +size 6626816 diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f7108b12b9048eb8ae937cb04df3916d7ed02a02 --- /dev/null +++ b/trainer.py @@ -0,0 +1,5170 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + +import contextlib +import copy +import functools +import glob +import importlib.metadata +import inspect +import json +import math +import os +import random +import re +import shutil +import sys +import tempfile +import time +import warnings +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union + + +# Integrations must be imported before ML frameworks: +# isort: off +from .integrations import ( + get_reporting_integration_callbacks, + hp_params, +) + +# isort: on + +import huggingface_hub.utils as hf_hub_utils +import numpy as np +import torch +import torch.distributed as dist +from huggingface_hub import ModelCard, create_repo, upload_folder +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler + +from . import __version__ +from .configuration_utils import PretrainedConfig +from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .debug_utils import DebugOption, DebugUnderflowOverflow +from .feature_extraction_sequence_utils import SequenceFeatureExtractor +from .feature_extraction_utils import FeatureExtractionMixin +from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend +from .image_processing_utils import BaseImageProcessor +from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from .integrations.tpu import tpu_spmd_dataloader +from .modelcard import TrainingSummary +from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from .models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_MAPPING_NAMES, +) +from .optimization import Adafactor, get_scheduler +from .processing_utils import ProcessorMixin +from .pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_2_3, +) +from .tokenization_utils_base import PreTrainedTokenizerBase +from .trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + ExportableState, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from .trainer_pt_utils import ( + DistributedTensorGatherer, + EvalLoopContainer, + IterableDatasetShard, + LabelSmoother, + LayerWiseDummyOptimizer, + LengthGroupedSampler, + SequentialDistributedSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_model_param_count, + get_module_class_from_name, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_xla_mesh_reduce, + reissue_pt_warnings, + remove_dummy_checkpoint, +) +from .trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + HPSearchBackend, + HubStrategy, + PredictionOutput, + RemoveColumnsCollator, + SaveStrategy, + TrainerMemoryTracker, + TrainOutput, + check_target_module_exists, + default_compute_objective, + denumpify_detensorize, + enable_full_determinism, + find_executable_batch_size, + get_last_checkpoint, + has_length, + neftune_post_forward_hook, + number_of_arguments, + seed_worker, + set_seed, + speed_metrics, +) +from .training_args import OptimizerNames, ParallelMode, TrainingArguments +from .utils import ( + ADAPTER_CONFIG_NAME, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + XLA_FSDPV2_MIN_VERSION, + PushInProgress, + PushToHubMixin, + can_return_loss, + find_labels, + is_accelerate_available, + is_apex_available, + is_bitsandbytes_available, + is_datasets_available, + is_galore_torch_available, + is_grokadamw_available, + is_in_notebook, + is_ipex_available, + is_liger_kernel_available, + is_lomo_available, + is_peft_available, + is_safetensors_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_schedulefree_available, + is_torch_compile_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + logging, + strtobool, +) +from .utils.deprecation import deprecate_kwarg +from .utils.quantization_config import QuantizationMethod + + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +if is_in_notebook(): + from .utils.notebook import NotebookProgressCallback + + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr +else: + IS_XLA_FSDPV2_POST_2_2 = False + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +if is_safetensors_available(): + import safetensors.torch + +if is_peft_available(): + from peft import PeftModel + + +if is_accelerate_available(): + from accelerate import Accelerator, skip_first_batches + from accelerate import __version__ as accelerate_version + from accelerate.state import AcceleratorState + from accelerate.utils import ( + DistributedDataParallelKwargs, + DistributedType, + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + if is_deepspeed_available(): + from accelerate.utils import DeepSpeedSchedulerWrapper + +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + + +def _is_peft_model(model): + if is_peft_available(): + classes_to_check = (PeftModel,) if is_peft_available() else () + # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 + if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): + from peft import PeftMixedModel + + classes_to_check = (*classes_to_check, PeftMixedModel) + return isinstance(model, classes_to_check) + return False + + +def _get_fsdp_ckpt_kwargs(): + # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release + if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): + return {"adapter_only": True} + else: + return {} + + +def safe_globals(): + # Starting from version 2.4 PyTorch introduces a check for the objects loaded + # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes + # a default and requires allowlisting of objects being loaded. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + # See: https://github.com/huggingface/accelerate/pull/3036 + if version.parse(torch.__version__).release < version.parse("2.6").release: + return contextlib.nullcontext() + + np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core + allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] + # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for + # all versions of numpy + allowlist += [type(np.dtype(np.uint32))] + + return torch.serialization.safe_globals(allowlist) + + +if TYPE_CHECKING: + import optuna + + if is_datasets_available(): + import datasets + +logger = logging.get_logger(__name__) + + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +OPTIMIZER_NAME_BIN = "optimizer.bin" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" +FSDP_MODEL_NAME = "pytorch_model_fsdp" + + +class Trainer: + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + + Args: + model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): + The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. + + + + [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use + your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers + models. + + + + args ([`TrainingArguments`], *optional*): + The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the + `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. + data_collator (`DataCollator`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + default to [`default_data_collator`] if no `processing_class` is provided, an instance of + [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer. + train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*): + The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*): + The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + This supercedes the `tokenizer` argument, which is now deprecated. + model_init (`Callable[[], PreTrainedModel]`, *optional*): + A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start + from a new instance of the model as given by this function. + + The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to + be able to choose different architectures according to hyper parameters (such as layer count, sizes of + inner layers, dropout probabilities etc). + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`]. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to + `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered + after the last eval batch to signal that the function needs to calculate and return the global summary + statistics rather than accumulating the batch-level statistics + callbacks (List of [`TrainerCallback`], *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](callback). + + If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + + Important attributes: + + - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] + subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, + the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner + model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to `False` if model parallel or deepspeed is used, or if the default + `TrainingArguments.place_model_on_device` is overridden to return `False` . + - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while + in `train`) + + """ + + # Those are used as methods of the Trainer in examples. + from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + + @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True) + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = TrainingArguments(output_dir=output_dir) + if args.batch_eval_metrics and compute_metrics is not None: + if "compute_result" not in inspect.signature(compute_metrics).parameters.keys(): + raise ValueError( + "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`" + " boolean argument which will be triggered after the last batch of the eval set to signal that the" + " summary statistics should be returned by the function." + ) + if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None: + raise ValueError( + f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " + ) + if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: + if args.metric_for_best_model is None: + raise ValueError( + "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." + ) + + self.args = args + self.compute_loss_func = compute_loss_func + # Seed must be set before instantiating the model when using model + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + + self.create_accelerator_and_postprocess() + + # memory metrics - must set up as early as possible + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + + # set the correct log level depending on the node + log_level = args.get_process_log_level() + logging.set_verbosity(log_level) + + # force device and distributed setup init explicitly + args._setup_devices + + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" + " overwrite your model when calling the `train` method. This will become a fatal error in the next" + " release.", + FutureWarning, + ) + self.model_init = model_init + + if model.__class__.__name__ in MODEL_MAPPING_NAMES: + raise ValueError( + f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " + "computes hidden states and does not accept any labels. You should choose a model with a head " + "suitable for your task like any of the `AutoModelForXxx` listed at " + "https://huggingface.co/docs/transformers/model_doc/auto" + ) + + if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False): + self.is_model_parallel = True + else: + self.is_model_parallel = False + + if getattr(model, "hf_device_map", None) is not None: + devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] + if len(devices) > 1: + self.is_model_parallel = True + elif len(devices) == 1: + self.is_model_parallel = self.args.device != torch.device(devices[0]) + else: + self.is_model_parallel = False + + # warn users + if self.is_model_parallel: + logger.info( + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) + + if self.args.use_liger_kernel: + if is_liger_kernel_available(): + from liger_kernel.transformers import _apply_liger_kernel_to_instance + + if isinstance(model, PreTrainedModel): + # Patch the model with liger kernels. Use the default kernel configurations. + _apply_liger_kernel_to_instance(model=model) + else: + logger.warning( + "The model is not an instance of PreTrainedModel. No liger kernels will be applied." + ) + else: + raise ImportError( + "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. " + "Please install it with `pip install liger-kernel`" + ) + + _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( + model, "_hf_peft_config_loaded", False + ) + _quantization_method_supports_training = ( + getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable + ) + + _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr( + model.hf_quantizer, "is_qat_trainable", False + ) + + # Filter out quantized + compiled models + if _is_quantized_and_base_model and hasattr(model, "_orig_mod"): + raise ValueError( + "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT" + ) + + # At this stage the model is already loaded + if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable: + raise ValueError( + "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of" + " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" + " for more details" + ) + elif _is_quantized_and_base_model and not _quantization_method_supports_training: + raise ValueError( + f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}" + " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers" + f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}" + ) + + self.is_fsdp_xla_enabled = args.fsdp_config["xla"] + if len(args.fsdp) > 0: + if self.is_deepspeed_enabled: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Using fsdp only works in distributed training.") + + # one place to sort out whether to place the model on device or not + # postpone switching model to cuda when: + # 1. MP - since we are trying to fit a much bigger than 1 gpu model + # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, + # and we only use deepspeed for training at the moment + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first + # 4. FSDP - same as MP + self.place_model_on_device = args.place_model_on_device + if ( + self.is_model_parallel + or self.is_deepspeed_enabled + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) + or self.is_fsdp_xla_enabled + or self.is_fsdp_enabled + ): + self.place_model_on_device = False + + default_collator = ( + DataCollatorWithPadding(processing_class) + if processing_class is not None + and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor)) + else default_data_collator + ) + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.processing_class = processing_class + + # Bnb Quantized models doesn't support `.to` operation. + if ( + self.place_model_on_device + and not getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ): + self._move_model_to_device(model, args.device) + + # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs + if self.is_model_parallel: + self.args._n_gpu = 1 + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + + # Just in case the model was wrapped outside of the `Trainer` + unwrapped_model = self.accelerator.unwrap_model(model) + model_forward = ( + unwrapped_model.forward + if not _is_peft_model(unwrapped_model) + else unwrapped_model.get_base_model().forward + ) + forward_params = inspect.signature(model_forward).parameters + + # Check if the model has explicit setup for loss kwargs, + # if not, check if `**kwargs` are in model.forward + if hasattr(model, "accepts_loss_kwargs"): + self.model_accepts_loss_kwargs = model.accepts_loss_kwargs + else: + self.model_accepts_loss_kwargs = any( + k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values() + ) + + self.neftune_noise_alpha = args.neftune_noise_alpha + + self.compute_metrics = compute_metrics + self.preprocess_logits_for_metrics = preprocess_logits_for_metrics + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs + if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None: + raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.") + if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): + raise RuntimeError( + "Passing a `model_init` is incompatible with providing the `optimizers` argument. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + if is_torch_xla_available() and self.optimizer is not None: + for param in self.model.parameters(): + model_device = param.device + break + for param_group in self.optimizer.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + if model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you" + " created an optimizer around your model **before** putting on the device and passing it to the" + " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" + " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." + ) + if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( + self.optimizer is not None or self.lr_scheduler is not None + ): + raise RuntimeError( + "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + + # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. + self._loggers_initialized = False + + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") + + if args.max_steps > 0 and args.num_train_epochs > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: + raise ValueError( + "The train_dataset does not implement __len__, max_steps has to be specified. " + "The number of steps needs to be known in advance for the learning rate scheduler." + ) + + if ( + train_dataset is not None + and isinstance(train_dataset, torch.utils.data.IterableDataset) + and args.group_by_length + ): + raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") + + self._signature_columns = None + + # Mixed precision setup + self.use_apex = False + self.use_cpu_amp = False + + # Mixed precision setup for SageMaker Model Parallel + if is_sagemaker_mp_enabled(): + # BF16 + model parallelism in SageMaker: currently not supported, raise an error + if args.bf16: + raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") + + if IS_SAGEMAKER_MP_POST_1_10: + # When there's mismatch between SMP config and trainer argument, use SMP config as truth + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + f"but FP16 provided in trainer argument is {args.fp16}, " + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + else: + # smp < 1.10 does not support fp16 in trainer. + if hasattr(smp.state.cfg, "fp16"): + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." + ) + if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + if not is_torch_greater_or_equal_than_2_3: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + else: + args.half_precision_backend = "cpu_amp" + logger.info(f"Using {args.half_precision_backend} half precision backend") + + if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): + # deepspeed and SageMaker Model Parallel manage their own half precision + if args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + elif args.half_precision_backend == "apex": + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to" + " https://www.github.com/nvidia/apex." + ) + self.use_apex = True + + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + + self.control = TrainerControl() + + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then + # returned to 0 every time flos need to be logged + self.current_flos = 0 + self.hp_search_backend = None + default_label_names = find_labels(self.model.__class__) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(self.model.__class__) + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + # Internal variables to help with automatic batch size reduction + self._train_batch_size = args.train_batch_size + self._created_lr_scheduler = False + + # very last + self._memory_tracker.stop_and_update_metrics() + + # torch.compile + if args.torch_compile and not is_torch_compile_available(): + raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + + self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False) + if self.is_fsdp_xla_v2_enabled: + if not IS_XLA_FSDPV2_POST_2_2: + raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.") + # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. + # Tensor axis is just a placeholder where it will not be used in FSDPv2. + num_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled + + @property + def tokenizer(self) -> Optional[PreTrainedTokenizerBase]: + logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.") + return self.processing_class + + @tokenizer.setter + def tokenizer(self, processing_class) -> None: + logger.warning( + "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead." + ) + self.processing_class = processing_class + + def _activate_neftune(self, model): + r""" + Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: + https://arxiv.org/abs/2310.05914 + """ + unwrapped_model = self.accelerator.unwrap_model(model) + + if _is_peft_model(unwrapped_model): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + del unwrapped_model + + embeddings.neftune_noise_alpha = self.neftune_noise_alpha + hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + self.neftune_hook_handle = hook_handle + return model + + def _deactivate_neftune(self, model): + """ + Deactivates the neftune method. Make sure to call `_activate_neftune` first. + """ + if not hasattr(self, "neftune_hook_handle"): + raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") + + unwrapped_model = self.accelerator.unwrap_model(model) + + if _is_peft_model(unwrapped_model): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha, unwrapped_model + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformers.TrainerCallback`]. + + Args: + callback (`type` or [`~transformers.TrainerCallback]`): + A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformers.TrainerCallback]`): + A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`~transformers.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformers.TrainerCallback`]. + + Args: + callback (`type` or [`~transformers.TrainerCallback]`): + A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + + def _move_model_to_device(self, model, device): + model = model.to(device) + # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. + if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): + model.tie_weights() + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + model_to_inspect = self.model + if _is_peft_model(self.model): + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + # PeftMixedModel do not provide a `get_base_model` method + model_to_inspect = self.model.base_model.model + signature = inspect.signature(model_to_inspect.forward) + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if description is None else f"in the {description} set" + logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " + " you can safely ignore this message." + ) + + columns = [k for k in signature_columns if k in dataset.column_names] + if len(columns) == 0: + raise ValueError( + "No columns in the dataset match the model's forward method signature. " + f"The following columns have been ignored: [{', '.join(ignored_columns)}]. " + "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`." + ) + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def _get_collator_with_removed_columns( + self, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + if not self.args.remove_unused_columns: + return data_collator + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + logger=logger, + description=description, + model_name=self.model.__class__.__name__, + ) + return remove_columns_collator + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + # Build the sampler. + if self.args.group_by_length: + if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): + lengths = ( + self.train_dataset[self.args.length_column_name] + if self.args.length_column_name in self.train_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = ( + self.processing_class.model_input_names[0] if self.processing_class is not None else None + ) + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + + else: + return RandomSampler(self.train_dataset) + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + if eval_dataset is None or not has_length(eval_dataset): + return None + # Build the sampler. + + # Deprecated code + if self.args.use_legacy_prediction_loop: + if is_torch_xla_available(): + return SequentialDistributedSampler( + eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) + else: + return SequentialSampler(eval_dataset) + + if self.args.group_by_length: + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + lengths = ( + eval_dataset[self.args.length_column_name] + if self.args.length_column_name in eval_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + return LengthGroupedSampler( + self.args.eval_batch_size, + dataset=eval_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + + if self.args.world_size <= 1: + return SequentialSampler(eval_dataset) + else: + return None + + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*): + If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} + + return self.accelerator.prepare(eval_dataloader) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + data_collator = self.data_collator + + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + test_dataset = self._remove_unused_columns(test_dataset, description="test") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="test") + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(test_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # We use the same batch_size as for eval. + return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_optimizer() + if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: + # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer + optimizer = self.optimizer.optimizer + else: + optimizer = self.optimizer + self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + def get_decay_parameter_names(self, model) -> List[str]: + """ + Get all parameter names that weight decay will be applied to + + Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + apply to those modules since this function only filter out instance of nn.LayerNorm + """ + decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + return decay_parameters + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + if self.optimizer_cls_and_kwargs is not None: + optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + else: + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + def get_num_trainable_parameters(self): + """ + Get the number of trainable parameters. + """ + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) + + def get_learning_rates(self): + """ + Returns the learning rate of each parameter from self.optimizer. + """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") + return [group["lr"] for group in self.optimizer.param_groups] + + def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None): + """ + Returns optimizer group for a parameter if given, else returns all optimizer groups for params. + + Args: + param (`str` or `torch.nn.parameter.Parameter`, *optional*): + The parameter for which optimizer group needs to be returned. + """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") + if param is not None: + for group in self.optimizer.param_groups: + if param in group["params"]: + return group + return [group["params"] for group in self.optimizer.param_groups] + + @staticmethod + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, model: Optional[PreTrainedModel] = None + ) -> Tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`transformers.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + if args.optim == OptimizerNames.ADAFACTOR: + optimizer_cls = Adafactor + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim == OptimizerNames.ADAMW_HF: + from .optimization import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: + from torch.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: + optimizer_kwargs.update({"fused": True}) + elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: + try: + from torch_xla.amp.syncfree import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") + elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED: + try: + from torch_npu.optim import NpuFusedAdamW + + optimizer_cls = NpuFusedAdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import FusedAdamW from torch_npu.") + elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: + try: + from apex.optimizers import FusedAdam + + optimizer_cls = FusedAdam + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim in [ + OptimizerNames.ADAMW_BNB, + OptimizerNames.ADAMW_8BIT, + OptimizerNames.PAGED_ADAMW, + OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.ADEMAMIX, + OptimizerNames.ADEMAMIX_8BIT, + OptimizerNames.PAGED_ADEMAMIX, + OptimizerNames.PAGED_ADEMAMIX_8BIT, + OptimizerNames.LION, + OptimizerNames.LION_8BIT, + OptimizerNames.PAGED_LION, + OptimizerNames.PAGED_LION_8BIT, + OptimizerNames.RMSPROP_BNB, + OptimizerNames.RMSPROP_8BIT, + OptimizerNames.RMSPROP_32BIT, + ]: + try: + from bitsandbytes.optim import AdamW, Lion, RMSprop + + is_paged = False + optim_bits = 32 + optimizer_cls = None + additional_optim_kwargs = adam_kwargs + if "paged" in args.optim: + is_paged = True + if "8bit" in args.optim: + optim_bits = 8 + if "adam" in args.optim: + optimizer_cls = AdamW + elif "lion" in args.optim: + optimizer_cls = Lion + additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + elif "rmsprop" in args.optim: + optimizer_cls = RMSprop + # Above we pass all `adam_kwargs` to the optimizer, here + # we only pass `optim_args` which can be passed by the user. + additional_optim_kwargs = optim_args + elif "ademamix" in args.optim: + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.44.0"): + raise ValueError( + "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. " + "Please install `bitsandbytes` >= 0.44.0." + ) + + from bitsandbytes.optim import AdEMAMix + + optimizer_cls = AdEMAMix + additional_optim_kwargs = { + "betas": ( + float(optim_args.get("beta1", args.adam_beta1)), + float(optim_args.get("beta2", args.adam_beta2)), + float(optim_args.get("beta3", 0.9999)), + ), + "alpha": float(optim_args.get("alpha", 5.0)), + "eps": float(optim_args.get("eps", args.adam_epsilon)), + } + + if "t_alpha" in optim_args: + additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"]) + + if "t_beta3" in optim_args: + additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"]) + + bnb_kwargs = {"optim_bits": optim_bits} + if "rmsprop" not in args.optim: + bnb_kwargs["is_paged"] = is_paged + + optimizer_kwargs.update(additional_optim_kwargs) + optimizer_kwargs.update(bnb_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!") + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.41.1"): + logger.warning( + "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. " + "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers." + ) + elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: + try: + from torchdistx.optimizers import AnyPrecisionAdamW + + optimizer_cls = AnyPrecisionAdamW + optimizer_kwargs.update(adam_kwargs) + + # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. + optimizer_kwargs.update( + { + "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), + "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), + "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), + "compensation_buffer_dtype": getattr( + torch, optim_args.get("compensation_buffer_dtype", "bfloat16") + ), + } + ) + except ImportError: + raise ValueError("Please install https://github.com/pytorch/torchdistx") + elif args.optim == OptimizerNames.SGD: + optimizer_cls = torch.optim.SGD + elif args.optim == OptimizerNames.ADAGRAD: + optimizer_cls = torch.optim.Adagrad + elif args.optim == OptimizerNames.RMSPROP: + optimizer_cls = torch.optim.RMSprop + elif args.optim in [ + OptimizerNames.GALORE_ADAMW, + OptimizerNames.GALORE_ADAMW_8BIT, + OptimizerNames.GALORE_ADAFACTOR, + OptimizerNames.GALORE_ADAMW_LAYERWISE, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE, + ]: + if not is_galore_torch_available(): + raise ImportError( + "You need to install `galore_torch` in order to use GaLore optimizers" + " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" + ) + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + is_layerwise = args.optim.lower().endswith("layerwise") + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: + raise NotImplementedError("Layer-wise GaLore does not support DDP at this time") + + optimizer_mapping = { + OptimizerNames.GALORE_ADAMW: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, + OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, + } + + optimizer_cls = optimizer_mapping[args.optim] + + if args.optim_target_modules is None: + raise ValueError( + "You need to define a `optim_target_modules` in order to properly use GaLore optimizers" + ) + + if not isinstance(args.optim_target_modules, (list, str)): + raise ValueError( + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" + ) + + if model is None: + raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + + logger.warning( + "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" + ) + + all_linear = ( + isinstance(args.optim_target_modules, str) + and args.optim_target_modules.replace("_", "-") == "all-linear" + ) + + galore_params = [] + galore_params_names = [] + for module_name, module in model.named_modules(): + target_module_exists, is_regex = check_target_module_exists( + args.optim_target_modules, module_name, return_is_regex=True + ) + + if not isinstance(module, nn.Linear): + # Warn in case we match but it's not a linear layer + if target_module_exists and not is_regex: + logger.warning( + f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!" + ) + + continue + + if not target_module_exists and not all_linear: + continue + + galore_params.append(module.weight) + galore_params_names.append(module_name + ".weight") + + if len(galore_params) == 0: + raise ValueError( + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." + ) + + non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] + + galore_optim_kwargs = { + "rank": int(optim_args.pop("rank", 128)), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 0.25)), + "proj_type": optim_args.pop("proj_type", "std"), + } + + # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore + param_groups = [ + {"params": non_galore_params}, + {"params": galore_params, **galore_optim_kwargs}, + ] + + if is_layerwise: + # For layer-wise optimizers, the optimization step is done through post accumulation + # gradient hooks. The trick is to first attach these hooks to the model parameters then + # create a dummy optimizer that will perform no-ops in the Trainer. + # See the original implementation or the nice implementation from @hiyouga + # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + if args.gradient_accumulation_steps != 1: + raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !") + + optimizer_dict = {} + for param in non_galore_params: + param_groups = [{"params": [param]}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + for param in galore_params: + param_groups = [{"params": [param], **galore_optim_kwargs}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + + def optimizer_hook(param): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in model.parameters(): + if param.requires_grad: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer_cls = LayerWiseDummyOptimizer + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) + + optimizer_kwargs.update({"params": param_groups}) + + if args.optim == OptimizerNames.GALORE_ADAFACTOR: + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + if not is_lomo_available(): + raise ImportError( + "You need to install `lomo_optim` in order to use LOMO optimizers" + " install it with `pip install lomo-optim`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers") + + if model is None: + raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.") + + from lomo_optim import AdaLomo, Lomo + + if "ada" in args.optim: + optimizer_cls = AdaLomo + else: + optimizer_cls = Lomo + + optimizer_kwargs.update({"model": model}) + elif args.optim == OptimizerNames.GROKADAMW: + if not is_grokadamw_available(): + raise ValueError("Please install grokadamw with `pip install grokadamw`") + + from grokadamw import GrokAdamW + + optimizer_cls = GrokAdamW + optimizer_kwargs.update( + { + "alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + } + ) + elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT: + if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse( + "0.4.0" + ): + raise ImportError( + "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers." + "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao" + ) + if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"): + raise ImportError( + "You need to have `torch>2.4` in order to use torch 4-bit optimizers. " + "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly." + ) + from torchao.prototype.low_bit_optim import AdamW4bit + + optimizer_cls = AdamW4bit + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [ + OptimizerNames.SCHEDULE_FREE_ADAMW, + OptimizerNames.SCHEDULE_FREE_SGD, + ]: + if not is_schedulefree_available(): + raise ImportError( + "You need to install `schedulefree` in order to use schedulefree optimizers" + " install it with `pip install schedulefree`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers") + from schedulefree import AdamWScheduleFree, SGDScheduleFree + + additional_optim_kwargs = {} + if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW: + optimizer_cls = AdamWScheduleFree + additional_optim_kwargs = adam_kwargs + elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD: + optimizer_cls = SGDScheduleFree + else: + raise ValueError("Invalid schedulefree optimizer") + additional_optim_kwargs["weight_decay"] = args.weight_decay + additional_optim_kwargs["warmup_steps"] = args.warmup_steps + additional_optim_kwargs.update( + { + "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)), + "r": float(optim_args.get("r", 0.0)), + } + ) + optimizer_kwargs.update(additional_optim_kwargs) + else: + raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") + return optimizer_cls, optimizer_kwargs + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + optimizer=self.optimizer if optimizer is None else optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + scheduler_specific_kwargs=self.args.lr_scheduler_kwargs, + ) + self._created_lr_scheduler = True + return self.lr_scheduler + + def num_examples(self, dataloader: DataLoader) -> int: + """ + Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When + dataloader.dataset does not exist or has no length, estimates as best it can + """ + try: + dataset = dataloader.dataset + # Special case for IterableDatasetShard, we need to dig deeper + if isinstance(dataset, IterableDatasetShard): + return len(dataloader.dataset.dataset) + return len(dataloader.dataset) + except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader + return len(dataloader) * self.args.per_device_train_batch_size + + @staticmethod + def num_tokens(train_dl: DataLoader, max_steps: Optional[int] = None) -> int: + """ + Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader. + """ + train_tokens = 0 + try: + for batch in train_dl: + tokens = batch["input_ids"].numel() + if max_steps is not None: + return tokens * max_steps + train_tokens += tokens + except KeyError: + logger.warning("Cannot get num_tokens from dataloader") + return train_tokens + + def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): + """HP search setup code""" + self._trial = trial + + if self.hp_search_backend is None or trial is None: + return + if self.hp_search_backend == HPSearchBackend.OPTUNA: + params = self.hp_space(trial) + elif self.hp_search_backend == HPSearchBackend.RAY: + params = trial + params.pop("wandb", None) + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} + elif self.hp_search_backend == HPSearchBackend.WANDB: + params = trial + + for key, value in params.items(): + if not hasattr(self.args, key): + logger.warning( + f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" + " `TrainingArguments`." + ) + continue + old_attr = getattr(self.args, key, None) + # Casting value to the proper type + if old_attr is not None: + value = type(old_attr)(value) + + setattr(self.args, key, value) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + logger.info(f"Trial: {trial.params}") + if self.hp_search_backend == HPSearchBackend.SIGOPT: + logger.info(f"SigOpt Assignments: {trial.assignments}") + if self.hp_search_backend == HPSearchBackend.WANDB: + logger.info(f"W&B Sweep parameters: {trial}") + if self.is_deepspeed_enabled: + if self.args.deepspeed is None: + raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") + + self.accelerator.free_memory() + + # Rebuild the deepspeed config to reflect the updated training parameters + from accelerate.utils import DeepSpeedPlugin + + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) + self.args.hf_deepspeed_config.trainer_config_process(self.args) + self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) + + # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps. + # Simply calling `_reset_state` is enough and doesn't need a version pin. + AcceleratorState()._reset_state() + + self.create_accelerator_and_postprocess() + + def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): + if self.hp_search_backend is None or trial is None: + return + metrics = metrics.copy() + self.objective = self.compute_objective(metrics) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + import optuna + + if hasattr(trial, "study") and not trial.study._is_multi_objective(): + trial.report(self.objective, step) + if trial.should_prune(): + self.callback_handler.on_train_end(self.args, self.state, self.control) + raise optuna.TrialPruned() + elif self.hp_search_backend == HPSearchBackend.RAY: + import ray.train + + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + checkpoint = None + if self.control.should_save: + self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir) + checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir) + metrics["objective"] = self.objective + ray.train.report(metrics, checkpoint=checkpoint) + + def _tune_save_checkpoint(self, checkpoint_dir: str): + output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") + self.save_model(output_dir, _internal_call=True) + if self.args.should_save: + # Update the `TrainerControl` state to where we are currently + self.state.stateful_callbacks["TrainerControl"] = self.control.state() + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + + def call_model_init(self, trial=None): + model_init_argcount = number_of_arguments(self.model_init) + if model_init_argcount == 0: + model = self.model_init() + elif model_init_argcount == 1: + model = self.model_init(trial) + else: + raise RuntimeError("model_init should have 0 or 1 argument.") + + if model is None: + raise RuntimeError("model_init should not return None.") + + return model + + def torch_jit_model_eval(self, model, dataloader, training=False): + if not training: + if dataloader is None: + logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") + return model + example_batch = next(iter(dataloader)) + example_batch = self._prepare_inputs(example_batch) + try: + jit_model = copy.copy(model) + jit_model.eval() + original_forward = jit_model.__dict__.pop("_original_forward", None) + # remove mixed precision hooks from the model + if original_forward: + jit_model.forward = original_forward + with self.accelerator.autocast(cache_enabled=False), torch.no_grad(): + if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"): + if isinstance(example_batch, dict): + jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) + else: + jit_model = torch.jit.trace( + jit_model, + example_kwarg_inputs={key: example_batch[key] for key in example_batch}, + strict=False, + ) + else: + jit_inputs = [] + for key in example_batch: + example_tensor = torch.ones_like(example_batch[key]) + jit_inputs.append(example_tensor) + jit_inputs = tuple(jit_inputs) + jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) + jit_model = torch.jit.freeze(jit_model) + with torch.no_grad(): + jit_model(**example_batch) + jit_model(**example_batch) + model = jit_model + self.use_cpu_amp = False + except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: + logger.warning(f"failed to use PyTorch jit mode due to: {e}.") + + return model + + def ipex_optimize_model(self, model, training=False, dtype=torch.float32): + if not is_ipex_available(): + raise ImportError( + "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" + " to https://github.com/intel/intel-extension-for-pytorch." + ) + + import intel_extension_for_pytorch as ipex + + if not training: + model.eval() + dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype + # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings + model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) + else: + if not model.training: + model.train() + model, self.optimizer = ipex.optimize( + model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" + ) + + return model + + def compare_trainer_and_checkpoint_args(self, training_args, trainer_state): + attributes_map = { + "logging_steps": "logging_steps", + "eval_steps": "eval_steps", + "save_steps": "save_steps", + } + + has_warning = False + warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: " + for arg_attr, state_attr in attributes_map.items(): + arg_value = getattr(training_args, arg_attr, None) + state_value = getattr(trainer_state, state_attr, None) + + if arg_value is not None and state_value is not None and arg_value != state_value: + warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)" + has_warning = True + + # train bs is special as we need to account for multi-GPU + train_bs_args = training_args.per_device_train_batch_size + train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu) + + if train_bs_args != train_bs_state: + warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)" + has_warning = True + + if has_warning: + logger.warning_once(warning_str) + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.use_ipex: + dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 + model = self.ipex_optimize_model(model, training, dtype=dtype) + + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if self.accelerator.unwrap_model(model) is not model: + return model + + # Mixed precision training with apex (torch < 1.6) + if self.use_apex and training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP + if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): + model = nn.DataParallel(model) + + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + # Distributed training using PyTorch FSDP + if self.is_fsdp_xla_enabled: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + if self.is_fsdp_xla_v2_enabled: + from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( + SpmdFullyShardedDataParallel as FSDPv2, + ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) + fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) + + if self.args.fsdp_config["min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"] + ) + elif fsdp_transformer_layer_cls_to_wrap is not None: + transformer_cls_to_wrap = set() + for layer_class in fsdp_transformer_layer_cls_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + if model.config.use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + model.config.use_cache = False + + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2 + return target_cls(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + if self.is_fsdp_xla_v2_enabled: + + def shard_output(output, mesh): + from .modeling_outputs import CausalLMOutputWithPast + + real_output = None + if isinstance(output, torch.Tensor): + real_output = output + elif isinstance(output, tuple): + real_output = output[0] + elif isinstance(output, CausalLMOutputWithPast): + real_output = output.logits + + if real_output is None: + raise ValueError("Something went wrong, the output of the model shouldn't be `None`") + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + + self.model = model = FSDPv2( + model, + shard_output=shard_output, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + ) + else: + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step + elif is_sagemaker_dp_enabled(): + model = nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] + ) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + if is_torch_neuroncore_available(): + return model + kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + + if self.args.ddp_broadcast_buffers is not None: + kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + + return model + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments used to hide deprecated arguments + """ + if resume_from_checkpoint is False: + resume_from_checkpoint = None + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # Attach NEFTune hooks if necessary + if self.neftune_noise_alpha is not None: + self.model = self._activate_neftune(self.model) + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel: + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None: + if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint) + # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly + state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + if state.train_batch_size is not None: + self._train_batch_size = state.train_batch_size + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + if args.push_to_hub: + try: + # Disable progress bars when uploading models during checkpoints to avoid polluting stdout + hf_hub_utils.disable_progress_bars() + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + finally: + hf_hub_utils.enable_progress_bars() + else: + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the intial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + num_train_tokens = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + ) + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if use_accelerator_prepare and self.is_fsdp_enabled: + # In case of auto_find_batch_size=True + # Remove FSDP wrapping from sub-models. + self.model = unwrap_model(self.model, recursive=True) + + if delay_optimizer_creation: + if use_accelerator_prepare: + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() + if self.accelerator.mixed_precision != "fp8": + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + for epoch in range(epochs_trained, num_train_epochs): + epoch_dataloader = train_dataloader + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_dataloader) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = num_examples % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + if not do_sync_step: + self.accelerator.gradient_state._set_sync_gradients(False) + else: + self.accelerator.gradient_state._set_sync_gradients(True) + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + input_tokens = inputs[main_input_name].numel() + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += ( + self.accelerator.gather(input_tokens).sum().cpu().item() + ) + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + and self.accelerator.distributed_type != DistributedType.DEEPSPEED + else contextlib.nullcontext + ) + with context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss = tr_loss + tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time + ) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_xla_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def _get_output_dir(self, trial): + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + import ray.train + + run_id = ray.train.get_context().get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + elif self.hp_search_backend == HPSearchBackend.WANDB: + import wandb + + run_id = wandb.run.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + return run_dir + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + if model is None: + model = self.model + + config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) + adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) + adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) + safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) + is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) + ) + # if multiple adapters exist, they get saved in sub directories + adapter_subdirs = ( + [ + folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + and ( + os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME)) + or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME)) + ) + ] + if os.path.isdir(resume_from_checkpoint) + else [] + ) + + if is_fsdp_ckpt and not self.is_fsdp_enabled: + raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP") + + if not ( + any( + os.path.isfile(f) + for f in [ + weights_file, + safe_weights_file, + weights_index_file, + safe_weights_index_file, + adapter_weights_file, + adapter_safe_weights_file, + ] + ) + or is_fsdp_ckpt + or adapter_subdirs + ): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint}.") + + if os.path.isfile(config_file): + config = PretrainedConfig.from_json_file(config_file) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warning( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + + if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: + weights_only_kwarg = {"weights_only": True} + # If the model is on the GPU, it still works! + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if hasattr(self.args, "fp16") and self.args.fp16 is True: + logger.warning( + "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." + ) + state_dict = torch.load( + weights_file, + map_location="cpu", + **weights_only_kwarg, + ) + # Required for smp to not auto-translate state_dict from hf to smp (is already smp). + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + # release memory + del state_dict + elif self.is_fsdp_enabled: + load_fsdp_model( + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + resume_from_checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(safe_weights_file): + state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") + else: + state_dict = torch.load( + weights_file, + map_location="cpu", + **weights_only_kwarg, + ) + + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + # release memory + del state_dict + self._issue_warnings_after_load(load_result) + + # Load adapters following PR # 24096 + elif _is_peft_model(model): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): + if os.path.exists(resume_from_checkpoint): + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapters = model.active_adapters + if len(active_adapters) > 1: + logger.warning("Multiple active adapters detected will only consider the first adapter") + active_adapter = active_adapters[0] + else: + active_adapter = model.active_adapter + + if adapter_subdirs: + for subdir_name in adapter_subdirs: + peft_id = os.path.join(resume_from_checkpoint, subdir_name) + model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter)) + model.set_adapter(active_adapter) + else: + model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + else: + # We load the sharded checkpoint + load_result = load_sharded_checkpoint( + model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, + self.state.best_model_checkpoint, + load_module_strict=not _is_peft_model(self.model), + ) + elif self.is_fsdp_enabled: + load_result = load_fsdp_model( + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + self.state.best_model_checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + elif ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) + ): + has_been_loaded = True + weights_only_kwarg = {"weights_only": True} + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load( + best_model_path, + map_location="cpu", + **weights_only_kwarg, + ) + + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + else: + if _is_peft_model(model): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapter = model.active_adapters[0] + if len(model.active_adapters) > 1: + logger.warning("Detected multiple active adapters, will only consider the first one") + else: + active_adapter = model.active_adapter + + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): + try: + model.load_adapter(self.state.best_model_checkpoint, active_adapter) + except RuntimeError as exc: + if model.peft_config[active_adapter].is_prompt_learning: + # for context: https://github.com/huggingface/peft/issues/2256 + msg = ( + "When using prompt learning PEFT methods such as " + f"{model.peft_config[active_adapter].peft_type.value}, setting " + "load_best_model_at_end=True can lead to errors, it is recommended " + "to set this to False and to load the model manually from the checkpoint " + "directory using PeftModel.from_pretrained(base_model, ) after training " + "has finished." + ) + raise RuntimeError(msg) from exc + else: + raise + # Load_adapter has no return value present, modify it when appropriate. + from torch.nn.modules.module import _IncompatibleKeys + + load_result = _IncompatibleKeys([], []) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + has_been_loaded = False + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + has_been_loaded = False + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load( + best_model_path, + map_location="cpu", + **weights_only_kwarg, + ) + + # If the model is on the GPU, it still works! + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + if not is_sagemaker_mp_enabled() and has_been_loaded: + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists( + os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME) + ): + load_result = load_sharded_checkpoint( + model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _issue_warnings_after_load(self, load_result): + if len(load_result.missing_keys) != 0: + if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( + self.model._keys_to_ignore_on_save + ): + self.model.tie_weights() + else: + logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") + if len(load_result.unexpected_keys) != 0: + logger.warning( + f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." + ) + + def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + try: + self.lr_scheduler.step(metrics[metric_to_check]) + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', " + f"which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. " + f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or " + f"consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + return metrics + + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + if is_torch_xla_available(): + xm.mark_step() + + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs, start_time) + + metrics = None + if self.control.should_evaluate: + metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == SaveStrategy.BEST: + self.control.should_save = is_new_best_metric + + if self.control.should_save: + self._save_checkpoint(model, trial) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + if self.args.world_size > 1: + process_index = self.args.process_index + rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") + if not os.path.isfile(rng_file): + logger.info( + f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + with safe_globals(): + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if torch.cuda.is_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + else: + try: + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_xla_available(): + xm.set_rng_state(checkpoint_rng_state["xla"]) + if is_torch_npu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"]) + else: + try: + torch.npu.random.set_rng_state(checkpoint_rng_state["npu"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_mlu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"]) + else: + try: + torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_musa_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.musa.set_rng_state_all(checkpoint_rng_state["musa"]) + else: + try: + torch.musa.set_rng_state(checkpoint_rng_state["musa"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + + def _save_checkpoint(self, model, trial): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + # Save RNG state + self._save_rng_state(output_dir) + + # Save the Trainer state + if self.args.should_save: + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + # Solely rely on numerical checkpoint id for rotation. + # mtime is not reliable especially on some fuse fs in cloud environments. + self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) + + def _save_rng_state(self, output_dir): + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_xla_available(): + rng_states["xla"] = xm.get_rng_state() + + if is_torch_npu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["npu"] = torch.npu.random.get_rng_state_all() + else: + rng_states["npu"] = torch.npu.random.get_rng_state() + + if is_torch_mlu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["mlu"] = torch.mlu.random.get_rng_state_all() + else: + rng_states["mlu"] = torch.mlu.random.get_rng_state() + + if is_torch_musa_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["musa"] = torch.musa.get_rng_state_all() + else: + rng_states["musa"] = torch.musa.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + + if self.args.world_size <= 1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) + + def _save_optimizer_and_scheduler(self, output_dir): + if is_torch_xla_available(): + xm.rendezvous("saving_optimizer_states") + if self.is_fsdp_xla_v1_enabled: + optm = { + "optimizer": self.optimizer.state_dict(), + "shard_metadata": self.model.get_shard_metadata(), + } + xm.save( + optm, + os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + master_only=False, + ) + else: + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + elif self.is_deepspeed_enabled: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set( + inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys() + ) + if accept_exclude_frozen_parameters and _is_peft_model(self.model): + self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True) + else: + self.model_wrapped.save_checkpoint(output_dir) + elif self.is_fsdp_enabled: + # save fsdp specific ckpt for resuming from ckpt + save_fsdp_model( + self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs() + ) + save_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir + ) + elif self.args.should_save: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + + # Save SCHEDULER & SCALER + is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance( + self.lr_scheduler, DeepSpeedSchedulerWrapper + ) + if ( + self.args.should_save + and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) + and not is_torch_xla_available() + ): + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if self.is_deepspeed_enabled: + # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) + return + + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else ( + os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN)) + or ( + os.path.isdir(checkpoint) + and any( + OPTIMIZER_NAME_BIN.split(".")[0] in folder_name + for folder_name in os.listdir(checkpoint) + if os.path.isdir(os.path.join(checkpoint, folder_name)) + ) + ) + ) + ) + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) + if self.is_fsdp_xla_v1_enabled + else checkpoint_file_exists + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + if is_torch_xla_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + if self.is_fsdp_xla_v1_enabled: + optimizer_state = torch.load( + os.path.join( + checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + map_location="cpu", + ) + # We only need `optimizer` when resuming from checkpoint + optimizer_state = optimizer_state["optimizer"] + else: + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + with warnings.catch_warnings(record=True) as caught_warnings: + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): + # Optimizer checkpoint was saved with smp >= 1.10 + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + else: + # Optimizer checkpoint was saved with smp < 1.10 + def opt_load_hook(mod, opt): + if IS_SAGEMAKER_MP_POST_1_10: + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) + ) + else: + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. + # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more + # likely to get OOM on CPU (since we load num_gpu times the optimizer state + map_location = self.args.device if self.args.world_size > 1 else "cpu" + if self.is_fsdp_enabled: + load_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, + self.accelerator, + self.optimizer, + self.model, + checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) + + def _load_callback_state(self): + """If callback states exist and were passed in, restore their states if enabled""" + if not self.args.restore_callback_states_from_checkpoint: + return + # Callback states are stored in stateful_callbacks + not_found = [] + new_callbacks = [] + original_callbacks = self.callback_handler.callbacks + [self.control] + for stored_callback, data in self.state.stateful_callbacks.items(): + if not isinstance(data, list): + data = [data] + if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): + # We can load/restore from multiple callbacks of the same type. + duplicates = [ + callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback + ] + for callback, callback_data in zip(duplicates, data): + args = callback_data.get("args", {}) + attributes = callback_data.get("attributes", {}) + new_callback = type(callback)(**args) + for attribute, value in attributes.items(): + setattr(new_callback, attribute, value) + if isinstance(callback, TrainerControl): + # Specifically for restoring the `control` state + self.control = new_callback + else: + new_callbacks.append(new_callback) + # We remove the existing callback and add it to the list of new callbacks + self.callback_handler.remove_callback(type(new_callback)) + logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in") + else: + not_found.append(stored_callback) + if len(not_found) > 0: + logger.warning( + f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})" + ) + for callback in new_callbacks: + self.callback_handler.add_callback(callback) + + def hyperparameter_search( + self, + hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, + compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, + n_trials: int = 20, + direction: Union[str, List[str]] = "minimize", + backend: Optional[Union["str", HPSearchBackend]] = None, + hp_name: Optional[Callable[["optuna.Trial"], str]] = None, + **kwargs, + ) -> Union[BestRun, List[BestRun]]: + """ + Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined + by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, + the sum of all metrics otherwise. + + + + To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to + reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to + subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom + optimizer/scheduler. + + + + Args: + hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): + A function that defines the hyperparameter search space. Will default to + [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or + [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. + compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): + A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` + method. Will default to [`~trainer_utils.default_compute_objective`]. + n_trials (`int`, *optional*, defaults to 100): + The number of trial runs to test. + direction (`str` or `List[str]`, *optional*, defaults to `"minimize"`): + If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you + should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or + several metrics. If it's multi objectives optimization, direction is `List[str]`, can be List of + `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss, + `"maximize"` when optimizing one or several metrics. + backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): + The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending + on which one is installed. If all are installed, will default to optuna. + hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): + A function that defines the trial/run name. Will default to None. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments for each backend: + + - `optuna`: parameters from + [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) + and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from + [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize) + - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run). + If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available). + If `progress_reporter` is not set in the `kwargs`, + [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used. + - `sigopt`: the parameter `proxies` from + [sigopt.Connection.set_proxies](https://docs.sigopt.com/support/faq#how-do-i-use-sigopt-with-a-proxy). + + Returns: + [`trainer_utils.BestRun` or `List[trainer_utils.BestRun]`]: All the information about the best run or best + runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray + backend. + """ + if backend is None: + backend = default_hp_search_backend() + backend = HPSearchBackend(backend) + backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]() + backend_obj.ensure_available() + self.hp_search_backend = backend + if self.model_init is None: + raise RuntimeError( + "To use hyperparameter search, you need to pass your model through a model_init function." + ) + + self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space + self.hp_name = hp_name + self.compute_objective = default_compute_objective if compute_objective is None else compute_objective + + best_run = backend_obj.run(self, n_trials, direction, **kwargs) + + self.hp_search_backend = None + return best_run + + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + start_time (`Optional[float]`): + The start of training. + """ + if self.state.epoch is not None: + logs["epoch"] = self.state.epoch + if self.args.include_num_input_tokens_seen: + logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen + if start_time is not None: + speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen) + + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)({k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": self.args.device} + if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) + return data.to(**kwargs) + return data + + def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and + handling potential state. + """ + inputs = self._prepare_input(inputs) + if len(inputs) == 0: + raise ValueError( + "The batch received was empty, your model won't be able to train on it. Double-check that your " + f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." + ) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs + + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + return self.autocast_smart_context_manager() + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.use_cpu_amp: + ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + else: + ctx_manager = contextlib.nullcontext() + + return ctx_manager + + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + + inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + # Finally we need to normalize the loss for reporting + if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: + loss = loss / self.args.gradient_accumulation_steps + + self.accelerator.backward(loss, **kwargs) + + return loss.detach() + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + if self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + if _is_peft_model(unwrapped_model): + model_name = unwrapped_model.base_model.model._get_name() + else: + model_name = unwrapped_model._get_name() + # User-defined compute_loss function + if self.compute_loss_func is not None: + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) + elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + + return (loss, outputs) if return_outputs else loss + + def is_local_process_zero(self) -> bool: + """ + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several + machines) main process. + """ + return self.args.local_process_index == 0 + + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + """ + # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global + # process index. + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.args.process_index == 0 + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if is_torch_xla_available(): + self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model_wrapped.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + if IS_SAGEMAKER_MP_POST_1_10: + # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 + Path(os.path.join(output_dir, "user_content.pt")).touch() + elif self.is_fsdp_enabled: + if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and ( + version.parse(accelerate_version) > version.parse("0.24.1") + ): + state_dict = self.accelerator.get_state_dict(self.model) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + elif self.is_deepspeed_enabled: + try: + state_dict = self.accelerator.get_state_dict(self.deepspeed) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + if self.args.should_save: + self._save(output_dir, state_dict={}) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + self.model_wrapped.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + + logger.info(f"Saving model checkpoint to {output_dir}") + model = self.model + xm.mark_step() + + if xm.is_master_ordinal(local=False): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + supported_classes = (PushToHubMixin,) + xm.rendezvous("saving_checkpoint") + if self.is_fsdp_xla_v1_enabled: + ckpt = { + "model": model.state_dict(), + "shard_metadata": model.get_shard_metadata(), + } + ckpt_path = os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}" + ) + # All ranks save sharded checkpoint + xm.save(ckpt, ckpt_path, master_only=False) + # Make sure all ranks have saved checkpoints + xm.rendezvous("save_full_checkpoints") + # Master save full checkpoint + if self.args.should_save: + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + + full_state_dict, _ = consolidate_sharded_model_checkpoints( + ckpt_prefix=os.path.join(output_dir, ""), + ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}", + save_model=False, + ) + model = model.module.module + unwrapped_model = self.accelerator.unwrap_model(model) + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( + output_dir, + state_dict=full_state_dict, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + elif not isinstance(model, supported_classes): + if isinstance(self.accelerator.unwrap_model(model), supported_classes): + self.accelerator.unwrap_model(model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = xm._maybe_convert_to_cpu(model.state_dict()) + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + model.save_pretrained( + output_dir, + is_main_process=self.args.should_save, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), + ) + if self.processing_class is not None and self.args.should_save: + self.processing_class.save_pretrained(output_dir) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + + if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): + self.accelerator.unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def store_flos(self): + # Storing the number of floating-point operations that went into the model + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + self.state.total_flos += ( + distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() + ) + self.current_flos = 0 + else: + self.state.total_flos += self.current_flos + self.current_flos = 0 + + def _sorted_checkpoints( + self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False + ) -> List[str]: + ordering_and_checkpoint_path = [] + + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] + + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + # Make sure we don't delete the best model. + if ( + self.state.best_model_checkpoint is not None + and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted + ): + best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] + return checkpoints_sorted + + def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return + + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + def evaluate( + self, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will + evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the + `__len__` method. + + + + If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run + separate evaluations on each dataset. This can be useful to monitor how training affects other + datasets or simply to get a more fine-grained evaluation. + When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one + of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets + `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the + loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`. + + + + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + # handle multipe eval datasets + override = eval_dataset is not None + eval_dataset = eval_dataset if override else self.eval_dataset + if isinstance(eval_dataset, dict): + metrics = {} + for eval_dataset_name, _eval_dataset in eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=_eval_dataset if override else eval_dataset_name, + ignore_keys=ignore_keys, + metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + return metrics + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + if self.is_fsdp_xla_v2_enabled: + eval_dataloader = tpu_spmd_dataloader(eval_dataloader) + + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def predict( + self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + + + If your predictions or labels have different sequence length (for instance because you're doing dynamic padding + in a token classification task) the predictions will be padded (on the right) to allow for concatenation into + one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + test_dataloader = self.get_test_dataloader(test_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8") + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + self.model_preparation_time = round(time.time() - start_time, 4) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"\n***** Running {description} *****") + if has_length(dataloader): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) + + if args.past_index >= 0: + self._past = None + + # Initialize containers + all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + + metrics = None + eval_set_kwargs = {} + + # Will be useful when we have an iterable dataset so don't know its length. + observed_num_examples = 0 + + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + main_input_name = getattr(self.model, "main_input_name", "input_ids") + inputs_decode = ( + self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None + ) + + if is_torch_xla_available(): + xm.mark_step() + + # Update containers + if losses is not None: + losses = self.gather_function((losses.repeat(batch_size))) + all_losses.add(losses) + if inputs_decode is not None: + inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) + inputs_decode = self.gather_function((inputs_decode)) + if not self.args.batch_eval_metrics or description == "Prediction": + all_inputs.add(inputs_decode) + if labels is not None: + # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block. + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + if logits is not None: + logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + logits = self.gather_function((logits)) + if not self.args.batch_eval_metrics or description == "Prediction": + all_preds.add(logits) + if labels is not None: + labels = self.gather_function((labels)) + if not self.args.batch_eval_metrics or description == "Prediction": + all_labels.add(labels) + + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and logits is not None and labels is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + batch_kwargs = {} + batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None + batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs), + compute_result=is_last_step, + ) + + del losses, logits, labels, inputs + torch.cuda.empty_cache() + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + all_losses.to_cpu_and_numpy() + all_preds.to_cpu_and_numpy() + all_labels.to_cpu_and_numpy() + all_inputs.to_cpu_and_numpy() + + del losses, logits, labels, inputs + torch.cuda.empty_cache() + + # After all calls to `.gather_function`, reset to `gather_for_metrics`: + self.gather_function = self.accelerator.gather_for_metrics + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + all_losses = all_losses.get_arrays() + all_preds = all_preds.get_arrays() + all_labels = all_labels.get_arrays() + all_inputs = all_inputs.get_arrays() + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + if num_samples == 0 and observed_num_examples > 0: + num_samples = observed_num_examples + + # Metrics! + if ( + self.compute_metrics is not None + and all_preds is not None + and all_labels is not None + and not self.args.batch_eval_metrics + ): + eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None + eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs) + ) + elif metrics is None: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if isinstance(all_losses, list) and all_losses: + metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() + elif isinstance(all_losses, np.ndarray): + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + if hasattr(self, "model_preparation_time"): + metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_xla_available(): + if name is None: + name = "nested_gather" + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( + self.args.distributed_state is None and self.args.local_rank != -1 + ): + tensors = distributed_concat(tensors) + return tensors + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point + operations for every backward + forward pass. If using another model, either implement such a method in the + model or subclass and override this method. + + Args: + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + `int`: The number of floating-point operations. + """ + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) + else: + return 0 + + def init_hf_repo(self, token: Optional[str] = None): + """ + Initializes a git repo in `self.args.hub_model_id`. + """ + # Only on process zero + if not self.is_world_process_zero(): + return + + if self.args.hub_model_id is None: + repo_name = Path(self.args.output_dir).absolute().name + else: + repo_name = self.args.hub_model_id + + token = token if token is not None else self.args.hub_token + repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True) + self.hub_model_id = repo_url.repo_id + self.push_in_progress = None + + def create_model_card( + self, + language: Optional[str] = None, + license: Optional[str] = None, + tags: Union[str, List[str], None] = None, + model_name: Optional[str] = None, + finetuned_from: Optional[str] = None, + tasks: Union[str, List[str], None] = None, + dataset_tags: Union[str, List[str], None] = None, + dataset: Union[str, List[str], None] = None, + dataset_args: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `List[str]`, *optional*): + Some tags to be included in the metadata of the model card. + model_name (`str`, *optional*): + The name of the model. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `List[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `List[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `List[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `List[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + if not self.is_world_process_zero(): + return + + model_card_filepath = os.path.join(self.args.output_dir, "README.md") + is_peft_library = False + if os.path.exists(model_card_filepath): + library_name = ModelCard.load(model_card_filepath).data.get("library_name") + is_peft_library = library_name == "peft" + + # Append existing tags in `tags` + existing_tags = ModelCard.load(model_card_filepath).data.tags + if tags is not None and existing_tags is not None: + if isinstance(tags, str): + tags = [tags] + for tag in existing_tags: + if tag not in tags: + tags.append(tag) + + training_summary = TrainingSummary.from_trainer( + self, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(model_card_filepath, "w") as f: + f.write(model_card) + + if is_peft_library: + self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + + def _push_from_checkpoint(self, checkpoint_folder): + # Only push from one node. + if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: + return + # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True. + if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done(): + return + + output_dir = self.args.output_dir + # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder + modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + # Add sharded checkpoints if we have an index + for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + index_path = os.path.join(checkpoint_folder, index_file) + if os.path.isfile(index_path): + modeling_files.append(index_file) + with open(index_path) as f: + index = json.loads(f.read()) + shard_files = list(set(index["weight_map"].values())) + modeling_files.extend(shard_files) + if is_peft_available(): + modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) + for modeling_file in modeling_files: + if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): + shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) + # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + # Same for the training arguments + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + if self.args.save_strategy == SaveStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" + + model_push_job = upload_folder( + repo_id=self.hub_model_id, + folder_path=output_dir, + commit_message=commit_message, + token=self.args.hub_token, + run_as_future=True, + ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], + ) + + push_jobs = [model_push_job] + + if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]: + path_in_repo = ( + "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name + ) + checkpoint_push = upload_folder( + repo_id=self.hub_model_id, + folder_path=checkpoint_folder, + path_in_repo=path_in_repo, + commit_message=commit_message + ", checkpoint", + token=self.args.hub_token, + run_as_future=True, + ) + push_jobs.append(checkpoint_push) + + if self.push_in_progress is None or self.push_in_progress.is_done(): + self.push_in_progress = PushInProgress(push_jobs) + else: + self.push_in_progress.jobs.extend(push_jobs) + + def _finish_current_push(self): + if not hasattr(self, "push_in_progress"): + return + if self.push_in_progress is not None and not self.push_in_progress.is_done(): + logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.") + self.push_in_progress.wait_until_done() + + def push_to_hub( + self, + commit_message: Optional[str] = "End of training", + blocking: bool = True, + token: Optional[str] = None, + revision: Optional[str] = None, + **kwargs, + ) -> str: + """ + Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`. + + Parameters: + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + token (`str`, *optional*, defaults to `None`): + Token with write permission to overwrite Trainer's original args. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the "main" branch. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to [`~Trainer.create_model_card`]. + + Returns: + The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the + progress of the commit if `blocking=True`. + """ + model_name = kwargs.pop("model_name", None) + if model_name is None and self.args.should_save: + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + token = token if token is not None else self.args.hub_token + + # In case the user calls this method with args.push_to_hub = False + if self.hub_model_id is None: + self.init_hf_repo(token=token) + + # Needs to be executed on all processes for TPU training, but will only save on the processed determined by + # self.args.should_save. + self.save_model(_internal_call=True) + + # Only push from one node. + if not self.is_world_process_zero(): + return + + # Add additional tags in the case the model has already some tags and users pass + # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags + # from all models since Trainer does not call `model.push_to_hub`. + if getattr(self.model, "model_tags", None) is not None: + if "tags" not in kwargs: + kwargs["tags"] = [] + + # If it is a string, convert it to a list + if isinstance(kwargs["tags"], str): + kwargs["tags"] = [kwargs["tags"]] + + for model_tag in self.model.model_tags: + if model_tag not in kwargs["tags"]: + kwargs["tags"].append(model_tag) + + self.create_model_card(model_name=model_name, **kwargs) + + # Wait for the current upload to be finished. + self._finish_current_push() + return upload_folder( + repo_id=self.hub_model_id, + folder_path=self.args.output_dir, + commit_message=commit_message, + token=token, + run_as_future=not blocking, + ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], + revision=revision, + ) + + # + # Deprecated code + # + + def prediction_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + if not has_length(dataloader): + raise ValueError("dataloader must implement a working __len__") + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled or self.is_fsdp_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = ( + dataloader.total_batch_size + if getattr(dataloader, "_is_accelerate_prepared", False) + else dataloader.batch_size + ) + + if batch_size is None: + raise ValueError( + "Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size." + ) + + num_examples = self.num_examples(dataloader) + logger.info(f"\n***** Running {description} *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Batch size = {batch_size}") + + losses_host: torch.Tensor = None + preds_host: Union[torch.Tensor, List[torch.Tensor]] = None + labels_host: Union[torch.Tensor, List[torch.Tensor]] = None + inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + metrics: Optional[dict] = None + eval_set_kwargs: dict = {} + + world_size = max(1, args.world_size) + + eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + if not prediction_loss_only: + # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass + # a batch size to the sampler) + make_multiple_of = None + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): + make_multiple_of = dataloader.sampler.batch_size + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + + model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() + + if args.past_index >= 0: + self._past = None + + self.callback_handler.eval_dataloader = dataloader + + for step, inputs in enumerate(dataloader): + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + main_input_name = getattr(self.model, "main_input_name", "input_ids") + inputs_decode = ( + self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None + ) + + if loss is not None: + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if logits is not None: + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and preds_host is not None and labels_host is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + batch_kwargs = {} + batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None + batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs), + compute_result=is_last_step, + ) + + if self.args.batch_eval_metrics or ( + args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 + ): + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + # Set back to None to begin a new accumulation + del losses_host, preds_host, labels_host, inputs_host + torch.cuda.empty_cache() + losses_host, preds_host, labels_host, inputs_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + eval_loss = eval_losses_gatherer.finalize() + preds = preds_gatherer.finalize() if not prediction_loss_only else None + label_ids = labels_gatherer.finalize() if not prediction_loss_only else None + inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None + + if ( + self.compute_metrics is not None + and preds is not None + and label_ids is not None + and not self.args.batch_eval_metrics + ): + eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None + eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs)) + elif metrics is None: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if eval_loss is not None: + metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) + + def _gather_and_numpify(self, tensors, name): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_xla_available(): + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + tensors = distributed_concat(tensors) + + return nested_numpify(tensors) + + def _add_sm_patterns_to_gitignore(self) -> None: + """Add SageMaker Checkpointing patterns to .gitignore file.""" + # Make sure we only do this on the main process + if not self.is_world_process_zero(): + return + + patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] + + # Get current .gitignore content + if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): + with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: + current_content = f.read() + else: + current_content = "" + + # Add the patterns to .gitignore + content = current_content + for pattern in patterns: + if pattern not in content: + if content.endswith("\n"): + content += pattern + else: + content += f"\n{pattern}" + + # Write the .gitignore file if it has changed + if content != current_content: + with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: + logger.debug(f"Writing .gitignore file. Content: {content}") + f.write(content) + + self.repo.git_add(".gitignore") + + # avoid race condition with git status + time.sleep(0.5) + + if not self.repo.is_repo_clean(): + self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") + self.repo.git_push() + + def create_accelerator_and_postprocess(self): + # We explicitly don't rely on the `Accelerator` to do gradient accumulation + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs: + if self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] + + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + if is_accelerate_available("1.1.0"): + dataloader_config.data_seed = self.args.data_seed + + non_blocking = accelerator_config.pop("non_blocking") + if not is_accelerate_available("0.30.0"): + if non_blocking: + raise ImportError( + "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." + ) + else: + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning( + "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." + ) + dataloader_config.non_blocking = non_blocking + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + } + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + + # create accelerator object + self.accelerator = Accelerator(**args) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() + + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") + + # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3 + if ( + self.is_deepspeed_enabled + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 + and self.args.auto_find_batch_size + ): + raise ValueError( + "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP" + ) + + def propagate_args_to_deepspeed(self, auto_find_batch_size=False): + """ + Sets values in the deepspeed plugin based on the Trainer args + """ + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + ds_plugin = self.accelerator.state.deepspeed_plugin + + ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size) + + def _fsdp_qlora_plugin_updates(self): + if self.is_fsdp_enabled and _is_peft_model(self.model): + from peft import LoraConfig + from peft.utils.other import fsdp_auto_wrap_policy + + if isinstance(self.model.active_peft_config, LoraConfig): + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model) + if ( + getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point + and version.parse(accelerate_version) > version.parse("0.27.0") + ): + fsdp_plugin.set_mixed_precision( + self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True + ) + + def get_batch_samples(self, epoch_iterator, num_batches): + batch_samples = [] + num_items_in_batch = None + for _ in range(num_batches): + try: + batch_samples += [next(epoch_iterator)] + except StopIteration: + break + + if len(batch_samples) > 0 and "labels" in batch_samples[0]: + # For now we don't support object detection + try: + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + except (TypeError, AttributeError): + pass + + if self.args.average_tokens_across_devices and num_items_in_batch is not None: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + + return batch_samples, num_items_in_batch diff --git a/trainer_callback.py b/trainer_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..7b711f65701d44e6ddfec089aefa2a35992531d6 --- /dev/null +++ b/trainer_callback.py @@ -0,0 +1,744 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Callbacks to use with the Trainer class and customize the training loop. +""" + +import dataclasses +import json +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import numpy as np +from tqdm.auto import tqdm + +from .trainer_utils import IntervalStrategy, SaveStrategy, has_length +from .training_args import TrainingArguments +from .utils import logging + + +logger = logging.get_logger(__name__) + + +@dataclass +class TrainerState: + """ + A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing + and passed to the [`TrainerCallback`]. + + + + In all this class, one step is to be understood as one update step. When using gradient accumulation, one update + step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update + step requires going through *n* batches. + + + + Args: + epoch (`float`, *optional*): + Only set during training, will represent the epoch the training is at (the decimal part being the + percentage of the current epoch completed). + global_step (`int`, *optional*, defaults to 0): + During training, represents the number of update steps completed. + max_steps (`int`, *optional*, defaults to 0): + The number of update steps to do during the current training. + logging_steps (`int`, *optional*, defaults to 500): + Log every X updates steps + eval_steps (`int`, *optional*): + Run an evaluation every X steps. + save_steps (`int`, *optional*, defaults to 500): + Save checkpoint every X updates steps. + train_batch_size (`int`, *optional*): + The batch size for the training dataloader. Only needed when + `auto_find_batch_size` has been used. + num_input_tokens_seen (`int`, *optional*, defaults to 0): + When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the + number of prediction tokens). + total_flos (`float`, *optional*, defaults to 0): + The total number of floating operations done by the model since the beginning of training (stored as floats + to avoid overflow). + log_history (`List[Dict[str, float]]`, *optional*): + The list of logs done since the beginning of training. + best_metric (`float`, *optional*): + When tracking the best model, the value of the best metric encountered so far. + best_model_checkpoint (`str`, *optional*): + When tracking the best model, the value of the name of the checkpoint for the best model encountered so + far. + is_local_process_zero (`bool`, *optional*, defaults to `True`): + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on + several machines) main process. + is_world_process_zero (`bool`, *optional*, defaults to `True`): + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + is_hyper_param_search (`bool`, *optional*, defaults to `False`): + Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will + impact the way data will be logged in TensorBoard. + stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*): + Callbacks attached to the `Trainer` that should have their states be saved or restored. + Relevent callbacks should implement a `state` and `from_state` function. + """ + + epoch: Optional[float] = None + global_step: int = 0 + max_steps: int = 0 + logging_steps: int = 500 + eval_steps: int = 500 + save_steps: int = 500 + train_batch_size: int = None + num_train_epochs: int = 0 + num_input_tokens_seen: int = 0 + total_flos: float = 0 + log_history: List[Dict[str, float]] = None + best_metric: Optional[float] = None + best_model_checkpoint: Optional[str] = None + is_local_process_zero: bool = True + is_world_process_zero: bool = True + is_hyper_param_search: bool = False + trial_name: str = None + trial_params: Dict[str, Union[str, float, int, bool]] = None + stateful_callbacks: List["TrainerCallback"] = None + + def __post_init__(self): + if self.log_history is None: + self.log_history = [] + if self.stateful_callbacks is None: + self.stateful_callbacks = {} + elif isinstance(self.stateful_callbacks, dict): + # We are loading the callbacks in from the state file, no need to process them + pass + else: + # Saveable callbacks get stored as dict of kwargs + stateful_callbacks = {} + for callback in self.stateful_callbacks: + if not isinstance(callback, (ExportableState)): + raise TypeError( + f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}" + ) + name = callback.__class__.__name__ + if name in stateful_callbacks: + # We can have multiple versions of the same callback + # if so, we store them as a list of states to restore + if not isinstance(stateful_callbacks[name], list): + stateful_callbacks[name] = [stateful_callbacks[name]] + stateful_callbacks[name].append(callback.state()) + else: + stateful_callbacks[name] = callback.state() + self.stateful_callbacks = stateful_callbacks + + def save_to_json(self, json_path: str): + """Save the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """Create an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) + + +class ExportableState: + """ + A class for objects that include the ability to have its state + be saved during `Trainer._save_checkpoint` and loaded back in during + `Trainer._load_from_checkpoint`. + + These must implement a `state` function that gets called during the respective + Trainer function call. It should only include parameters and attributes needed to + recreate the state at a particular time, to avoid utilizing pickle/maintain standard + file IO writing. + + Example: + + ```python + class EarlyStoppingCallback(TrainerCallback, ExportableState): + def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def state(self) -> dict: + return { + "args": { + "early_stopping_patience": self.early_stopping_patience, + "early_stopping_threshold": self.early_stopping_threshold, + }, + "attributes": { + "early_stopping_patience_counter": self.early_stopping_patience_counter, + } + } + ```""" + + def state(self) -> dict: + raise NotImplementedError("You must implement a `state` function to utilize this class.") + + @classmethod + def from_state(cls, state): + instance = cls(**state["args"]) + for k, v in state["attributes"].items(): + setattr(instance, k, v) + return instance + + +@dataclass +class TrainerControl(ExportableState): + """ + A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some + switches in the training loop. + + Args: + should_training_stop (`bool`, *optional*, defaults to `False`): + Whether or not the training should be interrupted. + + If `True`, this variable will not be set back to `False`. The training will just stop. + should_epoch_stop (`bool`, *optional*, defaults to `False`): + Whether or not the current epoch should be interrupted. + + If `True`, this variable will be set back to `False` at the beginning of the next epoch. + should_save (`bool`, *optional*, defaults to `False`): + Whether or not the model should be saved at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + should_evaluate (`bool`, *optional*, defaults to `False`): + Whether or not the model should be evaluated at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + should_log (`bool`, *optional*, defaults to `False`): + Whether or not the logs should be reported at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + """ + + should_training_stop: bool = False + should_epoch_stop: bool = False + should_save: bool = False + should_evaluate: bool = False + should_log: bool = False + + def _new_training(self): + """Internal method that resets the variable for a new training.""" + self.should_training_stop = False + + def _new_epoch(self): + """Internal method that resets the variable for a new epoch.""" + self.should_epoch_stop = False + + def _new_step(self): + """Internal method that resets the variable for a new step.""" + self.should_save = False + self.should_evaluate = False + self.should_log = False + + def state(self) -> dict: + return { + "args": { + "should_training_stop": self.should_training_stop, + "should_epoch_stop": self.should_epoch_stop, + "should_save": self.should_save, + "should_evaluate": self.should_evaluate, + "should_log": self.should_log, + }, + "attributes": {}, + } + + +class TrainerCallback: + # no-format + """ + A class for objects that will inspect the state of the training loop at some events and take some decisions. At + each of those events the following arguments are available: + + Args: + args ([`TrainingArguments`]): + The training arguments used to instantiate the [`Trainer`]. + state ([`TrainerState`]): + The current state of the [`Trainer`]. + control ([`TrainerControl`]): + The object that is returned to the [`Trainer`] and can be used to make some decisions. + model ([`PreTrainedModel`] or `torch.nn.Module`): + The model being trained. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`. + processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]): + The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor. + optimizer (`torch.optim.Optimizer`): + The optimizer used for the training steps. + lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`): + The scheduler used for setting the learning rate. + train_dataloader (`torch.utils.data.DataLoader`, *optional*): + The current dataloader used for training. + eval_dataloader (`torch.utils.data.DataLoader`, *optional*): + The current dataloader used for evaluation. + metrics (`Dict[str, float]`): + The metrics computed by the last evaluation phase. + + Those are only accessible in the event `on_evaluate`. + logs (`Dict[str, float]`): + The values to log. + + Those are only accessible in the event `on_log`. + + The `control` object is the only one that can be changed by the callback, in which case the event that changes it + should return the modified version. + + The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`. + You can unpack the ones you need in the signature of the event using them. As an example, see the code of the + simple [`~transformers.PrinterCallback`]. + + Example: + + ```python + class PrinterCallback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + ```""" + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of the initialization of the [`Trainer`]. + """ + pass + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of training. + """ + pass + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of training. + """ + pass + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of an epoch. + """ + pass + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of an epoch. + """ + pass + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients. + """ + pass + + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. + """ + pass + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of an substep during gradient accumulation. + """ + pass + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after an evaluation phase. + """ + pass + + def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): + """ + Event called after a successful prediction. + """ + pass + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a checkpoint save. + """ + pass + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after logging the last logs. + """ + pass + + def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a prediction step. + """ + pass + + +class CallbackHandler(TrainerCallback): + """Internal class that just calls the list of callbacks in order.""" + + def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler): + self.callbacks = [] + for cb in callbacks: + self.add_callback(cb) + self.model = model + self.processing_class = processing_class + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.train_dataloader = None + self.eval_dataloader = None + + if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): + logger.warning( + "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n" + + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of" + + "callbacks is\n:" + + self.callback_list + ) + + def add_callback(self, callback): + cb = callback() if isinstance(callback, type) else callback + cb_class = callback if isinstance(callback, type) else callback.__class__ + if cb_class in [c.__class__ for c in self.callbacks]: + logger.warning( + f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current" + + "list of callbacks is\n:" + + self.callback_list + ) + self.callbacks.append(cb) + + def pop_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return cb + else: + for cb in self.callbacks: + if cb == callback: + self.callbacks.remove(cb) + return cb + + def remove_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return + else: + self.callbacks.remove(callback) + + @property + def callback_list(self): + return "\n".join(cb.__class__.__name__ for cb in self.callbacks) + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_init_end", args, state, control) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_training_stop = False + return self.call_event("on_train_begin", args, state, control) + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_train_end", args, state, control) + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_epoch_stop = False + return self.call_event("on_epoch_begin", args, state, control) + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_epoch_end", args, state, control) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_log = False + control.should_evaluate = False + control.should_save = False + return self.call_event("on_step_begin", args, state, control) + + def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_pre_optimizer_step", args, state, control) + + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_optimizer_step", args, state, control) + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_substep_end", args, state, control) + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_step_end", args, state, control) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + control.should_evaluate = False + return self.call_event("on_evaluate", args, state, control, metrics=metrics) + + def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + return self.call_event("on_predict", args, state, control, metrics=metrics) + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_save = False + return self.call_event("on_save", args, state, control) + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs): + control.should_log = False + return self.call_event("on_log", args, state, control, logs=logs) + + def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_prediction_step", args, state, control) + + def call_event(self, event, args, state, control, **kwargs): + for callback in self.callbacks: + result = getattr(callback, event)( + args, + state, + control, + model=self.model, + processing_class=self.processing_class, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + train_dataloader=self.train_dataloader, + eval_dataloader=self.eval_dataloader, + **kwargs, + ) + # A Callback can skip the return of `control` if it doesn't change it. + if result is not None: + control = result + return control + + +class DefaultFlowCallback(TrainerCallback): + """ + A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints. + """ + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # Log + if state.global_step == 1 and args.logging_first_step: + control.should_log = True + if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0: + control.should_log = True + + # Evaluate + if ( + args.eval_strategy == IntervalStrategy.STEPS + and state.global_step % state.eval_steps == 0 + and args.eval_delay <= state.global_step + ): + control.should_evaluate = True + + # Save + if ( + args.save_strategy == SaveStrategy.STEPS + and state.save_steps > 0 + and state.global_step % state.save_steps == 0 + ): + control.should_save = True + + # End training + if state.global_step >= state.max_steps: + control.should_training_stop = True + # Save the model at the end if we have a save strategy + if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]: + control.should_save = True + + return control + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # Log + if args.logging_strategy == IntervalStrategy.EPOCH: + control.should_log = True + + # Evaluate + if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch: + control.should_evaluate = True + + # Save + if args.save_strategy == SaveStrategy.EPOCH: + control.should_save = True + + return control + + +class ProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. + """ + + def __init__(self, max_str_len: int = 100): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ + self.training_bar = None + self.prediction_bar = None + self.max_str_len = max_str_len + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.update(state.global_step - self.current_step) + self.current_step = state.global_step + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if state.is_world_process_zero and has_length(eval_dataloader): + if self.prediction_bar is None: + self.prediction_bar = tqdm( + total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True + ) + self.prediction_bar.update(1) + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_predict(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero and self.training_bar is not None: + # make a shallow copy of logs so we can mutate the fields copied + # but avoid doing any value pickling. + shallow_logs = {} + for k, v in logs.items(): + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) + else: + shallow_logs[k] = v + _ = shallow_logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in shallow_logs: + shallow_logs["epoch"] = round(shallow_logs["epoch"], 2) + self.training_bar.write(str(shallow_logs)) + + def on_train_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.close() + self.training_bar = None + + +class PrinterCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + + +class EarlyStoppingCallback(TrainerCallback, ExportableState): + """ + A [`TrainerCallback`] that handles early stopping. + + Args: + early_stopping_patience (`int`): + Use with `metric_for_best_model` to stop training when the specified metric worsens for + `early_stopping_patience` evaluation calls. + early_stopping_threshold(`float`, *optional*): + Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the + specified metric must improve to satisfy early stopping conditions. ` + + This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric + in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the + early stopping will not occur until the next save step. + """ + + def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def check_metric_value(self, args, state, control, metric_value): + # best_metric is set by code for load_best_model + operator = np.greater if args.greater_is_better else np.less + if state.best_metric is None or ( + operator(metric_value, state.best_metric) + and abs(metric_value - state.best_metric) > self.early_stopping_threshold + ): + self.early_stopping_patience_counter = 0 + else: + self.early_stopping_patience_counter += 1 + + def on_train_begin(self, args, state, control, **kwargs): + assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True" + assert ( + args.metric_for_best_model is not None + ), "EarlyStoppingCallback requires metric_for_best_model is defined" + assert ( + args.eval_strategy != IntervalStrategy.NO + ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch" + + def on_evaluate(self, args, state, control, metrics, **kwargs): + metric_to_check = args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics.get(metric_to_check) + + if metric_value is None: + logger.warning( + f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping" + " is disabled" + ) + return + + self.check_metric_value(args, state, control, metric_value) + if self.early_stopping_patience_counter >= self.early_stopping_patience: + control.should_training_stop = True + + def state(self) -> dict: + return { + "args": { + "early_stopping_patience": self.early_stopping_patience, + "early_stopping_threshold": self.early_stopping_threshold, + }, + "attributes": { + "early_stopping_patience_counter": self.early_stopping_patience_counter, + }, + } diff --git a/trainer_pt_utils.py b/trainer_pt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da95329e1845676ddde9086fa9a8c0cf67233c3a --- /dev/null +++ b/trainer_pt_utils.py @@ -0,0 +1,1392 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Torch utilities for the Trainer class. +""" + +import copy +import datetime +import io +import json +import math +import os +import sys +import warnings +from collections.abc import Mapping +from contextlib import contextmanager +from dataclasses import dataclass, field +from itertools import chain +from logging import StreamHandler +from typing import Any, Dict, Iterator, List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler +from torch.utils.data.distributed import DistributedSampler + +from .integrations.deepspeed import is_deepspeed_zero3_enabled +from .tokenization_utils_base import BatchEncoding +from .utils import ( + is_sagemaker_mp_enabled, + is_torch_available, + is_torch_xla_available, + is_training_run_on_sagemaker, + logging, +) + + +if is_training_run_on_sagemaker(): + logging.add_handler(StreamHandler(sys.stdout)) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +if is_torch_available(): + from torch.optim.lr_scheduler import LRScheduler + + +logger = logging.get_logger(__name__) + + +def get_dataloader_sampler(dataloader): + if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: + return get_dataloader_sampler(dataloader.batch_sampler) + elif hasattr(dataloader, "sampler"): + return dataloader.sampler + + +def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): + if isinstance(tensor_or_array, torch.Tensor): + if hasattr(torch, "atleast_1d"): + tensor_or_array = torch.atleast_1d(tensor_or_array) + elif tensor_or_array.ndim < 1: + tensor_or_array = tensor_or_array[None] + else: + tensor_or_array = np.atleast_1d(tensor_or_array) + return tensor_or_array + + +def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100): + """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary.""" + tensor1 = atleast_1d(tensor1) + tensor2 = atleast_1d(tensor2) + + if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]: + return torch.cat((tensor1, tensor2), dim=0) + + # Let's figure out the new shape + new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:] + + # Now let's fill the result tensor + result = tensor1.new_full(new_shape, padding_index) + result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1 + result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2 + return result + + +def numpy_pad_and_concatenate(array1, array2, padding_index=-100): + """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary.""" + array1 = atleast_1d(array1) + array2 = atleast_1d(array2) + + if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]: + return np.concatenate((array1, array2), axis=0) + + # Let's figure out the new shape + new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:] + + # Now let's fill the result tensor + result = np.full_like(array1, padding_index, shape=new_shape) + result[: array1.shape[0], : array1.shape[1]] = array1 + result[array1.shape[0] :, : array2.shape[1]] = array2 + return result + + +def nested_concat(tensors, new_tensors, padding_index=-100): + """ + Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or + nested list/tuples/dict of tensors. + """ + if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)): + assert ( + type(tensors) is type(new_tensors) + ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors)) + elif isinstance(tensors, torch.Tensor): + return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index) + elif isinstance(tensors, Mapping): + return type(tensors)( + {k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()} + ) + elif isinstance(tensors, np.ndarray): + return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index) + else: + raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") + + +def find_batch_size(tensors): + """ + Find the first dimension of a tensor in a nested list/tuple/dict of tensors. + """ + if isinstance(tensors, (list, tuple)): + for t in tensors: + result = find_batch_size(t) + if result is not None: + return result + elif isinstance(tensors, Mapping): + for key, value in tensors.items(): + result = find_batch_size(value) + if result is not None: + return result + elif isinstance(tensors, torch.Tensor): + return tensors.shape[0] if len(tensors.shape) >= 1 else None + elif isinstance(tensors, np.ndarray): + return tensors.shape[0] if len(tensors.shape) >= 1 else None + + +def nested_numpify(tensors): + "Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_numpify(t) for t in tensors) + if isinstance(tensors, Mapping): + return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()}) + + t = tensors.cpu() + if t.dtype == torch.bfloat16: + # As of Numpy 1.21.4, NumPy does not support bfloat16 (see + # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ). + # Until Numpy adds bfloat16, we must convert float32. + t = t.to(torch.float32) + return t.numpy() + + +def nested_detach(tensors): + "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t) for t in tensors) + elif isinstance(tensors, Mapping): + return type(tensors)({k: nested_detach(t) for k, t in tensors.items()}) + return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors + + +def nested_xla_mesh_reduce(tensors, name): + if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) + if isinstance(tensors, Mapping): + return type(tensors)( + {k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())} + ) + + tensors = atleast_1d(tensors) + return xm.mesh_reduce(name, tensors, torch.cat) + else: + raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") + + +def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any: + try: + if isinstance(tensor, (tuple, list)): + return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) + if isinstance(tensor, Mapping): + return type(tensor)({k: distributed_concat(t, num_total_examples) for k, t in tensor.items()}) + tensor = atleast_1d(tensor).contiguous() + output_tensors = [tensor.clone() for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, tensor) + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + + +def distributed_broadcast_scalars( + scalars: List[Union[int, float]], + num_total_examples: Optional[int] = None, + device: Optional[torch.device] = torch.device("cuda"), +) -> torch.Tensor: + try: + tensorized_scalar = torch.tensor(scalars).to(device) + output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, tensorized_scalar) + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + + +def reissue_pt_warnings(caught_warnings): + # Reissue warnings + if len(caught_warnings) > 1: + for w in caught_warnings: + if w.category is not UserWarning: + warnings.warn(w.message, w.category) + + +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """ + Decorator to make all processes in distributed training wait for each local_master to do something. + + Args: + local_rank (`int`): The rank of the local process. + """ + if local_rank not in [-1, 0]: + dist.barrier() + yield + if local_rank == 0: + dist.barrier() + + +class DistributedSamplerWithLoop(DistributedSampler): + """ + Like a torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the shuffled + samples to make each process have a round multiple of batch_size samples. + + Args: + dataset (`torch.utils.data.Dataset`): + Dataset used for sampling. + batch_size (`int`): + The batch size used with this sampler + kwargs (`Dict[str, Any]`, *optional*): + All other keyword arguments passed to `DistributedSampler`. + """ + + def __init__(self, dataset, batch_size, **kwargs): + super().__init__(dataset, **kwargs) + self.batch_size = batch_size + + def __iter__(self): + indices = list(super().__iter__()) + remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size + # DistributedSampler already added samples from the beginning to make the number of samples a round multiple + # of the world size, so we skip those. + start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0 + indices += indices[start_remainder : start_remainder + remainder] + return iter(indices) + + +class EvalLoopContainer: + """ + Container to store intermediate results of evaluation loop + + Args: + do_nested_concat (`bool`, *optional*, defaults to `True`): + If set to `True`, each iteration will recursively concatenate a new object containing tensors to + the existing stored tensors, provided that the structure of the existing object and the new one + are identical. If set to `False`, all newly added tensors will be stored in a list. + padding_index (`int`, *optional*, defaults to -100): + Value used to pad tensors of different shapes when `do_nested_concat=True`. + """ + + def __init__(self, do_nested_concat: bool = True, padding_index: int = -100): + self.do_nested_concat = do_nested_concat + self.padding_index = padding_index + self.tensors = None + self.arrays = None + + def add(self, tensors) -> None: + """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively.""" + if self.tensors is None: + self.tensors = tensors if self.do_nested_concat else [tensors] + elif self.do_nested_concat: + self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index) + else: + self.tensors.append(tensors) + + def to_cpu_and_numpy(self) -> None: + """Move tensors in stored objects to CPU and convert them to numpy arrays.""" + + # Check if we have something to add, if not just return + if self.tensors is None: + return + + new_arrays = nested_numpify(self.tensors) + if self.arrays is None: + self.arrays = new_arrays + elif self.do_nested_concat: + self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index) + else: + self.arrays.extend(new_arrays) + + # reset device tensors after adding to cpu + self.tensors = None + + def get_arrays(self): + """Returns the numpified and moved to CPU stored objects.""" + self.to_cpu_and_numpy() + return self.arrays + + +class SequentialDistributedSampler(Sampler): + """ + Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. + + Even though we only use this sampler for eval and predict (no training), which means that the model params won't + have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add + extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather` + or `reduce` resulting tensors at the end of the loop. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): + warnings.warn( + "SequentialDistributedSampler is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + num_samples = len(self.dataset) + # Add extra samples to make num_samples a multiple of batch_size if passed + if batch_size is not None: + self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size + else: + self.num_samples = int(math.ceil(num_samples / num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.batch_size = batch_size + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert ( + len(indices) == self.total_size + ), f"Indices length {len(indices)} and total size {self.total_size} mismatched" + + # subsample + indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] + assert ( + len(indices) == self.num_samples + ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" + + return iter(indices) + + def __len__(self): + return self.num_samples + + +def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int): + if xm.xrt_world_size() <= 1: + return RandomSampler(dataset) + return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + +def nested_new_like(arrays, num_samples, padding_index=-100): + """Create the same nested structure as `arrays` with a first dimension always at `num_samples`.""" + if isinstance(arrays, (list, tuple)): + return type(arrays)(nested_new_like(x, num_samples) for x in arrays) + return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:])) + + +def expand_like(arrays, new_seq_length, padding_index=-100): + """Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding.""" + result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:]) + result[:, : arrays.shape[1]] = arrays + return result + + +def nested_truncate(tensors, limit): + "Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_truncate(t, limit) for t in tensors) + if isinstance(tensors, Mapping): + return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()}) + + return tensors[:limit] + + +class DistributedTensorGatherer: + """ + A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks. + + If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every + step, our sampler will generate the following indices: + + `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]` + + to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and + 2 will be responsible of making predictions for the following samples: + + - P0: `[0, 1, 2, 3, 4, 5]` + - P1: `[6, 7, 8, 9, 10, 11]` + - P2: `[12, 13, 14, 15, 0, 1]` + + The first batch treated on each process will be + + - P0: `[0, 1]` + - P1: `[6, 7]` + - P2: `[12, 13]` + + So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to + the following indices: + + `[0, 1, 6, 7, 12, 13]` + + If we directly concatenate our results without taking any precautions, the user will then get the predictions for + the indices in this order at the end of the prediction loop: + + `[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]` + + For some reason, that's not going to roll their boat. This class is there to solve that problem. + + Args: + world_size (`int`): + The number of processes used in the distributed training. + num_samples (`int`): + The number of samples in our dataset. + make_multiple_of (`int`, *optional*): + If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument + (by adding samples). + padding_index (`int`, *optional*, defaults to -100): + The padding index to use if the arrays don't all have the same sequence length. + """ + + def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100): + warnings.warn( + "DistributedTensorGatherer is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.world_size = world_size + self.num_samples = num_samples + total_size = world_size if make_multiple_of is None else world_size * make_multiple_of + self.total_samples = int(np.ceil(num_samples / total_size)) * total_size + self.process_length = self.total_samples // world_size + self._storage = None + self._offsets = None + self.padding_index = padding_index + + def add_arrays(self, arrays): + """ + Add `arrays` to the internal storage, Will initialize the storage to the full size at the first arrays passed + so that if we're bound to get an OOM, it happens at the beginning. + """ + if arrays is None: + return + if self._storage is None: + self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index) + self._offsets = list(range(0, self.total_samples, self.process_length)) + + slice_len, self._storage = self._nested_set_tensors(self._storage, arrays) + for i in range(self.world_size): + self._offsets[i] += slice_len + + def _nested_set_tensors(self, storage, arrays): + if isinstance(arrays, (list, tuple)): + result = [self._nested_set_tensors(x, y) for x, y in zip(storage, arrays)] + return result[0][0], type(arrays)(r[1] for r in result) + assert ( + arrays.shape[0] % self.world_size == 0 + ), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}." + + slice_len = arrays.shape[0] // self.world_size + for i in range(self.world_size): + if len(arrays.shape) == 1: + storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len] + else: + # Expand the array on the fly if needed. + if len(storage.shape) > 1 and storage.shape[1] < arrays.shape[1]: + storage = expand_like(storage, arrays.shape[1], padding_index=self.padding_index) + storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[ + i * slice_len : (i + 1) * slice_len + ] + return slice_len, storage + + def finalize(self): + """ + Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras + to get each process a dataset of the same length). + """ + if self._storage is None: + return + if self._offsets[0] != self.process_length: + logger.warning("Not all data has been set. Are you sure you passed all values?") + return nested_truncate(self._storage, self.num_samples) + + +@dataclass +class LabelSmoother: + """ + Adds label-smoothing on a pre-computed output from a Transformers model. + + Args: + epsilon (`float`, *optional*, defaults to 0.1): + The label smoothing factor. + ignore_index (`int`, *optional*, defaults to -100): + The index in the labels to ignore when computing the loss. + """ + + epsilon: float = 0.1 + ignore_index: int = -100 + + def __call__(self, model_output, labels, shift_labels=False): + logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0] + if shift_labels: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + + log_probs = -nn.functional.log_softmax(logits, dim=-1) + if labels.dim() == log_probs.dim() - 1: + labels = labels.unsqueeze(-1) + + padding_mask = labels.eq(self.ignore_index) + # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask + # will ignore them in any case. + labels = torch.clamp(labels, min=0) + nll_loss = log_probs.gather(dim=-1, index=labels) + # works for fp16 input tensor too, by internally upcasting it to fp32 + smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) + + nll_loss.masked_fill_(padding_mask, 0.0) + smoothed_loss.masked_fill_(padding_mask, 0.0) + + # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): + num_active_elements = padding_mask.numel() - padding_mask.long().sum() + nll_loss = nll_loss.sum() / num_active_elements + smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1]) + return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss + + +def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() + # Switch to put the longest element in first position + megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0] + + return [i for megabatch in megabatches for i in megabatch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + dataset: Optional[Dataset] = None, + lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, + generator=None, + ): + if dataset is None and lengths is None: + raise ValueError("One of dataset and lengths must be provided.") + + self.batch_size = batch_size + if lengths is None: + model_input_name = model_input_name if model_input_name is not None else "input_ids" + if ( + not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) + or model_input_name not in dataset[0] + ): + raise ValueError( + "Can only automatically infer lengths for datasets whose items are dictionaries with an " + f"'{model_input_name}' key." + ) + lengths = [len(feature[model_input_name]) for feature in dataset] + elif isinstance(lengths, torch.Tensor): + logger.info( + "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..." + ) + lengths = lengths.tolist() + + self.lengths = lengths + self.generator = generator + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator) + return iter(indices) + + +class DistributedLengthGroupedSampler(DistributedSampler): + r""" + Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same + length while keeping a bit of randomness. + """ + + # Copied and adapted from PyTorch DistributedSampler. + def __init__( + self, + batch_size: int, + dataset: Optional[Dataset] = None, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + drop_last: bool = False, + lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, + ): + if dataset is None and lengths is None: + raise ValueError("One of dataset and lengths must be provided.") + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + self.batch_size = batch_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + + if lengths is None: + model_input_name = model_input_name if model_input_name is not None else "input_ids" + if ( + not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) + or model_input_name not in dataset[0] + ): + raise ValueError( + "Can only automatically infer lengths for datasets whose items are dictionaries with an " + f"'{model_input_name}' key." + ) + lengths = [len(feature[model_input_name]) for feature in dataset] + elif isinstance(lengths, torch.Tensor): + logger.info( + "If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to" + " List[int]..." + ) + lengths = lengths.tolist() + + self.lengths = lengths + + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.lengths) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas) + else: + self.num_samples = math.ceil(len(self.lengths) / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.seed = seed + + def __iter__(self) -> Iterator: + # Deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g) + + if not self.drop_last: + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + +class ShardSampler(Sampler): + """ + Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch + size 4, the first two batches are `[0, 1, 2, 3, 4, 5, 6, 7]` and `[8, 9, 10, 11, 12, 13, 14, 15]`, which shard into + `[0, 1, 2, 3]` and `[8, 9, 10, 11]` for GPU-0 and `[4, 5, 6, 7]` and `[12, 13, 14, 15]` for GPU-1. + + The sampler thus yields `[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and `[4, 5, 6, 7, 12, 13, 14, 15]` on GPU-1. + """ + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + drop_last: bool = False, + num_processes: int = 1, + process_index: int = 0, + ): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.num_processes = num_processes + self.process_index = process_index + + self.total_batch_size = total_batch_size = batch_size * num_processes + + num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size) + self.total_num_samples = num_batches * total_batch_size + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset + # and it needs to be done several times. + while len(indices) < self.total_num_samples: + indices += indices[: (self.total_num_samples - len(indices))] + + result = [] + for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size): + result += indices[batch_start : batch_start + self.batch_size] + + return iter(result) + + def __len__(self): + # Each shard only sees a fraction of total_num_samples. + return self.total_num_samples // self.num_processes + + +class IterableDatasetShard(IterableDataset): + """ + Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will + always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x + num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the + first batch that would be too small or loop with indices from the beginning. + + On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of + 2: + + - the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]` + - the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]` + + + + If your IterableDataset implements some randomization that needs to be applied the same way on all processes + (for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to + generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this + object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the + iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with + this. + + + + Args: + dataset (`torch.utils.data.IterableDataset`): + The batch sampler to split in several shards. + batch_size (`int`, *optional*, defaults to 1): + The size of the batches per shard. + drop_last (`bool`, *optional*, defaults to `False`): + Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the + beginning. + num_processes (`int`, *optional*, defaults to 1): + The number of processes running concurrently. + process_index (`int`, *optional*, defaults to 0): + The index of the current process. + seed (`int`, *optional*, defaults to 0): + A random seed that will be used for the random number generation in + [`~trainer_pt_utils.IterableDatasetShard.set_epoch`]. + """ + + def __init__( + self, + dataset: IterableDataset, + batch_size: int = 1, + drop_last: bool = False, + num_processes: int = 1, + process_index: int = 0, + seed: int = 0, + ): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.num_processes = num_processes + self.process_index = process_index + self.seed = seed + self.epoch = 0 + self.num_examples = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __iter__(self): + self.num_examples = 0 + if ( + not hasattr(self.dataset, "set_epoch") + and hasattr(self.dataset, "generator") + and isinstance(self.dataset.generator, torch.Generator) + ): + self.dataset.generator.manual_seed(self.seed + self.epoch) + real_batch_size = self.batch_size * self.num_processes + process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size) + + first_batch = None + current_batch = [] + for element in self.dataset: + self.num_examples += 1 + current_batch.append(element) + # Wait to have a full batch before yielding elements. + if len(current_batch) == real_batch_size: + for i in process_slice: + yield current_batch[i] + if first_batch is None: + first_batch = current_batch.copy() + current_batch = [] + + # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning. + if not self.drop_last and len(current_batch) > 0: + if first_batch is None: + first_batch = current_batch.copy() + while len(current_batch) < real_batch_size: + current_batch += first_batch + for i in process_slice: + yield current_batch[i] + + def __len__(self): + # Will raise an error if the underlying dataset is not sized. + if self.drop_last: + return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size + else: + return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size + + +# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer +# helper methods here + + +def _get_learning_rate(self): + if self.is_deepspeed_enabled: + # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may + # not run for the first few dozen steps while loss scale is too large, and thus during + # that time `get_last_lr` will fail if called during that warm up stage, so work around it: + try: + last_lr = self.lr_scheduler.get_last_lr()[0] + except AssertionError as e: + if "need to call step" in str(e): + logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0") + last_lr = 0 + else: + raise + else: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + last_lr = self.optimizer.param_groups[0]["lr"] + else: + last_lr = self.lr_scheduler.get_last_lr()[0] + if torch.is_tensor(last_lr): + last_lr = last_lr.item() + return last_lr + + +def _secs2timedelta(secs): + """ + convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals + """ + + msec = int(abs(secs - int(secs)) * 100) + return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}" + + +def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]: + """ + Reformat Trainer metrics values to a human-readable format + + Args: + metrics (`Dict[str, float]`): + The metrics returned from train/evaluate/predict + + Returns: + metrics (`Dict[str, float]`): The reformatted metrics + """ + + metrics_copy = metrics.copy() + for k, v in metrics_copy.items(): + if "_mem_" in k: + metrics_copy[k] = f"{ v >> 20 }MB" + elif "_runtime" in k: + metrics_copy[k] = _secs2timedelta(v) + elif k == "total_flos": + metrics_copy[k] = f"{ int(v) >> 30 }GF" + elif isinstance(metrics_copy[k], float): + metrics_copy[k] = round(v, 4) + + return metrics_copy + + +def log_metrics(self, split, metrics): + """ + Log metrics in a specially formatted way + + Under distributed environment this is done only for a process with rank 0. + + Args: + split (`str`): + Mode/split name: one of `train`, `eval`, `test` + metrics (`Dict[str, float]`): + The metrics returned from train/evaluate/predictmetrics: metrics dict + + Notes on memory reports: + + In order to get memory usage report you need to install `psutil`. You can do that with `pip install psutil`. + + Now when this method is run, you will see a report that will include: : + + ``` + init_mem_cpu_alloc_delta = 1301MB + init_mem_cpu_peaked_delta = 154MB + init_mem_gpu_alloc_delta = 230MB + init_mem_gpu_peaked_delta = 0MB + train_mem_cpu_alloc_delta = 1345MB + train_mem_cpu_peaked_delta = 0MB + train_mem_gpu_alloc_delta = 693MB + train_mem_gpu_peaked_delta = 7MB + ``` + + **Understanding the reports:** + + - the first segment, e.g., `train__`, tells you which stage the metrics are for. Reports starting with `init_` + will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the + `__init__` will be reported along with the `eval_` metrics. + - the third segment, is either `cpu` or `gpu`, tells you whether it's the general RAM or the gpu0 memory + metric. + - `*_alloc_delta` - is the difference in the used/allocated memory counter between the end and the start of the + stage - it can be negative if a function released more memory than it allocated. + - `*_peaked_delta` - is any extra memory that was consumed and then freed - relative to the current allocated + memory counter - it is never negative. When you look at the metrics of any stage you add up `alloc_delta` + + `peaked_delta` and you know how much memory was needed to complete that stage. + + The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the + main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may + use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more + memory than the rest since it stores the gradient and optimizer states for all participating GPUS. Perhaps in the + future these reports will evolve to measure those too. + + The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the + memory shared with other processes. It is important to note that it does not include swapped out memory, so the + reports could be imprecise. + + The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if + that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than + reality. Using `tracemalloc` would have reported the exact peak memory, but it doesn't report memory allocations + outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it + was dropped in favor of the memory sampling approach, which reads the current process memory usage. + + The GPU allocated and peak memory reporting is done with `torch.cuda.memory_allocated()` and + `torch.cuda.max_memory_allocated()`. This metric reports only "deltas" for pytorch-specific allocations, as + `torch.cuda` memory management system doesn't track any memory allocated outside of pytorch. For example, the very + first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory. + + Note that this tracker doesn't account for memory allocations outside of [`Trainer`]'s `__init__`, `train`, + `evaluate` and `predict` calls. + + Because `evaluation` calls may happen during `train`, we can't handle nested invocations because + `torch.cuda.max_memory_allocated` is a single counter, so if it gets reset by a nested eval call, `train`'s tracker + will report incorrect info. If this [pytorch issue](https://github.com/pytorch/pytorch/issues/16266) gets resolved + it will be possible to change this class to be re-entrant. Until then we will only track the outer level of + `train`, `evaluate` and `predict` methods. Which means that if `eval` is called during `train`, it's the latter + that will account for its memory usage and that of the former. + + This also means that if any other tool that is used along the [`Trainer`] calls + `torch.cuda.reset_peak_memory_stats`, the gpu peak memory stats could be invalid. And the [`Trainer`] will disrupt + the normal behavior of any such tools that rely on calling `torch.cuda.reset_peak_memory_stats` themselves. + + For best performance you may want to consider turning the memory profiling off for production runs. + """ + if not self.is_world_process_zero(): + return + + print(f"***** {split} metrics *****") + metrics_formatted = self.metrics_format(metrics) + k_width = max(len(str(x)) for x in metrics_formatted.keys()) + v_width = max(len(str(x)) for x in metrics_formatted.values()) + for key in sorted(metrics_formatted.keys()): + print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}") + + +def save_metrics(self, split, metrics, combined=True): + """ + Save metrics into a json file for that split, e.g. `train_results.json`. + + Under distributed environment this is done only for a process with rank 0. + + Args: + split (`str`): + Mode/split name: one of `train`, `eval`, `test`, `all` + metrics (`Dict[str, float]`): + The metrics returned from train/evaluate/predict + combined (`bool`, *optional*, defaults to `True`): + Creates combined metrics by updating `all_results.json` with metrics of this call + + To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw + unformatted numbers are saved in the current method. + + """ + if not self.is_world_process_zero(): + return + + path = os.path.join(self.args.output_dir, f"{split}_results.json") + with open(path, "w") as f: + json.dump(metrics, f, indent=4, sort_keys=True) + + if combined: + path = os.path.join(self.args.output_dir, "all_results.json") + if os.path.exists(path): + with open(path, "r") as f: + all_metrics = json.load(f) + else: + all_metrics = {} + + all_metrics.update(metrics) + with open(path, "w") as f: + json.dump(all_metrics, f, indent=4, sort_keys=True) + + +def save_state(self): + """ + Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model + + Under distributed environment this is done only for a process with rank 0. + """ + if not self.is_world_process_zero(): + return + + path = os.path.join(self.args.output_dir, "trainer_state.json") + self.state.save_to_json(path) + + +def get_model_param_count(model, trainable_only=False): + """ + Calculate model's total param count. If trainable_only is True then count only those requiring grads + """ + if is_deepspeed_zero3_enabled(): + + def numel(p): + return p.ds_numel if hasattr(p, "ds_numel") else p.numel() + + else: + + def numel(p): + return p.numel() + + return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) + + +def get_parameter_names(model, forbidden_layer_types): + """ + Returns the names of the model parameters that are not inside a forbidden layer. + """ + result = [] + for name, child in model.named_children(): + result += [ + f"{name}.{n}" + for n in get_parameter_names(child, forbidden_layer_types) + if not isinstance(child, tuple(forbidden_layer_types)) + ] + # Add model specific parameters (defined with nn.Parameter) since they are not in any child. + result += list(model._parameters.keys()) + return result + + +def get_module_class_from_name(module, name): + """ + Gets a class from a module by its name. + + Args: + module (`torch.nn.Module`): The module to get the class from. + name (`str`): The name of the class. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + elif len(modules_children) == 0: + return + else: + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class + + +def remove_dummy_checkpoint(is_main_process, output_dir, filenames): + if is_main_process: + for filename in filenames: + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + @smp.step() + def smp_forward_backward(model, inputs, gradient_accumulation_steps=1): + outputs = model(**inputs) + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + loss /= gradient_accumulation_steps + model.backward(loss) + return loss + + @smp.step() + def smp_forward_only(model, inputs): + return model(**inputs) + + def smp_gather(tensor): + if isinstance(tensor, (list, tuple)): + return type(tensor)(smp_gather(t) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: smp_gather(v) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP) + all_tensors = [atleast_1d(t) for t in all_tensors] + return torch.cat([t.cpu() for t in all_tensors], dim=0) + + def smp_nested_concat(tensor): + if isinstance(tensor, (list, tuple)): + return type(tensor)(smp_nested_concat(t) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()}) + # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` + # which is also the name of the decorator so Python is confused. + return tensor.concat().detach().cpu() + + +@dataclass +class AcceleratorConfig: + """ + A subset of arguments relating to the underlying [`accelerate.Accelerator`] + implementation utilized in the `Trainer` that can be customized. + Mostly relating to data. + + Parameters: + split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set + in your script multiplied by the number of processes. + dispatch_batches (`bool`, *optional*): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + use_seedable_sampler (`bool`, *optional*, defaults to `True`): + Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducable using a different sampling technique. While seed-to-seed results + may differ, on average the differences are neglible when using multiple different seeds to compare. Should + also be ran with [`~utils.set_seed`] for the best results. + gradient_accumulation_kwargs (`dict`, *optional*): + Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. + Any of the following (optional) keys are acceptable: + num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if + the latter is set to 1, otherwise an exception will be raised. + adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. + The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. + sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. + The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`. + non_blocking (`bool`, *optional*, defaults to `False`): + Whether to use non-blocking CUDA calls to help minimize synchronization during + distributed training with prepared `DataLoader` inputs being moved to device. + Best if used with `pin_memory=True` in the `TrainingArguments`. + use_configured_state (`bool*, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined + before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState` + must be initialized. May lead to issues using sweeps or hyperparameter tuning. + + """ + + # Data related arguments + split_batches: bool = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" + " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" + " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" + " in your script multiplied by the number of processes." + }, + ) + dispatch_batches: bool = field( + default=None, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + " underlying dataset is an `IterableDataslet`, `False` otherwise." + }, + ) + even_batches: bool = field( + default=True, + metadata={ + "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" + " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" + " all workers." + }, + ) + use_seedable_sampler: bool = field( + default=True, + metadata={ + "help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])." + "Ensures training results are fully reproducable using a different sampling technique. " + "While seed-to-seed results may differ, on average the differences are neglible when using" + "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." + }, + ) + + non_blocking: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to use non-blocking CUDA calls to help minimize synchronization during " + "distributed training with prepared `DataLoader` inputs being moved to device. " + "Best if used with `pin_memory=True` in the `TrainingArguments`. Requires accelerate " + "v0.30.0." + }, + ) + + gradient_accumulation_kwargs: Optional[Dict] = field( + default=None, + metadata={ + "help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. " + "Any of the following (optional) keys are acceptable: " + " num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if " + " the latter is set to 1, otherwise an exception will be raised. " + " adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. " + " The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. " + " sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. " + " The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`." + }, + ) + use_configured_state: bool = field( + default=False, + metadata={ + "help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`." + "If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning." + }, + ) + + @classmethod + def from_json_file(cls, json_file): + # Check if exists + open_file = io.open if os.path.exists(json_file) else open + with open_file(json_file, "r", encoding="utf-8") as f: + config_dict = json.load(f) + # Check for keys and load sensible defaults + extra_keys = sorted(key for key in config_dict.keys() if key not in cls.__dataclass_fields__.keys()) + if len(extra_keys) > 0: + raise ValueError( + f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`" + " version or fix (and potentially remove these keys) from your config file." + ) + return cls(**config_dict) + + def to_dict(self): + return copy.deepcopy(self.__dict__) + + def pop(self, key, default=None): + return self.__dict__.pop(key, default) + + +class LayerWiseDummyOptimizer(torch.optim.Optimizer): + """ + For Layer-wise optimizers such as GaLoRE optimizer, the optimization + step is already done through the post gradient hooks. Therefore + the trick is to create a dummy optimizer that can take arbitrary + args and kwargs and return a no-op during training. + + Initial idea from @hiyouga in LLaMA-Factory: + https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + """ + + def __init__(self, optimizer_dict=None, *args, **kwargs): + dummy_tensor = torch.randn(1, 1) + self.optimizer_dict = optimizer_dict + super().__init__([dummy_tensor], {"lr": kwargs.get("lr", 1e-03)}) + + def zero_grad(self, set_to_none: bool = True) -> None: + pass + + def step(self, closure=None) -> Optional[float]: + pass + + +class LayerWiseDummyScheduler(LRScheduler): + """ + For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step + are already done through the post gradient hooks. Therefore + the trick is to create a dummy scheduler that can take arbitrary + args and kwargs and return a no-op during training. + """ + + def __init__(self, *args, **kwargs): + self.default_lr = kwargs["lr"] + optimizer = LayerWiseDummyOptimizer(**kwargs) + last_epoch = -1 + verbose = False + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + # default value + lrs = [self.default_lr] + + # we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer` + if self.optimizer is not None: + param_wise_lrs = [ + [group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values() + ] + lrs = list(chain(*param_wise_lrs)) + + return lrs + + def _get_closed_form_lr(self): + return self.base_lrs diff --git a/trainer_seq2seq.py b/trainer_seq2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..76b7c1556d8a4786b6cd800152beb69be5acefb3 --- /dev/null +++ b/trainer_seq2seq.py @@ -0,0 +1,392 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import warnings +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.utils.data import Dataset + +from .generation.configuration_utils import GenerationConfig +from .integrations.deepspeed import is_deepspeed_zero3_enabled +from .integrations.fsdp import is_fsdp_managed_module +from .trainer import Trainer +from .utils import is_datasets_available, logging +from .utils.deprecation import deprecate_kwarg + + +if is_datasets_available(): + import datasets + +if TYPE_CHECKING: + from torch.utils.data import IterableDataset + + from .data.data_collator import DataCollator + from .feature_extraction_utils import FeatureExtractionMixin + from .image_processing_utils import BaseImageProcessor + from .modeling_utils import PreTrainedModel + from .processing_utils import ProcessorMixin + from .tokenization_utils_base import PreTrainedTokenizerBase + from .trainer_callback import TrainerCallback + from .trainer_utils import EvalPrediction, PredictionOutput + from .training_args import TrainingArguments + + +logger = logging.get_logger(__name__) + + +class Seq2SeqTrainer(Trainer): + @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True) + def __init__( + self, + model: Union["PreTrainedModel", nn.Module] = None, + args: "TrainingArguments" = None, + data_collator: Optional["DataCollator"] = None, + train_dataset: Optional[Union[Dataset, "IterableDataset", "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + processing_class: Optional[ + Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"] + ] = None, + model_init: Optional[Callable[[], "PreTrainedModel"]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, + callbacks: Optional[List["TrainerCallback"]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_loss_func=compute_loss_func, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Override self.model.generation_config if a GenerationConfig is specified in args. + # Priority: args.generation_config > model.generation_config > default GenerationConfig. + if self.args.generation_config is not None: + gen_config = self.load_generation_config(self.args.generation_config) + self.model.generation_config = gen_config + + @staticmethod + def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig: + """ + Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments. + + Args: + gen_config_arg (`str` or [`~generation.GenerationConfig]`): + `Seq2SeqTrainingArguments.generation_config` argument. + + Returns: + A `~generation.GenerationConfig`. + """ + + # GenerationConfig provided, nothing to do + if isinstance(gen_config_arg, GenerationConfig): + gen_config = deepcopy(gen_config_arg) + else: + # str or Path + pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg + config_file_name = None + + # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL + # This step is required in order to determine config_file_name + if pretrained_model_name.is_file(): + config_file_name = pretrained_model_name.name + pretrained_model_name = pretrained_model_name.parent + # dir path + elif pretrained_model_name.is_dir(): + pass + # model id or URL + else: + pretrained_model_name = gen_config_arg + + gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name) + + # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws + # an exception if there are warnings at validation time. + try: + with warnings.catch_warnings(record=True) as caught_warnings: + gen_config.validate() + if len(caught_warnings) > 0: + raise ValueError(str([w.message for w in caught_warnings])) + except ValueError as exc: + raise ValueError( + "The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings " + "and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc) + ) + return gen_config + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + **gen_kwargs, + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + + gen_kwargs = gen_kwargs.copy() + + # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the + # training args + if ( + gen_kwargs.get("max_length") is None + and gen_kwargs.get("max_new_tokens") is None + and self.args.generation_max_length is not None + ): + gen_kwargs["max_length"] = self.args.generation_max_length + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams + # We don't want to drop samples in general + self.gather_function = self.accelerator.gather + self._gen_kwargs = gen_kwargs + return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def predict( + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "test", + **gen_kwargs, + ) -> "PredictionOutput": + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + + + + If your predictions or labels have different sequence lengths (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + + gen_kwargs = gen_kwargs.copy() + + # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the + # training args + if ( + gen_kwargs.get("max_length") is None + and gen_kwargs.get("max_new_tokens") is None + and self.args.generation_max_length is not None + ): + gen_kwargs["max_length"] = self.args.generation_max_length + if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: + gen_kwargs["num_beams"] = self.args.generation_num_beams + self.gather_function = self.accelerator.gather + self._gen_kwargs = gen_kwargs + + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + **gen_kwargs, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + gen_kwargs: + Additional `generate` specific kwargs. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # Priority (handled in generate): + # non-`None` gen_kwargs > model.generation_config > default GenerationConfig() + if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): + gen_kwargs = self._gen_kwargs.copy() + if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None: + gen_kwargs.pop("num_beams") + if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None: + gen_kwargs.pop("max_length") + + default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model) + gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus) + + generation_inputs = inputs.copy() + # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate + # (otherwise, it would continue generating from the padded `decoder_input_ids`) + if ( + "labels" in generation_inputs + and "decoder_input_ids" in generation_inputs + and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape + ): + generation_inputs = { + k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask") + } + + summon_full_params_context = ( + FullyShardedDataParallel.summon_full_params(self.model) + if isinstance(self.model, FullyShardedDataParallel) + else contextlib.nullcontext() + ) + + with summon_full_params_context: + generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) + + # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop + # TODO: remove this hack when the legacy code that initializes generation_config from a model config is + # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183 + if self.model.generation_config._from_model_config: + self.model.generation_config._from_model_config = False + + # Retrieves GenerationConfig from model.generation_config + gen_config = self.model.generation_config + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_config.max_length: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) + elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) + + with torch.no_grad(): + if has_labels: + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + if has_labels: + labels = inputs["labels"] + if labels.shape[-1] < gen_config.max_length: + labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) + elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: + labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) + else: + labels = None + + return loss, generated_tokens, labels + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.processing_class is not None and hasattr(self.processing_class, "pad_token_id"): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.processing_class.pad_token_id + if self.processing_class.pad_token_id is not None + else self.processing_class.eos_token_id + ) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/trainer_utils.py b/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..42088cd730628db229bd8accbcac74864f435f72 --- /dev/null +++ b/trainer_utils.py @@ -0,0 +1,887 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch-independent utilities for the Trainer class. +""" + +import copy +import functools +import gc +import inspect +import os +import random +import re +import threading +import time +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union + +import numpy as np + +from .utils import ( + ExplicitEnum, + is_psutil_available, + is_tf_available, + is_torch_available, + is_torch_cuda_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_npu_available, + is_torch_xla_available, + is_torch_xpu_available, + requires_backends, +) + + +if is_torch_available(): + import torch + + +def seed_worker(_): + """ + Helper function to set worker seed during Dataloader initialization. + """ + worker_seed = torch.initial_seed() % 2**32 + set_seed(worker_seed) + + +def enable_full_determinism(seed: int, warn_only: bool = False): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow + """ + # set seed first + set_seed(seed) + + if is_torch_available(): + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + # The environment variable required to enable deterministic mode on Ascend NPUs. + os.environ["ASCEND_LAUNCH_BLOCKING"] = "1" + os.environ["HCCL_DETERMINISTIC"] = "1" + + os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1" + torch.use_deterministic_algorithms(True, warn_only=warn_only) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if is_tf_available(): + import tensorflow as tf + + tf.config.experimental.enable_op_determinism() + + +def set_seed(seed: int, deterministic: bool = False): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). + + Args: + seed (`int`): + The seed to set. + deterministic (`bool`, *optional*, defaults to `False`): + Whether to use deterministic algorithms where available. Can slow down training. + """ + random.seed(seed) + np.random.seed(seed) + if is_torch_available(): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + if deterministic: + torch.use_deterministic_algorithms(True) + if is_torch_mlu_available(): + torch.mlu.manual_seed_all(seed) + if is_torch_musa_available(): + torch.musa.manual_seed_all(seed) + if is_torch_npu_available(): + torch.npu.manual_seed_all(seed) + if is_torch_xpu_available(): + torch.xpu.manual_seed_all(seed) + if is_tf_available(): + import tensorflow as tf + + tf.random.set_seed(seed) + if deterministic: + tf.config.experimental.enable_op_determinism() + + +def neftune_post_forward_hook(module, input, output): + """ + Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding + layers. This method is slightly adapted from the original source code that can be found here: + https://github.com/neelsjain/NEFTune Simply add it to your model as follows: + ```python + model = ... + model.embed_tokens.neftune_noise_alpha = 0.1 + model.embed_tokens.register_forward_hook(neftune_post_forward_hook) + ``` + Args: + module (`torch.nn.Module`): + The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to + the desired noise alpha value. + input (`torch.Tensor`): + The input tensor to the model. + output (`torch.Tensor`): + The output tensor of the model (i.e. the embeddings). + """ + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output + + +class EvalPrediction: + """ + Evaluation output (always contains labels), to be used to compute metrics. + + Parameters: + predictions (`np.ndarray`): Predictions of the model. + label_ids (`np.ndarray`): Targets to be matched. + inputs (`np.ndarray`, *optional*): Input data passed to the model. + losses (`np.ndarray`, *optional*): Loss values computed during evaluation. + """ + + def __init__( + self, + predictions: Union[np.ndarray, Tuple[np.ndarray]], + label_ids: Union[np.ndarray, Tuple[np.ndarray]], + inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None, + losses: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None, + ): + self.predictions = predictions + self.label_ids = label_ids + self.inputs = inputs + self.losses = losses + self.elements = (self.predictions, self.label_ids) + if self.inputs is not None: + self.elements += (self.inputs,) + if self.losses is not None: + self.elements += (self.losses,) + + def __iter__(self): + return iter(self.elements) + + def __getitem__(self, idx): + if idx < 0 or idx >= len(self.elements): + raise IndexError("tuple index out of range") + return self.elements[idx] + + +class EvalLoopOutput(NamedTuple): + predictions: Union[np.ndarray, Tuple[np.ndarray]] + label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] + metrics: Optional[Dict[str, float]] + num_samples: Optional[int] + + +class PredictionOutput(NamedTuple): + predictions: Union[np.ndarray, Tuple[np.ndarray]] + label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] + metrics: Optional[Dict[str, float]] + + +class TrainOutput(NamedTuple): + global_step: int + training_loss: float + metrics: Dict[str, float] + + +PREFIX_CHECKPOINT_DIR = "checkpoint" +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") + + +def get_last_checkpoint(folder): + content = os.listdir(folder) + checkpoints = [ + path + for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) + + +class IntervalStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class SaveStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + BEST = "best" + + +class EvaluationStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class HubStrategy(ExplicitEnum): + END = "end" + EVERY_SAVE = "every_save" + CHECKPOINT = "checkpoint" + ALL_CHECKPOINTS = "all_checkpoints" + + +class BestRun(NamedTuple): + """ + The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]). + + Parameters: + run_id (`str`): + The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending + with run-{run_id}). + objective (`float`): + The objective that was obtained for this run. + hyperparameters (`Dict[str, Any]`): + The hyperparameters picked to get this run. + run_summary (`Optional[Any]`): + A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend. + """ + + run_id: str + objective: Union[float, List[float]] + hyperparameters: Dict[str, Any] + run_summary: Optional[Any] = None + + +def default_compute_objective(metrics: Dict[str, float]) -> float: + """ + The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no + metrics are provided to the [`Trainer`], the sum of all metrics otherwise. + + Args: + metrics (`Dict[str, float]`): The metrics returned by the evaluate method. + + Return: + `float`: The objective to minimize or maximize + """ + metrics = copy.deepcopy(metrics) + loss = metrics.pop("eval_loss", None) + _ = metrics.pop("epoch", None) + # Remove speed metrics + speed_metrics = [ + m + for m in metrics.keys() + if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time") + ] + for sm in speed_metrics: + _ = metrics.pop(sm, None) + return loss if len(metrics) == 0 else sum(metrics.values()) + + +def default_hp_space_optuna(trial) -> Dict[str, float]: + from .integrations import is_optuna_available + + assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`" + return { + "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), + "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5), + "seed": trial.suggest_int("seed", 1, 40), + "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]), + } + + +def default_hp_space_ray(trial) -> Dict[str, float]: + from .integrations import is_ray_tune_available + + assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`" + from ray import tune + + return { + "learning_rate": tune.loguniform(1e-6, 1e-4), + "num_train_epochs": tune.choice(list(range(1, 6))), + "seed": tune.uniform(1, 40), + "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]), + } + + +def default_hp_space_sigopt(trial): + return [ + {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformamtion": "log"}, + {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"}, + {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"}, + { + "categorical_values": ["4", "8", "16", "32", "64"], + "name": "per_device_train_batch_size", + "type": "categorical", + }, + ] + + +def default_hp_space_wandb(trial) -> Dict[str, float]: + from .integrations import is_wandb_available + + if not is_wandb_available(): + raise ImportError("This function needs wandb installed: `pip install wandb`") + + return { + "method": "random", + "metric": {"name": "objective", "goal": "minimize"}, + "parameters": { + "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4}, + "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6}, + "seed": {"distribution": "int_uniform", "min": 1, "max": 40}, + "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]}, + }, + } + + +class HPSearchBackend(ExplicitEnum): + OPTUNA = "optuna" + RAY = "ray" + SIGOPT = "sigopt" + WANDB = "wandb" + + +def is_main_process(local_rank): + """ + Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on + `local_rank`. + """ + if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + return xm.get_ordinal() == 0 + return local_rank in [-1, 0] + + +def total_processes_number(local_rank): + """ + Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. + """ + if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + return xm.xrt_world_size() + elif local_rank != -1 and is_torch_available(): + import torch + + return torch.distributed.get_world_size() + return 1 + + +def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None): + """ + Measure and return speed performance metrics. + + This function requires a time snapshot `start_time` before the operation to be measured starts and this function + should be run immediately after the operation to be measured has completed. + + Args: + - split: name to prefix metric (like train, eval, test...) + - start_time: operation start time + - num_samples: number of samples processed + - num_steps: number of steps processed + - num_tokens: number of tokens processed + """ + runtime = time.time() - start_time + result = {f"{split}_runtime": round(runtime, 4)} + if runtime == 0: + return result + if num_samples is not None: + samples_per_second = num_samples / runtime + result[f"{split}_samples_per_second"] = round(samples_per_second, 3) + if num_steps is not None: + steps_per_second = num_steps / runtime + result[f"{split}_steps_per_second"] = round(steps_per_second, 3) + if num_tokens is not None: + tokens_per_second = num_tokens / runtime + result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3) + return result + + +class SchedulerType(ExplicitEnum): + """ + Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`]. + By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`]. + Scheduler types: + - "linear" = get_linear_schedule_with_warmup + - "cosine" = get_cosine_schedule_with_warmup + - "cosine_with_restarts" = get_cosine_with_hard_restarts_schedule_with_warmup + - "polynomial" = get_polynomial_decay_schedule_with_warmup + - "constant" = get_constant_schedule + - "constant_with_warmup" = get_constant_schedule_with_warmup + - "inverse_sqrt" = get_inverse_sqrt_schedule + - "reduce_lr_on_plateau" = get_reduce_on_plateau_schedule + - "cosine_with_min_lr" = get_cosine_with_min_lr_schedule_with_warmup + - "warmup_stable_decay" = get_wsd_schedule + """ + + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + INVERSE_SQRT = "inverse_sqrt" + REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" + COSINE_WITH_MIN_LR = "cosine_with_min_lr" + WARMUP_STABLE_DECAY = "warmup_stable_decay" + + +class TrainerMemoryTracker: + """ + A helper class that tracks cpu and gpu memory. + + This class will silently skip unless `psutil` is available. Install with `pip install psutil`. + + When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage. + + Example : + + ```python + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + # code ... + metrics = {"train_runtime": 10.5} + self._memory_tracker.stop_and_update_metrics(metrics) + ``` + + At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`. + + To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`]. + """ + + # map trainer methods to metrics prefix + stages = { + "__init__": "init", + "train": "train", + "_inner_training_loop": "train", + "evaluate": "eval", + "predict": "test", + } + + def __init__(self, skip_memory_metrics=False): + self.skip_memory_metrics = skip_memory_metrics + + if not is_psutil_available(): + # soft dependency on psutil + self.skip_memory_metrics = True + + if self.skip_memory_metrics: + return + + import psutil # noqa + + if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_mps_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_xpu_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_npu_available(): + import torch + + self.torch = torch + self.gpu = {} + else: + self.torch = None + + self.process = psutil.Process() + + self.cur_stage = None + self.cpu = {} + self.init_reported = False + + def derive_stage(self): + """derives the stage/caller name automatically""" + caller = inspect.currentframe().f_back.f_back.f_code.co_name + if caller in self.stages: + return self.stages[caller] + else: + raise ValueError( + f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}" + ) + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_mem_used_peak = -1 + + while True: + self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def start(self): + """start tracking for the caller's stage""" + if self.skip_memory_metrics: + return + + stage = self.derive_stage() + # deal with nested calls of eval during train - simply ignore those + if self.cur_stage is not None and self.cur_stage != stage: + return + + self.cur_stage = stage + + gc.collect() + + if self.torch is not None: + if torch.cuda.is_available(): + self.torch.cuda.reset_peak_memory_stats() + self.torch.cuda.empty_cache() + elif is_torch_mlu_available(): + self.torch.mlu.reset_peak_memory_stats() + self.torch.mlu.empty_cache() + elif is_torch_musa_available(): + self.torch.musa.reset_peak_memory_stats() + self.torch.musa.empty_cache() + elif is_torch_xpu_available(): + self.torch.xpu.reset_peak_memory_stats() + self.torch.xpu.empty_cache() + elif is_torch_npu_available(): + self.torch.npu.reset_peak_memory_stats() + self.torch.npu.empty_cache() + elif is_torch_mps_available(): + self.torch.mps.empty_cache() + + # gpu + if self.torch is not None: + if torch.cuda.is_available(): + self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated() + elif is_torch_mlu_available(): + self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated() + elif is_torch_musa_available(): + self.gpu_mem_used_at_start = self.torch.musa.memory_allocated() + elif is_torch_xpu_available(): + self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() + elif is_torch_npu_available(): + self.gpu_mem_used_at_start = self.torch.npu.memory_allocated() + elif is_torch_mps_available(): + self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory() + + # cpu + self.cpu_mem_used_at_start = self.cpu_mem_used() + + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + + def stop(self, stage): + """stop tracking for the passed stage""" + + # deal with nested calls of eval during train - simply ignore those + if self.cur_stage is not None and self.cur_stage != stage: + return + + # this sends a signal to peak_monitor_func to complete its loop + self.peak_monitoring = False + + # first ensure all objects get collected and their memory is freed + gc.collect() + + if self.torch is not None: + if torch.cuda.is_available(): + self.torch.cuda.empty_cache() + elif is_torch_mlu_available(): + self.torch.mlu.empty_cache() + elif is_torch_musa_available(): + self.torch.musa.empty_cache() + elif is_torch_xpu_available(): + self.torch.xpu.empty_cache() + elif is_torch_npu_available(): + self.torch.npu.empty_cache() + elif is_torch_mps_available(): + self.torch.mps.empty_cache() + + # concepts: + # - alloc_delta: the difference of allocated memory between the end and the start + # - peaked_delta: the difference between the peak memory and the current memory + # in order to know how much memory the measured code consumed one needs to sum these two + + # gpu + if self.torch is not None: + if torch.cuda.is_available(): + self.gpu_mem_used_now = self.torch.cuda.memory_allocated() + self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated() + elif is_torch_mlu_available(): + self.gpu_mem_used_now = self.torch.mlu.memory_allocated() + self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated() + elif is_torch_musa_available(): + self.gpu_mem_used_now = self.torch.musa.memory_allocated() + self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated() + elif is_torch_xpu_available(): + self.gpu_mem_used_now = self.torch.xpu.memory_allocated() + self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated() + elif is_torch_npu_available(): + self.gpu_mem_used_now = self.torch.npu.memory_allocated() + self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated() + elif is_torch_mps_available(): + self.gpu_mem_used_now = self.torch.mps.current_allocated_memory() + # self.torch.mps.max_memory_allocated() does not exist yet + self.gpu_mem_used_peak = None + + else: + raise ValueError("No available GPU device found!") + + self.gpu[self.cur_stage] = { + "begin": self.gpu_mem_used_at_start, + "end": self.gpu_mem_used_now, + "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start), + } + if self.gpu_mem_used_peak is not None: + self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now) + else: + self.gpu[self.cur_stage]["peaked"] = "Not available" + + # cpu + self.cpu_mem_used_now = self.cpu_mem_used() + self.cpu[self.cur_stage] = { + "begin": self.cpu_mem_used_at_start, + "end": self.cpu_mem_used_now, + "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start), + "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now), + } + + # reset - cycle finished + self.cur_stage = None + + def update_metrics(self, stage, metrics): + """updates the metrics""" + if self.skip_memory_metrics: + return + + # deal with nested calls of eval during train - simply ignore those + if self.cur_stage is not None and self.cur_stage != stage: + return + + # since we don't have a way to return init metrics, we push them into the first of train/val/predict + stages = [stage] + if not self.init_reported: + stages.insert(0, "init") + self.init_reported = True + + for stage in stages: + for t in ["alloc", "peaked"]: + if stage in self.cpu and t in self.cpu[stage]: + metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t] + if self.torch is not None and stage in self.gpu and t in self.gpu[stage]: + metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t] + # if we need additional debug info, enable the following + # for t in ["begin", "end"]: + # if stage in self.cpu and t in self.cpu[stage]: + # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t] + # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]: + # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t] + + # since memory can be allocated before init, and it might be difficult to track overall + # memory usage, in particular for GPU, let's report memory usage at the point init was called + if stages[0] == "init": + metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"] + if self.torch is not None: + metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"] + # if we also wanted to report any additional memory allocations in between init and + # whatever the next stage was we could also report this: + # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]: + # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"] + # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]: + # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"] + + def stop_and_update_metrics(self, metrics=None): + """combine stop and metrics update in one call for simpler code""" + if self.skip_memory_metrics: + return + + stage = self.derive_stage() + self.stop(stage) + + # init doesn't have metrics to update so we just save that data for later stages to retrieve + if metrics is not None: + self.update_metrics(stage, metrics) + + +def has_length(dataset): + """ + Checks if the dataset implements __len__() and it doesn't raise an error + """ + try: + return len(dataset) is not None + except TypeError: + # TypeError: len() of unsized object + return False + + +def denumpify_detensorize(metrics): + """ + Recursively calls `.item()` on the element of the dictionary passed + """ + if isinstance(metrics, (list, tuple)): + return type(metrics)(denumpify_detensorize(m) for m in metrics) + elif isinstance(metrics, dict): + return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()}) + elif isinstance(metrics, np.generic): + return metrics.item() + elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1: + return metrics.item() + return metrics + + +def number_of_arguments(func): + """ + Return the number of arguments of the passed function, even if it's a partial function. + """ + if isinstance(func, functools.partial): + total_args = len(inspect.signature(func.func).parameters) + return total_args - len(func.args) - len(func.keywords) + return len(inspect.signature(func).parameters) + + +def find_executable_batch_size( + function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False +): + """ + Args: + A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or + CUDNN, the batch size is cut in half and passed to `function`. `function` must take in a `batch_size` parameter as + its first argument. + function (`callable`, *optional*) + A function to wrap + starting_batch_size (`int`, *optional*) + The batch size to try and fit into memory + auto_find_batch_size (`bool`, *optional*) + If False, will just execute `function` + """ + if function is None: + return functools.partial( + find_executable_batch_size, + starting_batch_size=starting_batch_size, + auto_find_batch_size=auto_find_batch_size, + ) + + if auto_find_batch_size: + requires_backends(find_executable_batch_size, "accelerate") + from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size + + return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size) + + return functools.partial(function, batch_size=starting_batch_size) + + +class FSDPOption(ExplicitEnum): + FULL_SHARD = "full_shard" + SHARD_GRAD_OP = "shard_grad_op" + NO_SHARD = "no_shard" + HYBRID_SHARD = "hybrid_shard" + HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" + OFFLOAD = "offload" + AUTO_WRAP = "auto_wrap" + + +class RemoveColumnsCollator: + """Wrap the data collator to remove unused columns before they are passed to the collator.""" + + def __init__( + self, + data_collator, + signature_columns, + logger=None, + model_name: Optional[str] = None, + description: Optional[str] = None, + ): + self.data_collator = data_collator + self.signature_columns = signature_columns + self.logger = logger + self.description = description + self.model_name = model_name + self.message_logged = False + + def _remove_columns(self, feature: dict) -> dict: + if not isinstance(feature, dict): + return feature + if not self.message_logged and self.logger and self.model_name: + ignored_columns = list(set(feature.keys()) - set(self.signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if self.description is None else f"in the {self.description} set" + self.logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, " + " you can safely ignore this message." + ) + self.message_logged = True + return {k: v for k, v in feature.items() if k in self.signature_columns} + + def __call__(self, features: List[dict]): + features = [self._remove_columns(feature) for feature in features] + return self.data_collator(features) + + +def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False): + """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules. + + Args: + optim_target_modules (`Union[str, List[str]]`): + A list of strings to try to match. Can be also a full string. + key (`str`): + A key to search any matches in optim_target_modules + return_is_regex (`bool`): + If set to `True`, the method will return whether the passed `optim_target_modules` + is a regex or not. + + Returns: + `bool` : True of match object if key matches any target modules from config, False or + None if no match found + `bool` : If the matched target module is a regex to silence out the warnings in Trainer + for extra modules being found (only if `target_module_found=True` for an array of regex). + """ + target_module_found = False + is_regex = False + + if isinstance(optim_target_modules, str): + target_module_found = bool(re.fullmatch(optim_target_modules, key)) + is_regex = True if not optim_target_modules == key else False + elif key in optim_target_modules: # from here, target_module_found must be a list of str + # this module is specified directly in target_modules + target_module_found = True + elif any(target_key in key for target_key in optim_target_modules): + target_module_found = True + elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules): + target_module_found = True + is_regex = True + + if return_is_regex: + return target_module_found, is_regex + + return target_module_found diff --git a/training_args.py b/training_args.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b2ba0db3a79e8cec3142a4cf9a04cd5a4d4169 --- /dev/null +++ b/training_args.py @@ -0,0 +1,3094 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import io +import json +import math +import os +import warnings +from dataclasses import asdict, dataclass, field, fields +from datetime import timedelta +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from huggingface_hub import get_full_repo_name +from packaging import version + +from .debug_utils import DebugOption +from .trainer_utils import ( + EvaluationStrategy, + FSDPOption, + HubStrategy, + IntervalStrategy, + SaveStrategy, + SchedulerType, +) +from .utils import ( + ACCELERATE_MIN_VERSION, + ExplicitEnum, + cached_property, + is_accelerate_available, + is_ipex_available, + is_safetensors_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_available, + is_torch_bf16_cpu_available, + is_torch_bf16_gpu_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torch_xpu_available, + logging, + requires_backends, +) +from .utils.generic import strtobool +from .utils.import_utils import is_optimum_neuron_available + + +logger = logging.get_logger(__name__) +log_levels = logging.get_log_levels_dict().copy() +trainer_log_levels = dict(**log_levels, passive=-1) + +if is_torch_available(): + import torch + import torch.distributed as dist + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils import DistributedType + + from .trainer_pt_utils import AcceleratorConfig + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +if is_torch_neuroncore_available(check_device=False): + # torchrun support + # https://github.com/pytorch/xla/pull/3609 + if os.environ.get("TORCHELASTIC_RUN_ID"): + if is_optimum_neuron_available(): + logger.info( + "Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this " + "will fail otherwise." + ) + else: + logger.warning( + "Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform " + "training on AWS Trainium instances. More information here: " + "https://github.com/huggingface/optimum-neuron" + ) + import torch_xla.distributed.xla_backend as xbn + + if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla): + dist.init_process_group(backend="xla") + if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla): + raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + smp.init() + + +def default_logdir() -> str: + """ + Same default as PyTorch + """ + import socket + from datetime import datetime + + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + return os.path.join("runs", current_time + "_" + socket.gethostname()) + + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + + +def get_xla_device_type(device: "torch.device") -> Optional[str]: + """ + Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device. + """ + if is_torch_xla_available(): + if device.type == "cpu": + return "CPU" + return xm.xla_real_devices([device])[0].split(":")[0] + return None + + +class OptimizerNames(ExplicitEnum): + """ + Stores the acceptable string identifiers for optimizers. + """ + + ADAMW_HF = "adamw_hf" + ADAMW_TORCH = "adamw_torch" + ADAMW_TORCH_FUSED = "adamw_torch_fused" + ADAMW_TORCH_XLA = "adamw_torch_xla" + ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused" + ADAMW_APEX_FUSED = "adamw_apex_fused" + ADAFACTOR = "adafactor" + ADAMW_ANYPRECISION = "adamw_anyprecision" + ADAMW_TORCH_4BIT = "adamw_torch_4bit" + ADEMAMIX = "ademamix" + SGD = "sgd" + ADAGRAD = "adagrad" + ADAMW_BNB = "adamw_bnb_8bit" + ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit + ADEMAMIX_8BIT = "ademamix_8bit" + LION_8BIT = "lion_8bit" + LION = "lion_32bit" + PAGED_ADAMW = "paged_adamw_32bit" + PAGED_ADAMW_8BIT = "paged_adamw_8bit" + PAGED_ADEMAMIX = "paged_ademamix_32bit" + PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit" + PAGED_LION = "paged_lion_32bit" + PAGED_LION_8BIT = "paged_lion_8bit" + RMSPROP = "rmsprop" + RMSPROP_BNB = "rmsprop_bnb" + RMSPROP_8BIT = "rmsprop_bnb_8bit" + RMSPROP_32BIT = "rmsprop_bnb_32bit" + GALORE_ADAMW = "galore_adamw" + GALORE_ADAMW_8BIT = "galore_adamw_8bit" + GALORE_ADAFACTOR = "galore_adafactor" + GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" + GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" + GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" + LOMO = "lomo" + ADALOMO = "adalomo" + GROKADAMW = "grokadamw" + SCHEDULE_FREE_ADAMW = "schedule_free_adamw" + SCHEDULE_FREE_SGD = "schedule_free_sgd" + + +# Sometimes users will pass in a `str` repr of a dict in the CLI +# We need to track what fields those can be. Each time a new arg +# has a dict type, it must be added to this list. +# Important: These should be typed with Optional[Union[dict,str,...]] +_VALID_DICT_FIELDS = [ + "accelerator_config", + "fsdp_config", + "deepspeed", + "gradient_checkpointing_kwargs", + "lr_scheduler_kwargs", +] + + +def _convert_str_dict(passed_value: dict): + "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types." + for key, value in passed_value.items(): + if isinstance(value, dict): + passed_value[key] = _convert_str_dict(value) + elif isinstance(value, str): + # First check for bool and convert + if value.lower() in ("true", "false"): + passed_value[key] = value.lower() == "true" + # Check for digit + elif value.isdigit(): + passed_value[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + passed_value[key] = float(value) + + return passed_value + + +# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 +@dataclass +class TrainingArguments: + """ + TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop + itself**. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + output_dir (`str`): + The output directory where the model predictions and checkpoints will be written. + overwrite_output_dir (`bool`, *optional*, defaults to `False`): + If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` + points to a checkpoint directory. + do_train (`bool`, *optional*, defaults to `False`): + Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used + by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_eval (`bool`, *optional*): + Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is + different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your + training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_predict (`bool`, *optional*, defaults to `False`): + Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's + intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `eval_steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + prediction_loss_only (`bool`, *optional*, defaults to `False`): + When performing evaluation and generating predictions, only returns the loss. + per_device_train_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training. + per_device_eval_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, + evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples. + + + + eval_accumulation_steps (`int`, *optional*): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If + left unset, the whole predictions are accumulated on GPU/NPU/TPU before being moved to the CPU (faster but + requires more memory). + eval_delay (`float`, *optional*): + Number of epochs or steps to wait for before the first evaluation can be performed, depending on the + eval_strategy. + torch_empty_cache_steps (`int`, *optional*): + Number of steps to wait before calling `torch..empty_cache()`. If left unset or set to None, cache will not be emptied. + + + + This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372). + + + + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for [`AdamW`] optimizer. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] + optimizer. + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the [`AdamW`] optimizer. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the [`AdamW`] optimizer. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the [`AdamW`] optimizer. + max_grad_norm (`float`, *optional*, defaults to 1.0): + Maximum gradient norm (for gradient clipping). + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents of + the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): + The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + lr_scheduler_kwargs ('dict', *optional*, defaults to {}): + The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`. + log_level (`str`, *optional*, defaults to `passive`): + Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug', + 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the + current log level for the Transformers library (which will be `"warning"` by default). + log_level_replica (`str`, *optional*, defaults to `"warning"`): + Logger log level to use on replicas. Same choices as `log_level`" + log_on_each_node (`bool`, *optional*, defaults to `True`): + In multinode distributed training, whether to log using `log_level` once per node, or only on the main + node. + logging_dir (`str`, *optional*): + [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to + *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***. + logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + logging_first_step (`bool`, *optional*, defaults to `False`): + Whether to log the first `global_step` or not. + logging_steps (`int` or `float`, *optional*, defaults to 500): + Number of update steps between two logs if `logging_strategy="steps"`. Should be an integer or a float in + range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps. + logging_nan_inf_filter (`bool`, *optional*, defaults to `True`): + Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan` + or `inf` is filtered and the average loss of the current logging window is taken instead. + + + + `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the + gradient is computed or applied to the model. + + + + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + - `"best"`: Save is done whenever a new `best_metric` is achieved. + + If `"epoch"` or `"steps"` is chosen, saving will also be performed at the + very end of training, always. + save_steps (`int` or `float`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a + float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps. + save_total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. When `load_best_model_at_end` is enabled, the "best" checkpoint according to + `metric_for_best_model` will always be retained in addition to the most recent ones. For example, for + `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained + alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two + checkpoints are saved: the last one and the best one (if they are different). + save_safetensors (`bool`, *optional*, defaults to `True`): + Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of + default `torch.load` and `torch.save`. + save_on_each_node (`bool`, *optional*, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on + the main one. + + This should not be activated when the different nodes use the same storage as the files will be saved with + the same names for each node. + save_only_model (`bool`, *optional*, defaults to `False`): + When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state. + Note that when this is true, you won't be able to resume training from checkpoint. + This enables you to save storage by not storing the optimizer, scheduler & rng state. + You can only load the model using `from_pretrained` with this option set to `True`. + restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to restore the callback states from the checkpoint. If `True`, will override + callbacks passed to the `Trainer` if they exist in the checkpoint." + use_cpu (`bool`, *optional*, defaults to `False`): + Whether or not to use cpu. If set to False, we will use cuda or mps device if available. + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the + [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters. + data_seed (`int`, *optional*): + Random seed to be used with data samplers. If not set, random generators for data sampling will use the + same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model + seed. + jit_mode_eval (`bool`, *optional*, defaults to `False`): + Whether or not to use PyTorch jit trace for inference. + use_ipex (`bool`, *optional*, defaults to `False`): + Use Intel extension for PyTorch when it is available. [IPEX + installation](https://github.com/intel/intel-extension-for-pytorch). + bf16 (`bool`, *optional*, defaults to `False`): + Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher + NVIDIA architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change. + fp16 (`bool`, *optional*, defaults to `False`): + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. + fp16_opt_level (`str`, *optional*, defaults to 'O1'): + For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on + the [Apex documentation](https://nvidia.github.io/apex/amp). + fp16_backend (`str`, *optional*, defaults to `"auto"`): + This argument is deprecated. Use `half_precision_backend` instead. + half_precision_backend (`str`, *optional*, defaults to `"auto"`): + The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will + use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the + requested backend. + bf16_full_eval (`bool`, *optional*, defaults to `False`): + Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm + metric values. This is an experimental API and it may change. + fp16_full_eval (`bool`, *optional*, defaults to `False`): + Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm + metric values. + tf32 (`bool`, *optional*): + Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends + on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to + the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an + experimental API and it may change. + local_rank (`int`, *optional*, defaults to -1): + Rank of the process during distributed training. + ddp_backend (`str`, *optional*): + The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`. + tpu_num_cores (`int`, *optional*): + When training on TPU, the number of TPU cores (automatically passed by launcher script). + dataloader_drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) + or not. + eval_steps (`int` or `float`, *optional*): + Number of update steps between two evaluations if `eval_strategy="steps"`. Will default to the same + value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1, + will be interpreted as ratio of total training steps. + dataloader_num_workers (`int`, *optional*, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the + main process. + past_index (`int`, *optional*, defaults to -1): + Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of + the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will + use the corresponding output (usually index 2) as the past state and feed it to the model at the next + training step under the keyword argument `mems`. + run_name (`str`, *optional*, defaults to `output_dir`): + A descriptor for the run. Typically used for [wandb](https://www.wandb.com/), + [mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will + be the same as `output_dir`. + disable_tqdm (`bool`, *optional*): + Whether or not to disable the tqdm progress bars and table of metrics produced by + [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is + set to warn or lower (default), `False` otherwise. + remove_unused_columns (`bool`, *optional*, defaults to `True`): + Whether or not to automatically remove the columns unused by the model forward method. + label_names (`List[str]`, *optional*): + The list of keys in your dictionary of inputs that correspond to the labels. + + Will eventually default to the list of argument names accepted by the model that contain the word "label", + except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the + `["start_positions", "end_positions"]` keys. + load_best_model_at_end (`bool`, *optional*, defaults to `False`): + Whether or not to load the best model found during training at the end of training. When this option is + enabled, the best checkpoint will always be saved. See + [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit) + for more. + + + + When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in + the case it is "steps", `save_steps` must be a round multiple of `eval_steps`. + + + + metric_for_best_model (`str`, *optional*): + Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different + models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`. + + If not specified, this will default to `"loss"` when either `load_best_model_at_end == True` + or `lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU` (to use the evaluation loss). + + If you set this value, `greater_is_better` will default to `True` unless the name ends with "loss". + Don't forget to set it to `False` if your metric is better when lower. + greater_is_better (`bool`, *optional*): + Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models + should have a greater metric or not. Will default to: + + - `True` if `metric_for_best_model` is set to a value that doesn't end in `"loss"`. + - `False` if `metric_for_best_model` is not set, or set to a value that ends in `"loss"`. + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the same + stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step + can take a long time) but will not yield the same results as the interrupted training would have. + fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`): + Use PyTorch Distributed Parallel Training (in distributed training only). + + A list of options along the following: + + - `"full_shard"`: Shard parameters, gradients and optimizer states. + - `"shard_grad_op"`: Shard optimizer states and gradients. + - `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes. + - `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes. + - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and + `"shard_grad_op"`). + - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`. + fsdp_config (`str` or `dict`, *optional*): + Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of + fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. + + A List of config and its options: + - min_num_params (`int`, *optional*, defaults to `0`): + FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is + passed). + - transformer_layer_cls_to_wrap (`List[str]`, *optional*): + List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`, + `T5Block` .... (useful only when `fsdp` flag is passed). + - backward_prefetch (`str`, *optional*) + FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when + `fsdp` field is passed). + + A list of options along the following: + + - `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's + gradient + computation. + - `"backward_post"` : This prefetches the next set of parameters after the current set of + parameter’s + gradient computation. + - forward_prefetch (`bool`, *optional*, defaults to `False`) + FSDP's forward prefetch mode (useful only when `fsdp` field is passed). + If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the + forward pass. + - limit_all_gathers (`bool`, *optional*, defaults to `False`) + FSDP's limit_all_gathers (useful only when `fsdp` field is passed). + If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight + all-gathers. + - use_orig_params (`bool`, *optional*, defaults to `True`) + If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed + frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please + refer this + [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + - sync_module_states (`bool`, *optional*, defaults to `True`) + If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to + ensure they are the same across all ranks after initialization + - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`) + If `"True"`, only the first process loads the pretrained model checkpoint while all other processes + have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`, + otherwise all the processes except the main process would have random weights leading to unexpected + behaviour during training. + - activation_checkpointing (`bool`, *optional*, defaults to `False`): + If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of + certain layers and recomputing them during a backward pass. Effectively, this trades extra + computation time for reduced memory usage. + - xla (`bool`, *optional*, defaults to `False`): + Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature + and its API may evolve in the future. + - xla_fsdp_settings (`dict`, *optional*) + The value is a dictionary which stores the XLA FSDP wrapping parameters. + + For a complete list of options, please see [here]( + https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py). + - xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`): + Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be + used when the xla flag is set to true, and an auto wrapping policy is specified through + fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. + + deepspeed (`str` or `dict`, *optional*): + Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may + evolve in the future. The value is either the location of DeepSpeed json config file (e.g., + `ds_config.json`) or an already loaded json file as a `dict`" + + + If enabling any Zero-init, make sure that your model is not initialized until + *after* initializing the `TrainingArguments`, else it will not be applied. + + + accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*): + Config to be used with the internal `Accelerator` implementation. The value is either a location of + accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`, + or an instance of [`~trainer_pt_utils.AcceleratorConfig`]. + + A list of config and its options: + - split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set + in your script multiplied by the number of processes. + - dispatch_batches (`bool`, *optional*): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + - even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + - use_seedable_sampler (`bool`, *optional*, defaults to `True`): + Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducable using a different sampling technique. While seed-to-seed results + may differ, on average the differences are neglible when using multiple different seeds to compare. Should + also be ran with [`~utils.set_seed`] for the best results. + - use_configured_state (`bool`, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`. + If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues + with hyperparameter tuning. + + label_smoothing_factor (`float`, *optional*, defaults to 0.0): + The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded + labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor + + label_smoothing_factor/num_labels` respectively. + debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`): + Enable one or more debug features. This is an experimental feature. + + Possible options are: + + - `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to + the event + - `"tpu_metrics_debug"`: print debug metrics on TPU + + The options should be separated by whitespaces. + optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`): + The optimizer to use, such as "adamw_hf", "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision", + "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) + for a full list of optimizers. + optim_args (`str`, *optional*): + Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore. + group_by_length (`bool`, *optional*, defaults to `False`): + Whether or not to group together samples of roughly the same length in the training dataset (to minimize + padding applied and be more efficient). Only useful if applying dynamic padding. + length_column_name (`str`, *optional*, defaults to `"length"`): + Column name for precomputed lengths. If the column exists, grouping by length will use these values rather + than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an + instance of `Dataset`. + report_to (`str` or `List[str]`, *optional*, defaults to `"all"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`, + `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no + integrations. + ddp_find_unused_parameters (`bool`, *optional*): + When using distributed training, the value of the flag `find_unused_parameters` passed to + `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. + ddp_bucket_cap_mb (`int`, *optional*): + When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`. + ddp_broadcast_buffers (`bool`, *optional*): + When using distributed training, the value of the flag `broadcast_buffers` passed to + `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. + dataloader_pin_memory (`bool`, *optional*, defaults to `True`): + Whether you want to pin memory in data loaders or not. Will default to `True`. + dataloader_persistent_workers (`bool`, *optional*, defaults to `False`): + If True, the data loader will not shut down the worker processes after a dataset has been consumed once. + This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will + increase RAM usage. Will default to `False`. + dataloader_prefetch_factor (`int`, *optional*): + Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + skip_memory_metrics (`bool`, *optional*, defaults to `True`): + Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows + down the training and evaluation speed. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push the model to the Hub every time the model is saved. If this is activated, + `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content + will be pushed each time a save is triggered (depending on your `save_strategy`). Calling + [`~Trainer.save_model`] will also trigger a push. + + + + If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be + pushed. + + + + resume_from_checkpoint (`str`, *optional*): + The path to a folder with a valid checkpoint for your model. This argument is not directly used by + [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + hub_model_id (`str`, *optional*): + The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, + for instance `"user_name/model"`, which allows you to push to an organization you are a member of with + `"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the + name of `output_dir`. + + Will default to the name of `output_dir`. + hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`): + Defines the scope of what is pushed to the Hub and when. Possible values are: + + - `"end"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and a + draft of a model card when the [`~Trainer.save_model`] method is called. + - `"every_save"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and + a draft of a model card each time there is a model save. The pushes are asynchronous to not block + training, and in case the save are very frequent, a new push is only attempted if the previous one is + finished. A last push is made with the final model at the end of training. + - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named + last-checkpoint, allowing you to resume training easily with + `trainer.train(resume_from_checkpoint="last-checkpoint")`. + - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the output + folder (so you will get one checkpoint folder per folder in your final repository) + + hub_token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with + `huggingface-cli login`. + hub_private_repo (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + hub_always_push (`bool`, *optional*, defaults to `False`): + Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished. + gradient_checkpointing (`bool`, *optional*, defaults to `False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): + Key word arguments to be passed to the `gradient_checkpointing_enable` method. + include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): + This argument is deprecated. Use `include_for_metrics` instead, e.g, `include_for_metrics = ["inputs"]`. + include_for_metrics (`List[str]`, *optional*, defaults to `[]`): + Include additional data in the `compute_metrics` function if needed for metrics computation. + Possible options to add to `include_for_metrics` list: + - `"inputs"`: Input data passed to the model, intended for calculating input dependent metrics. + - `"loss"`: Loss values computed during evaluation, intended for calculating loss dependent metrics. + eval_do_concat_batches (`bool`, *optional*, defaults to `True`): + Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, + will instead store them as lists, with each batch kept separate. + auto_find_batch_size (`bool`, *optional*, defaults to `False`) + Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding + CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) + full_determinism (`bool`, *optional*, defaults to `False`) + If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in + distributed training. Important: this will negatively impact the performance, so only use it for debugging. + torchdynamo (`str`, *optional*): + If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, + `"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`. + ray_scope (`str`, *optional*, defaults to `"last"`): + The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will + then use the last checkpoint of all trials, compare those, and select the best one. However, other options + are also available. See the [Ray documentation]( + https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for + more options. + ddp_timeout (`int`, *optional*, defaults to 1800): + The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when + performing slow operations in distributed runnings. Please refer the [PyTorch documentation] + (https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more + information. + use_mps_device (`bool`, *optional*, defaults to `False`): + This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device. + torch_compile (`bool`, *optional*, defaults to `False`): + Whether or not to compile the model using PyTorch 2.0 + [`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/). + + This will use the best defaults for the [`torch.compile` + API](https://pytorch.org/docs/stable/generated/torch.compile.html?highlight=torch+compile#torch.compile). + You can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we + don't guarantee any of them will work as the support is progressively rolled in in PyTorch. + + This flag and the whole compile API is experimental and subject to change in future releases. + torch_compile_backend (`str`, *optional*): + The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`. + + Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions. + + This flag is experimental and subject to change in future releases. + torch_compile_mode (`str`, *optional*): + The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`. + + Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions. + + This flag is experimental and subject to change in future releases. + split_batches (`bool`, *optional*): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices + during distributed training. If + + set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it + must be a + + round multiple of the number of processes you are using (such as GPUs). + include_tokens_per_second (`bool`, *optional*): + Whether or not to compute the number of tokens per second per device for training speed metrics. + + This will iterate over the entire training dataloader once beforehand, + + and will slow down the entire process. + + include_num_input_tokens_seen (`bool`, *optional*): + Whether or not to track the number of input tokens seen throughout training. + + May be slower in distributed training as gather operations must be called. + + neftune_noise_alpha (`Optional[float]`): + If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance + for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the + [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also + `PeftModel` from peft. The original paper used values in the range [5.0, 15.0]. + optim_target_modules (`Union[str, List[str]]`, *optional*): + The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm + https://arxiv.org/abs/2403.03507 + See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe + optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules + only. + + batch_eval_metrics (`Optional[bool]`, defaults to `False`): + If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics + rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function + that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global + summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. + + eval_on_start (`bool`, *optional*, defaults to `False`): + Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + eval_use_gather_object (`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch. + + use_liger_kernel (`bool`, *optional*, defaults to `False`): + Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training. + It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with + flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models. + """ + + framework = "pt" + output_dir: str = field( + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": ( + "Overwrite the content of the output directory. " + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, + ) + + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) + eval_strategy: Union[IntervalStrategy, str] = field( + default="no", + metadata={"help": "The evaluation strategy to use."}, + ) + prediction_loss_only: bool = field( + default=False, + metadata={"help": "When performing evaluation and predictions, only returns the loss."}, + ) + + per_device_train_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."} + ) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."} + ) + + per_gpu_train_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Deprecated, the use of `--per_device_train_batch_size` is preferred. " + "Batch size per GPU/TPU core/CPU for training." + ) + }, + ) + per_gpu_eval_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Deprecated, the use of `--per_device_eval_batch_size` is preferred. " + "Batch size per GPU/TPU core/CPU for evaluation." + ) + }, + ) + + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + eval_accumulation_steps: Optional[int] = field( + default=None, + metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."}, + ) + + eval_delay: Optional[float] = field( + default=0, + metadata={ + "help": ( + "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the" + " eval_strategy." + ) + }, + ) + + torch_empty_cache_steps: Optional[int] = field( + default=None, + metadata={ + "help": "Number of steps to wait before calling `torch..empty_cache()`." + "This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)." + "If left unset or set to None, cache will not be emptied." + }, + ) + + learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) + + num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, + ) + lr_scheduler_type: Union[SchedulerType, str] = field( + default="linear", + metadata={"help": "The scheduler type to use."}, + ) + lr_scheduler_kwargs: Optional[Union[dict, str]] = field( + default_factory=dict, + metadata={ + "help": ( + "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts." + ) + }, + ) + warmup_ratio: float = field( + default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} + ) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + + log_level: Optional[str] = field( + default="passive", + metadata={ + "help": ( + "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug'," + " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and" + " lets the application set the level. Defaults to 'passive'." + ), + "choices": trainer_log_levels.keys(), + }, + ) + log_level_replica: Optional[str] = field( + default="warning", + metadata={ + "help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``", + "choices": trainer_log_levels.keys(), + }, + ) + log_on_each_node: bool = field( + default=True, + metadata={ + "help": ( + "When doing a multinode distributed training, whether to log once per node or just once on the main" + " node." + ) + }, + ) + logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) + logging_strategy: Union[IntervalStrategy, str] = field( + default="steps", + metadata={"help": "The logging strategy to use."}, + ) + logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) + logging_steps: float = field( + default=500, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) + save_strategy: Union[SaveStrategy, str] = field( + default="steps", + metadata={"help": "The checkpoint save strategy to use."}, + ) + save_steps: float = field( + default=500, + metadata={ + "help": ( + "Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + save_total_limit: Optional[int] = field( + default=None, + metadata={ + "help": ( + "If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in" + " `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to" + " `metric_for_best_model` will always be retained in addition to the most recent ones. For example," + " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be" + " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`," + " it is possible that two checkpoints are saved: the last one and the best one (if they are different)." + " Default is unlimited checkpoints" + ) + }, + ) + save_safetensors: Optional[bool] = field( + default=True, + metadata={ + "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save." + }, + ) + save_on_each_node: bool = field( + default=False, + metadata={ + "help": ( + "When doing multi-node distributed training, whether to save models and checkpoints on each node, or" + " only on the main one" + ) + }, + ) + save_only_model: bool = field( + default=False, + metadata={ + "help": ( + "When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state." + "Note that when this is true, you won't be able to resume training from checkpoint." + "This enables you to save storage by not storing the optimizer, scheduler & rng state." + "You can only load the model using from_pretrained with this option set to True." + ) + }, + ) + restore_callback_states_from_checkpoint: bool = field( + default=False, + metadata={ + "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint." + }, + ) + no_cuda: bool = field( + default=False, + metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."}, + ) + use_cpu: bool = field( + default=False, + metadata={ + "help": "Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available." + }, + ) + use_mps_device: bool = field( + default=False, + metadata={ + "help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device." + " It will be removed in version 5.0 of 🤗 Transformers" + }, + ) + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) + jit_mode_eval: bool = field( + default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"} + ) + use_ipex: bool = field( + default=False, + metadata={ + "help": ( + "Use Intel extension for PyTorch when it is available, installation:" + " 'https://github.com/intel/intel-extension-for-pytorch'" + ) + }, + ) + bf16: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA" + " architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + fp16: bool = field( + default=False, + metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"}, + ) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": ( + "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " + "See details at https://nvidia.github.io/apex/amp.html" + ) + }, + ) + half_precision_backend: str = field( + default="auto", + metadata={ + "help": "The backend to be used for half precision.", + "choices": ["auto", "apex", "cpu_amp"], + }, + ) + bf16_full_eval: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may" + " change." + ) + }, + ) + fp16_full_eval: bool = field( + default=False, + metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, + ) + tf32: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental" + " API and it may change." + ) + }, + ) + local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) + ddp_backend: Optional[str] = field( + default=None, + metadata={ + "help": "The backend to be used for distributed training", + "choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"], + }, + ) + tpu_num_cores: Optional[int] = field( + default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} + ) + tpu_metrics_debug: bool = field( + default=False, + metadata={ + "help": ( + "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics" + ) + }, + ) + debug: Union[str, List[DebugOption]] = field( + default="", + metadata={ + "help": ( + "Whether or not to enable debug mode. Current options: " + "`underflow_overflow` (Detect underflow and overflow in activations and weights), " + "`tpu_metrics_debug` (print debug metrics on TPU)." + ) + }, + ) + + dataloader_drop_last: bool = field( + default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} + ) + eval_steps: Optional[float] = field( + default=None, + metadata={ + "help": ( + "Run an evaluation every X steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + dataloader_num_workers: int = field( + default=0, + metadata={ + "help": ( + "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded" + " in the main process." + ) + }, + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Number of batches loaded in advance by each worker. " + "2 means there will be a total of 2 * num_workers batches prefetched across all workers. " + "Default is 2 for PyTorch < 2.0.0 and otherwise None." + ) + }, + ) + past_index: int = field( + default=-1, + metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, + ) + + run_name: Optional[str] = field( + default=None, + metadata={"help": "An optional descriptor for the run. Notably used for wandb, mlflow and comet logging."}, + ) + disable_tqdm: Optional[bool] = field( + default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."} + ) + + remove_unused_columns: Optional[bool] = field( + default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} + ) + label_names: Optional[List[str]] = field( + default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} + ) + load_best_model_at_end: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Whether or not to load the best model found during training at the end of training. When this option" + " is enabled, the best checkpoint will always be saved. See `save_total_limit` for more." + ) + }, + ) + metric_for_best_model: Optional[str] = field( + default=None, metadata={"help": "The metric to use to compare two different models."} + ) + greater_is_better: Optional[bool] = field( + default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} + ) + ignore_data_skip: bool = field( + default=False, + metadata={ + "help": ( + "When resuming training, whether or not to skip the first epochs and batches to get to the same" + " training data." + ) + }, + ) + fsdp: Optional[Union[List[FSDPOption], str]] = field( + default="", + metadata={ + "help": ( + "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training" + " only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add" + " CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op" + " offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard" + " auto_wrap` or `shard_grad_op auto_wrap`." + ), + }, + ) + fsdp_min_num_params: int = field( + default=0, + metadata={ + "help": ( + "This parameter is deprecated. FSDP's minimum number of parameters for Default Auto Wrapping. (useful" + " only when `fsdp` field is passed)." + ) + }, + ) + fsdp_config: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": ( + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a " + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + ) + }, + ) + fsdp_transformer_layer_cls_to_wrap: Optional[str] = field( + default=None, + metadata={ + "help": ( + "This parameter is deprecated. Transformer layer class name (case-sensitive) to wrap, e.g," + " `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed)." + ) + }, + ) + accelerator_config: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": ( + "Config to be used with the internal Accelerator object initializtion. The value is either a " + "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`." + ) + }, + ) + deepspeed: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": ( + "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" + " loaded json file as a dict" + ) + }, + ) + label_smoothing_factor: float = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} + ) + + default_optim = "adamw_torch" + # XXX: enable when pytorch==2.0.1 comes out - we want to give it time to get all the bugs sorted out + # if is_torch_available() and version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.1.0"): + # default_optim = "adamw_torch_fused" + # and update the doc above to: + # optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch_fused"` (for torch<2.1.0 `"adamw_torch"`): + optim: Union[OptimizerNames, str] = field( + default=default_optim, + metadata={"help": "The optimizer to use."}, + ) + optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."}) + adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) + group_by_length: bool = field( + default=False, + metadata={"help": "Whether or not to group samples of roughly the same length together when batching."}, + ) + length_column_name: Optional[str] = field( + default="length", + metadata={"help": "Column name with precomputed lengths to use when grouping by length."}, + ) + report_to: Union[None, str, List[str]] = field( + default=None, metadata={"help": "The list of integrations to report the results and logs to."} + ) + ddp_find_unused_parameters: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `find_unused_parameters` passed to " + "`DistributedDataParallel`." + ) + }, + ) + ddp_bucket_cap_mb: Optional[int] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `bucket_cap_mb` passed to " + "`DistributedDataParallel`." + ) + }, + ) + ddp_broadcast_buffers: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `broadcast_buffers` passed to " + "`DistributedDataParallel`." + ) + }, + ) + dataloader_pin_memory: bool = field( + default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} + ) + dataloader_persistent_workers: bool = field( + default=False, + metadata={ + "help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage." + }, + ) + skip_memory_metrics: bool = field( + default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} + ) + use_legacy_prediction_loop: bool = field( + default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."} + ) + push_to_hub: bool = field( + default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "The path to a folder with a valid checkpoint for your model."}, + ) + hub_model_id: Optional[str] = field( + default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} + ) + hub_strategy: Union[HubStrategy, str] = field( + default="every_save", + metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, + ) + hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + hub_private_repo: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists." + }, + ) + hub_always_push: bool = field( + default=False, + metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."}, + ) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." + }, + ) + include_inputs_for_metrics: bool = field( + default=False, + metadata={ + "help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead." + }, + ) + include_for_metrics: List[str] = field( + default_factory=list, + metadata={ + "help": "List of strings to specify additional data to include in the `compute_metrics` function." + "Options: 'inputs', 'loss'." + }, + ) + eval_do_concat_batches: bool = field( + default=True, + metadata={ + "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate." + }, + ) + # Deprecated arguments + fp16_backend: str = field( + default="auto", + metadata={ + "help": "Deprecated. Use half_precision_backend instead", + "choices": ["auto", "apex", "cpu_amp"], + }, + ) + evaluation_strategy: Union[IntervalStrategy, str] = field( + default=None, + metadata={"help": "Deprecated. Use `eval_strategy` instead"}, + ) + push_to_hub_model_id: Optional[str] = field( + default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} + ) + push_to_hub_organization: Optional[str] = field( + default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."} + ) + push_to_hub_token: Optional[str] = field( + default=None, metadata={"help": "The token to use to push to the Model Hub."} + ) + _n_gpu: int = field(init=False, repr=False, default=-1) + mp_parameters: str = field( + default="", + metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"}, + ) + + auto_find_batch_size: bool = field( + default=False, + metadata={ + "help": ( + "Whether to automatically decrease the batch size in half and rerun the training loop again each time" + " a CUDA Out-of-Memory was reached" + ) + }, + ) + full_determinism: bool = field( + default=False, + metadata={ + "help": ( + "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed" + " training. Important: this will negatively impact the performance, so only use it for debugging." + ) + }, + ) + torchdynamo: Optional[str] = field( + default=None, + metadata={ + "help": "This argument is deprecated, use `--torch_compile_backend` instead.", + }, + ) + ray_scope: Optional[str] = field( + default="last", + metadata={ + "help": ( + 'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray' + " will then use the last checkpoint of all trials, compare those, and select the best one. However," + " other options are also available. See the Ray documentation" + " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html" + "#ray.tune.ExperimentAnalysis.get_best_trial)" + " for more options." + ) + }, + ) + ddp_timeout: Optional[int] = field( + default=1800, + metadata={ + "help": "Overrides the default timeout for distributed training (value should be given in seconds)." + }, + ) + torch_compile: bool = field( + default=False, metadata={"help": "If set to `True`, the model will be wrapped in `torch.compile`."} + ) + torch_compile_backend: Optional[str] = field( + default=None, + metadata={ + "help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.", + }, + ) + torch_compile_mode: Optional[str] = field( + default=None, + metadata={ + "help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.", + }, + ) + + dispatch_batches: Optional[bool] = field( + default=None, + metadata={"help": "Deprecated. Pass {'dispatch_batches':VALUE} to `accelerator_config`."}, + ) + + split_batches: Optional[bool] = field( + default=None, + metadata={"help": "Deprecated. Pass {'split_batches':True} to `accelerator_config`."}, + ) + + include_tokens_per_second: Optional[bool] = field( + default=False, + metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."}, + ) + + include_num_input_tokens_seen: Optional[bool] = field( + default=False, + metadata={ + "help": "If set to `True`, will track the number of input tokens seen throughout training. (May be slower in distributed training)" + }, + ) + + neftune_noise_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instrcution fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes." + }, + ) + + optim_target_modules: Union[None, str, List[str]] = field( + default=None, + metadata={ + "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." + }, + ) + + batch_eval_metrics: bool = field( + default=False, + metadata={"help": "Break eval metrics calculation into batches to save memory."}, + ) + + eval_on_start: bool = field( + default=False, + metadata={ + "help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check." + }, + ) + + use_liger_kernel: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to enable the Liger Kernel for model training."}, + ) + + eval_use_gather_object: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + + average_tokens_across_devices: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to " + "synchronize num_tokens_in_batch for precise loss calculation. Reference: " + "https://github.com/huggingface/transformers/issues/34242" + }, + ) + + def __post_init__(self): + # Parse in args that could be `dict` sent in from the CLI as a string + for field in _VALID_DICT_FIELDS: + passed_value = getattr(self, field) + # We only want to do this if the str starts with a bracket to indiciate a `dict` + # else its likely a filename if supported + if isinstance(passed_value, str) and passed_value.startswith("{"): + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, field, loaded_dict) + + # expand paths, if not os.makedirs("~/bar") will make directory + # in the current directory instead of the actual home + # see https://github.com/huggingface/transformers/issues/10628 + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + if self.logging_dir is None and self.output_dir is not None: + self.logging_dir = os.path.join(self.output_dir, default_logdir()) + if self.logging_dir is not None: + self.logging_dir = os.path.expanduser(self.logging_dir) + + if self.disable_tqdm is None: + self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN + + if self.evaluation_strategy is not None: + warnings.warn( + "`evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead", + FutureWarning, + ) + self.eval_strategy = self.evaluation_strategy + + if isinstance(self.eval_strategy, EvaluationStrategy): + warnings.warn( + "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5" + " of 🤗 Transformers. Use `IntervalStrategy` instead", + FutureWarning, + ) + # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. + self.eval_strategy = self.eval_strategy.value + if self.no_cuda: + warnings.warn( + "using `no_cuda` is deprecated and will be removed in version 5.0 of 🤗 Transformers. " + "Use `use_cpu` instead", + FutureWarning, + ) + self.use_cpu = self.no_cuda + + self.eval_strategy = IntervalStrategy(self.eval_strategy) + self.logging_strategy = IntervalStrategy(self.logging_strategy) + self.save_strategy = SaveStrategy(self.save_strategy) + self.hub_strategy = HubStrategy(self.hub_strategy) + + self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) + if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO: + self.do_eval = True + + if self.torch_empty_cache_steps is not None: + if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0): + raise ValueError( + f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}." + ) + + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero + if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): + if self.logging_steps > 0: + logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") + self.eval_steps = self.logging_steps + else: + raise ValueError( + f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or" + " --logging_steps" + ) + + # logging_steps must be non-zero for logging_strategy that is other than 'no' + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: + raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") + + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1: + if self.logging_steps != int(self.logging_steps): + raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") + self.logging_steps = int(self.logging_steps) + if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: + if self.eval_steps != int(self.eval_steps): + raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") + self.eval_steps = int(self.eval_steps) + if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1: + if self.save_steps != int(self.save_steps): + raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") + self.save_steps = int(self.save_steps) + + # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. + if self.load_best_model_at_end and self.save_strategy != SaveStrategy.BEST: + if self.eval_strategy != self.save_strategy: + raise ValueError( + "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " + f"strategy: {self.eval_strategy}\n- Save strategy: {self.save_strategy}" + ) + if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_steps < 1 or self.save_steps < 1: + if not (self.eval_steps < 1 and self.save_steps < 1): + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps " + f"{self.save_steps} and eval_steps {self.eval_steps}." + ) + # Work around floating point precision issues + LARGE_MULTIPLIER = 1_000_000 + if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." + ) + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." + ) + + safetensors_available = is_safetensors_available() + if self.save_safetensors and not safetensors_available: + raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!") + if not self.save_safetensors and safetensors_available: + logger.info( + f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. " + f"Safetensors should be a preferred weights saving format due to security and performance reasons. " + f"If your model cannot be saved by safetensors please feel free to open an issue at " + f"https://github.com/huggingface/safetensors!" + ) + + if ( + self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU + ) and self.metric_for_best_model is None: + self.metric_for_best_model = "loss" + if self.greater_is_better is None and self.metric_for_best_model is not None: + self.greater_is_better = not (self.metric_for_best_model.endswith("loss")) + if self.run_name is None: + self.run_name = self.output_dir + if self.framework == "pt" and is_torch_available(): + if self.fp16_backend and self.fp16_backend != "auto": + warnings.warn( + "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `half_precision_backend` instead", + FutureWarning, + ) + self.half_precision_backend = self.fp16_backend + + if self.bf16 or self.bf16_full_eval: + if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available(): + # cpu + raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") + elif not self.use_cpu: + if torch.cuda.is_available() and not is_torch_bf16_gpu_available(): + # gpu + raise ValueError( + "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" + ) + + if self.fp16 and self.bf16: + raise ValueError("At most one of fp16 and bf16 can be True, but not both") + + if self.fp16_full_eval and self.bf16_full_eval: + raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") + + if self.bf16: + if self.half_precision_backend == "apex": + raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.") + + if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: + if self.eval_strategy == IntervalStrategy.NO: + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") + if not is_torch_available(): + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") + + self.optim = OptimizerNames(self.optim) + if self.adafactor: + warnings.warn( + "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim" + " adafactor` instead", + FutureWarning, + ) + self.optim = OptimizerNames.ADAFACTOR + if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available(): + if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"): + raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher") + # there is a bug in fp16/AMP in pt-2.0.0 + if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: + raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") + + # We need to setup the accelerator config here *before* the first call to `self.device` + if is_accelerate_available(): + if not isinstance(self.accelerator_config, (AcceleratorConfig)): + if self.accelerator_config is None: + self.accelerator_config = AcceleratorConfig() + elif isinstance(self.accelerator_config, dict): + self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) + else: + self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + + if self.dispatch_batches is not None: + warnings.warn( + "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'dispatch_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.dispatch_batches = self.dispatch_batches + + if self.split_batches is not None: + warnings.warn( + "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'split_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.split_batches = self.split_batches + + # Initialize device before we proceed + if self.framework == "pt" and is_torch_available(): + self.device + + # Disable average tokens when using single device + if self.average_tokens_across_devices: + try: + if self.world_size == 1: + logger.warning( + "average_tokens_across_devices is set to True but it is invalid when world size is" + "1. Turn it to False automatically." + ) + self.average_tokens_across_devices = False + except ImportError as e: + logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.") + self.average_tokens_across_devices = False + + if self.torchdynamo is not None: + warnings.warn( + "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `torch_compile_backend` instead", + FutureWarning, + ) + self.torch_compile_backend = self.torchdynamo + if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: + self.torch_compile = True + if self.torch_compile and self.torch_compile_backend is None: + self.torch_compile_backend = "inductor" + + # accelerate integration for torch compile + if self.torch_compile: + # set env vars for accelerate + prefix = "ACCELERATE_DYNAMO_" + os.environ[prefix + "BACKEND"] = self.torch_compile_backend + if self.torch_compile_mode is not None: + os.environ[prefix + "MODE"] = self.torch_compile_mode + + if self.framework == "pt" and is_torch_available() and self.torch_compile: + if is_torch_tf32_available(): + if self.tf32 is None and not self.fp16 or self.bf16: + logger.info( + "Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement" + " otherwise." + ) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + else: + logger.warning( + "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." + ) + if self.framework == "pt" and is_torch_available() and self.tf32 is not None: + if self.tf32: + if is_torch_tf32_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + else: + raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") + else: + if is_torch_tf32_available(): + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + # no need to assert on else + + # if training args is specified, it will override the one specified in the accelerate config + if self.half_precision_backend != "apex": + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + + if self.report_to is None: + logger.info( + "The default value for the training argument `--report_to` will change in v5 (from all installed " + "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as " + "now. You should start updating your code and make this info disappear :-)." + ) + self.report_to = "all" + if self.report_to == "all" or self.report_to == ["all"]: + # Import at runtime to avoid a circular import. + from .integrations import get_available_reporting_integrations + + self.report_to = get_available_reporting_integrations() + + if "codecarbon" in self.report_to and torch.version.hip: + logger.warning( + "When using the Trainer, CodeCarbonCallback requires the `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). Automatically disabling the codecarbon callback. Reference: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to." + ) + self.report_to.remove("codecarbon") + + elif self.report_to == "none" or self.report_to == ["none"]: + self.report_to = [] + elif not isinstance(self.report_to, list): + self.report_to = [self.report_to] + + if self.warmup_ratio < 0 or self.warmup_ratio > 1: + raise ValueError("warmup_ratio must lie in range [0,1]") + elif self.warmup_ratio > 0 and self.warmup_steps > 0: + logger.info( + "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio" + " during training" + ) + + if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0: + raise ValueError("warmup_steps must be of type int and must be 0 or a positive integer.") + + if isinstance(self.fsdp, bool): + self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else "" + if isinstance(self.fsdp, str): + self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] + if self.fsdp == [FSDPOption.OFFLOAD]: + raise ValueError( + "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " + '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' + ) + elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: + raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + + if self.gradient_checkpointing and ( + FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp + ): + logger.warning( + "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" + " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather" + " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404" + ) + + if self.fsdp_config is None: + self.fsdp_config = {} + + if isinstance(self.fsdp_config, str): + if len(self.fsdp) == 0: + warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") + with io.open(self.fsdp_config, "r", encoding="utf-8") as f: + self.fsdp_config = json.load(f) + for k in list(self.fsdp_config.keys()): + if k.startswith("fsdp_"): + v = self.fsdp_config.pop(k) + self.fsdp_config[k[5:]] = v + + if self.fsdp_min_num_params > 0: + warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) + + self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params) + + # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object + if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str): + self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]] + + if self.fsdp_transformer_layer_cls_to_wrap is not None: + warnings.warn( + "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning + ) + self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get( + "transformer_layer_cls_to_wrap", [] + ) + [self.fsdp_transformer_layer_cls_to_wrap] + + if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0: + warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.") + + if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") + + if ( + len(self.fsdp) > 0 + and self.fsdp_config["min_num_params"] > 0 + and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None + ): + raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") + self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) + self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False) + self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) + if self.fsdp_config["xla"]: + if len(self.fsdp) > 0: + # store XLA fsdp configuration parameters into a dictionary + # Copy the config to avoid modifying the original config (which may be used for JSON serialization) + self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy() + # apply appropriate string to torch.dtype conversions for parameters + if "compute_dtype" in self.xla_fsdp_config: + self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) + if "buffer_dtype" in self.xla_fsdp_config: + self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"]) + else: + warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.") + else: + if self.fsdp_config["xla_fsdp_grad_ckpt"]: + warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") + + # accelerate integration for FSDP + if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + os.environ["ACCELERATE_USE_FSDP"] = "true" + from accelerate.utils.constants import ( + FSDP_AUTO_WRAP_POLICY, + FSDP_SHARDING_STRATEGY, + ) + + prefix = "FSDP_" + for fsdp_option in self.fsdp: + if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: + # set environment variable for FSDP sharding strategy + os.environ[f"{prefix}SHARDING_STRATEGY"] = str( + FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1 + ) + elif fsdp_option == FSDPOption.OFFLOAD: + os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" + elif fsdp_option == FSDPOption.AUTO_WRAP: + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + if self.fsdp_config["min_num_params"] > 0: + os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"]) + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["transformer_layer_cls_to_wrap"] + ) + prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") + os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() + + sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() + cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() + + if sync_module_states == "false" and cpu_ram_efficient_loading == "true": + # In this case, all the processes except the main process would have random weights leading + # to unexpected behaviour during training, thus throwing error here to prevent it. + raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') + + os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading + + os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() + + if self.tpu_metrics_debug: + warnings.warn( + "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `--debug tpu_metrics_debug` instead", + FutureWarning, + ) + if self.debug is None: + self.debug = " tpu_metrics_debug" + else: + self.debug += " tpu_metrics_debug" + self.tpu_metrics_debug = False + + if isinstance(self.debug, str): + self.debug = [DebugOption(s) for s in self.debug.split()] + elif self.debug is None: + self.debug = [] + + self.deepspeed_plugin = None + if self.deepspeed: + # - must be run very last in arg parsing, since it will use a lot of these settings. + # - must be run before the model is created. + if not is_accelerate_available(): + raise ValueError( + f"--deepspeed requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`." + ) + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + # will be used later by the Trainer + # note: leave self.deepspeed unmodified in case a user relies on it not to be modified) + self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) + self.hf_deepspeed_config.trainer_config_process(self) + + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) + elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")): + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + self.deepspeed_plugin = DeepSpeedPlugin() + mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + self.deepspeed_plugin.set_mixed_precision(mixed_precision) + self.deepspeed_plugin.set_deepspeed_weakref() + + if self.use_cpu: + self.dataloader_pin_memory = False + + if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None: + raise ValueError( + "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e." + " when --dataloader_num_workers > 1." + ) + + if self.push_to_hub_token is not None: + warnings.warn( + "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_token` instead.", + FutureWarning, + ) + self.hub_token = self.push_to_hub_token + + if self.push_to_hub_model_id is not None: + self.hub_model_id = get_full_repo_name( + self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token + ) + if self.push_to_hub_organization is not None: + warnings.warn( + "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in " + "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this " + f"argument (in this case {self.hub_model_id}).", + FutureWarning, + ) + else: + warnings.warn( + "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + elif self.push_to_hub_organization is not None: + self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}" + warnings.warn( + "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` > 0.30.0." + "This is not supported and we recommend you to update your version." + ) + + if self.data_seed is not None: + if not is_accelerate_available("1.1.0"): + raise NotImplementedError( + "data_seed requires Accelerate version `accelerate` >= 1.1.0. " + "This is not supported and we recommend you to update your version." + ) + + if self.include_inputs_for_metrics: + logger.warning( + "Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead." + ) + self.include_for_metrics.append("inputs") + + def __str__(self): + self_as_dict = asdict(self) + + # Remove deprecated arguments. That code should be removed once + # those deprecated arguments are removed from TrainingArguments. (TODO: v5) + del self_as_dict["per_gpu_train_batch_size"] + del self_as_dict["per_gpu_eval_batch_size"] + + self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()} + + attrs_as_str = [f"{k}={v},\n" for k, v in sorted(self_as_dict.items())] + return f"{self.__class__.__name__}(\n{''.join(attrs_as_str)})" + + __repr__ = __str__ + + @property + def train_batch_size(self) -> int: + """ + The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training). + """ + if self.per_gpu_train_batch_size: + logger.warning( + "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " + "version. Using `--per_device_train_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size + train_batch_size = per_device_batch_size * max(1, self.n_gpu) + return train_batch_size + + @property + def eval_batch_size(self) -> int: + """ + The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training). + """ + if self.per_gpu_eval_batch_size: + logger.warning( + "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " + "version. Using `--per_device_eval_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size + eval_batch_size = per_device_batch_size * max(1, self.n_gpu) + return eval_batch_size + + @property + def ddp_timeout_delta(self) -> timedelta: + """ + The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable. + """ + return timedelta(seconds=self.ddp_timeout) + + @cached_property + def _setup_devices(self) -> "torch.device": + requires_backends(self, ["torch"]) + logger.info("PyTorch: setting up devices") + if not is_sagemaker_mp_enabled(): + if not is_accelerate_available(): + raise ImportError( + f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " + f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + # We delay the init of `PartialState` to the end for clarity + accelerator_state_kwargs = {"enabled": True, "use_configured_state": False} + if isinstance(self.accelerator_config, AcceleratorConfig): + accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( + "use_configured_state", False + ) + if accelerator_state_kwargs["use_configured_state"]: + if PartialState._shared_state == {}: + raise ValueError( + "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured " + "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. " + ) + # We rely on `PartialState` to yell if there's issues here (which it will) + self.distributed_state = PartialState(cpu=self.use_cpu) + if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED: + raise RuntimeError( + "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, " + "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set " + "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly." + ) + else: + AcceleratorState._reset_state(reset_partial_state=True) + self.distributed_state = None + if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: + os.environ["ACCELERATE_USE_IPEX"] = "false" + + self._n_gpu = 1 + if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): + accelerator_state_kwargs["cpu"] = True + accelerator_state_kwargs["backend"] = self.ddp_backend + self._n_gpu = 0 + elif is_sagemaker_mp_enabled(): + accelerator_state_kwargs["enabled"] = False + local_rank = smp.local_rank() + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + elif is_sagemaker_dp_enabled(): + accelerator_state_kwargs["_use_sagemaker_dp"] = True + elif self.deepspeed: + accelerator_state_kwargs["use_deepspeed"] = True + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + else: + accelerator_state_kwargs["backend"] = self.ddp_backend + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + + # Now we pop everything + if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop( + "use_configured_state", False + ): + # We need to patch this env var when enabling to detect deepspeed + use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False) + if use_deepspeed: + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = PartialState(**accelerator_state_kwargs) + if use_deepspeed: + del os.environ["ACCELERATE_USE_DEEPSPEED"] + if not is_sagemaker_mp_enabled(): + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index + if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED: + logger.warning( + "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) + if is_torch_xla_available(): + device = self.distributed_state.device + self._n_gpu = 0 + elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): + # Already set _n_gpu + pass + elif self.distributed_state.distributed_type == DistributedType.NO: + if self.use_mps_device: + warnings.warn( + "`use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. " + "`mps` device will be used by default if available similar to the way `cuda` device is used." + "Therefore, no action from user is required. " + ) + if device.type != "mps": + raise ValueError( + "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ " + "or current PyTorch install was not built with MPS enabled." + ) + if self.use_cpu: + device = torch.device("cpu") + elif is_torch_mps_available(): + device = torch.device("mps") + elif is_torch_xpu_available(): + if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): + raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`") + device = torch.device("xpu:0") + torch.xpu.set_device(device) + elif is_torch_mlu_available(): + device = torch.device("mlu:0") + torch.mlu.set_device(device) + elif is_torch_musa_available(): + device = torch.device("musa:0") + torch.musa.set_device(device) + elif is_torch_npu_available(): + device = torch.device("npu:0") + torch.npu.set_device(device) + else: + # if n_gpu is > 1 we'll use nn.DataParallel. + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will + # trigger an error that a device index is missing. Index 0 takes into account the + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` + # will use the first GPU in that env, i.e. GPU#1 + device = torch.device( + "cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu") + ) + # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at + # the default value. + self._n_gpu = torch.cuda.device_count() + if device.type == "cuda": + torch.cuda.set_device(device) + return device + + @property + def device(self) -> "torch.device": + """ + The device used by this process. + """ + requires_backends(self, ["torch"]) + return self._setup_devices + + @property + def n_gpu(self): + """ + The number of GPUs used by this process. + + Note: + This will only be greater than one when you have multiple GPUs available but are not using distributed + training. For distributed training, it will always be 1. + """ + requires_backends(self, ["torch"]) + # Make sure `self._n_gpu` is properly setup. + if not hasattr(self, "_n_gpu"): + _ = self._setup_devices + return self._n_gpu + + @property + def parallel_mode(self): + """ + The current mode used for parallelism if multiple GPUs/TPU cores are available. One of: + + - `ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU). + - `ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses `torch.nn.DataParallel`). + - `ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses + `torch.nn.DistributedDataParallel`). + - `ParallelMode.TPU`: several TPU cores. + """ + requires_backends(self, ["torch"]) + if is_torch_xla_available(): + return ParallelMode.TPU + elif is_sagemaker_mp_enabled(): + return ParallelMode.SAGEMAKER_MODEL_PARALLEL + elif is_sagemaker_dp_enabled(): + return ParallelMode.SAGEMAKER_DATA_PARALLEL + elif ( + self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO + ) or (self.distributed_state is None and self.local_rank != -1): + return ParallelMode.DISTRIBUTED + elif self.n_gpu > 1: + return ParallelMode.NOT_DISTRIBUTED + else: + return ParallelMode.NOT_PARALLEL + + @property + def world_size(self): + """ + The number of processes used in parallel. + """ + requires_backends(self, ["torch"]) + if self.distributed_state is not None: + return self.distributed_state.num_processes + elif is_sagemaker_mp_enabled(): + return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() + return 1 + + @property + def process_index(self): + """ + The index of the current process used. + """ + requires_backends(self, ["torch"]) + if self.distributed_state is not None: + return self.distributed_state.process_index + elif is_sagemaker_mp_enabled(): + return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank() + return 0 + + @property + def local_process_index(self): + """ + The index of the local process used. + """ + requires_backends(self, ["torch"]) + + if self.distributed_state is not None: + return self.distributed_state.local_process_index + elif is_sagemaker_mp_enabled(): + return smp.local_rank() + return 0 + + @property + def should_log(self): + """ + Whether or not the current process should produce log. + """ + if self.log_on_each_node: + return self.local_process_index == 0 + else: + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.process_index == 0 + + @property + def should_save(self): + """ + Whether or not the current process should write to disk, e.g., to save models and checkpoints. + """ + if self.save_on_each_node: + return self.local_process_index == 0 + else: + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.process_index == 0 + + def get_process_log_level(self): + """ + Returns the log level to be used depending on whether this process is the main process of node 0, main process + of node non-0, or a non-main process. + + For the main process the log level defaults to the logging level set (`logging.WARNING` if you didn't do + anything) unless overridden by `log_level` argument. + + For the replica processes the log level defaults to `logging.WARNING` unless overridden by `log_level_replica` + argument. + + The choice between the main and replica process settings is made according to the return value of `should_log`. + """ + + # convert to int + log_level = trainer_log_levels[self.log_level] + log_level_replica = trainer_log_levels[self.log_level_replica] + + log_level_main_node = logging.get_verbosity() if log_level == -1 else log_level + log_level_replica_node = logging.get_verbosity() if log_level_replica == -1 else log_level_replica + return log_level_main_node if self.should_log else log_level_replica_node + + @property + def place_model_on_device(self): + """ + Can be subclassed and overridden for some specific integrations. + """ + return not is_sagemaker_mp_enabled() + + @property + def _no_sync_in_gradient_accumulation(self): + """ + Whether or not to use no_sync for the gradients when doing gradient accumulation. + """ + return not ( + self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled() or is_torch_neuroncore_available() + ) + + @contextlib.contextmanager + def main_process_first(self, local=True, desc="work"): + """ + A context manager for torch distributed environment where on needs to do something on the main process, while + blocking replicas, and when it's finished releasing the replicas. + + One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process, + which upon completion saves a cached version of results and which then automatically gets loaded by the + replicas. + + Args: + local (`bool`, *optional*, defaults to `True`): + if `True` first means process of rank 0 of each node if `False` first means process of rank 0 of node + rank 0 In multi-node environment with a shared filesystem you most likely will want to use + `local=False` so that only the main process of the first node will do the processing. If however, the + filesystem is not shared, then the main process of each node will need to do the processing, which is + the default behavior. + desc (`str`, *optional*, defaults to `"work"`): + a work description to be used in debug logs + + """ + if is_torch_available() and self.world_size > 1: + main_process_desc = "main local process" if local else "main process" + if self.distributed_state is not None: + is_main_process = ( + self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process + ) + elif is_sagemaker_mp_enabled(): + is_main_process = smp.rank() == 0 + + try: + if not is_main_process: + # tell all replicas to wait + logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") + + if is_torch_xla_available(): + xm.rendezvous(desc) + else: + dist.barrier() + yield + finally: + if is_main_process: + # the wait is over + logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") + if is_torch_xla_available(): + xm.rendezvous(desc) + else: + dist.barrier() + else: + yield + + def get_warmup_steps(self, num_training_steps: int): + """ + Get number of steps used for a linear warmup. + """ + warmup_steps = ( + self.warmup_steps if self.warmup_steps > 0 else math.ceil(num_training_steps * self.warmup_ratio) + ) + return warmup_steps + + def _dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self._dict_torch_dtype_to_str(value) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + # filter out fields that are defined as field(init=False) + d = {field.name: getattr(self, field.name) for field in fields(self) if field.init} + + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + # Handle the accelerator_config if passed + if is_accelerate_available() and isinstance(v, AcceleratorConfig): + d[k] = v.to_dict() + self._dict_torch_dtype_to_str(d) + + return d + + def to_json_string(self): + """ + Serializes this instance to a JSON string. + """ + return json.dumps(self.to_dict(), indent=2) + + def to_sanitized_dict(self) -> Dict[str, Any]: + """ + Sanitized serialization to use with TensorBoard’s hparams + """ + d = self.to_dict() + d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} + + valid_types = [bool, int, float, str] + if is_torch_available(): + valid_types.append(torch.Tensor) + + return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} + + # The following methods are there to simplify the instantiation of `TrainingArguments` + def set_training( + self, + learning_rate: float = 5e-5, + batch_size: int = 8, + weight_decay: float = 0, + num_epochs: float = 3, + max_steps: int = -1, + gradient_accumulation_steps: int = 1, + seed: int = 42, + gradient_checkpointing: bool = False, + ): + """ + A method that regroups all basic arguments linked to the training. + + + + Calling this method will automatically set `self.do_train` to `True`. + + + + Args: + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for the optimizer. + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for training. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in the + optimizer. + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents + of the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, + logging, evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training + examples. + + + + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use + the [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized + parameters. + gradient_checkpointing (`bool`, *optional*, defaults to `False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_training(learning_rate=1e-4, batch_size=32) + >>> args.learning_rate + 1e-4 + ``` + """ + self.do_train = True + self.learning_rate = learning_rate + self.per_device_train_batch_size = batch_size + self.weight_decay = weight_decay + self.num_train_epochs = num_epochs + self.max_steps = max_steps + self.gradient_accumulation_steps = gradient_accumulation_steps + self.seed = seed + self.gradient_checkpointing = gradient_checkpointing + return self + + def set_evaluate( + self, + strategy: Union[str, IntervalStrategy] = "no", + steps: int = 500, + batch_size: int = 8, + accumulation_steps: Optional[int] = None, + delay: Optional[float] = None, + loss_only: bool = False, + jit_mode: bool = False, + ): + """ + A method that regroups all arguments linked to evaluation. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + Setting a `strategy` different from `"no"` will set `self.do_eval` to `True`. + steps (`int`, *optional*, defaults to 500): + Number of update steps between two evaluations if `strategy="steps"`. + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for evaluation. + accumulation_steps (`int`, *optional*): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. + If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster + but requires more memory). + delay (`float`, *optional*): + Number of epochs or steps to wait for before the first evaluation can be performed, depending on the + eval_strategy. + loss_only (`bool`, *optional*, defaults to `False`): + Ignores all outputs except the loss. + jit_mode (`bool`, *optional*): + Whether or not to use PyTorch jit trace for inference. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_evaluate(strategy="steps", steps=100) + >>> args.eval_steps + 100 + ``` + """ + self.eval_strategy = IntervalStrategy(strategy) + if self.eval_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.do_eval = self.eval_strategy != IntervalStrategy.NO + self.eval_steps = steps + self.per_device_eval_batch_size = batch_size + self.eval_accumulation_steps = accumulation_steps + self.eval_delay = delay + self.prediction_loss_only = loss_only + self.jit_mode_eval = jit_mode + return self + + def set_testing( + self, + batch_size: int = 8, + loss_only: bool = False, + jit_mode: bool = False, + ): + """ + A method that regroups all basic arguments linked to testing on a held-out dataset. + + + + Calling this method will automatically set `self.do_predict` to `True`. + + + + Args: + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for testing. + loss_only (`bool`, *optional*, defaults to `False`): + Ignores all outputs except the loss. + jit_mode (`bool`, *optional*): + Whether or not to use PyTorch jit trace for inference. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_testing(batch_size=32) + >>> args.per_device_eval_batch_size + 32 + ``` + """ + self.do_predict = True + self.per_device_eval_batch_size = batch_size + self.prediction_loss_only = loss_only + self.jit_mode_eval = jit_mode + return self + + def set_save( + self, + strategy: Union[str, IntervalStrategy] = "steps", + steps: int = 500, + total_limit: Optional[int] = None, + on_each_node: bool = False, + ): + """ + A method that regroups all arguments linked to checkpoint saving. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + + steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `strategy="steps"`. + total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. + on_each_node (`bool`, *optional*, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or + only on the main one. + + This should not be activated when the different nodes use the same storage as the files will be saved + with the same names for each node. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_save(strategy="steps", steps=100) + >>> args.save_steps + 100 + ``` + """ + self.save_strategy = SaveStrategy(strategy) + if self.save_strategy == SaveStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.save_steps = steps + self.save_total_limit = total_limit + self.save_on_each_node = on_each_node + return self + + def set_logging( + self, + strategy: Union[str, IntervalStrategy] = "steps", + steps: int = 500, + report_to: Union[str, List[str]] = "none", + level: str = "passive", + first_step: bool = False, + nan_inf_filter: bool = False, + on_each_node: bool = False, + replica_level: str = "passive", + ): + """ + A method that regroups all arguments linked to logging. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + steps (`int`, *optional*, defaults to 500): + Number of update steps between two logs if `strategy="steps"`. + level (`str`, *optional*, defaults to `"passive"`): + Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`, + `"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything + and lets the application set the level. + report_to (`str` or `List[str]`, *optional*, defaults to `"all"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, + `"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, + `"none"` for no integrations. + first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + nan_inf_filter (`bool`, *optional*, defaults to `True`): + Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is + `nan` or `inf` is filtered and the average loss of the current logging window is taken instead. + + + + `nan_inf_filter` only influences the logging of loss values, it does not change the behavior the + gradient is computed or applied to the model. + + + + on_each_node (`bool`, *optional*, defaults to `True`): + In multinode distributed training, whether to log using `log_level` once per node, or only on the main + node. + replica_level (`str`, *optional*, defaults to `"passive"`): + Logger log level to use on replicas. Same choices as `log_level` + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_logging(strategy="steps", steps=100) + >>> args.logging_steps + 100 + ``` + """ + self.logging_strategy = IntervalStrategy(strategy) + if self.logging_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.logging_steps = steps + self.report_to = report_to + self.log_level = level + self.logging_first_step = first_step + self.logging_nan_inf_filter = nan_inf_filter + self.log_on_each_node = on_each_node + self.log_level_replica = replica_level + return self + + def set_push_to_hub( + self, + model_id: str, + strategy: Union[str, HubStrategy] = "every_save", + token: Optional[str] = None, + private_repo: Optional[bool] = None, + always_push: bool = False, + ): + """ + A method that regroups all arguments linked to synchronizing checkpoints with the Hub. + + + + Calling this method will set `self.push_to_hub` to `True`, which means the `output_dir` will begin a git + directory synced with the repo (determined by `model_id`) and the content will be pushed each time a save is + triggered (depending on your `self.save_strategy`). Calling [`~Trainer.save_model`] will also trigger a push. + + + + Args: + model_id (`str`): + The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository + name, for instance `"user_name/model"`, which allows you to push to an organization you are a member of + with `"organization_name/model"`. + strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`): + Defines the scope of what is pushed to the Hub and when. Possible values are: + + - `"end"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`]) and a + draft of a model card when the [`~Trainer.save_model`] method is called. + - `"every_save"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`]) + and + a draft of a model card each time there is a model save. The pushes are asynchronous to not block + training, and in case the save are very frequent, a new push is only attempted if the previous one is + finished. A last push is made with the final model at the end of training. + - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named + last-checkpoint, allowing you to resume training easily with + `trainer.train(resume_from_checkpoint="last-checkpoint")`. + - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the + output + folder (so you will get one checkpoint folder per folder in your final repository) + + token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained + with `huggingface-cli login`. + private_repo (`bool`, *optional*, defaults to `False`): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + always_push (`bool`, *optional*, defaults to `False`): + Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not + finished. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_push_to_hub("me/awesome-model") + >>> args.hub_model_id + 'me/awesome-model' + ``` + """ + self.push_to_hub = True + self.hub_model_id = model_id + self.hub_strategy = HubStrategy(strategy) + self.hub_token = token + self.hub_private_repo = private_repo + self.hub_always_push = always_push + return self + + def set_optimizer( + self, + name: Union[str, OptimizerNames] = "adamw_torch", + learning_rate: float = 5e-5, + weight_decay: float = 0, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-8, + args: Optional[str] = None, + ): + """ + A method that regroups all arguments linked to the optimizer and its hyperparameters. + + Args: + name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`): + The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`, + `"adamw_anyprecision"` or `"adafactor"`. + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights. + beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the adam optimizer or its variants. + beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the adam optimizer or its variants. + epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the adam optimizer or its variants. + args (`str`, *optional*): + Optional arguments that are supplied to AnyPrecisionAdamW (only useful when + `optim="adamw_anyprecision"`). + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_optimizer(name="adamw_torch", beta1=0.8) + >>> args.optim + 'adamw_torch' + ``` + """ + self.optim = OptimizerNames(name) + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.adam_beta1 = beta1 + self.adam_beta2 = beta2 + self.adam_epsilon = epsilon + self.optim_args = args + return self + + def set_lr_scheduler( + self, + name: Union[str, SchedulerType] = "linear", + num_epochs: float = 3.0, + max_steps: int = -1, + warmup_ratio: float = 0, + warmup_steps: int = 0, + ): + """ + A method that regroups all arguments linked to the learning rate scheduler and its hyperparameters. + + Args: + name (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): + The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + num_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents + of the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of + `warmup_ratio`. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_lr_scheduler(name="cosine", warmup_ratio=0.05) + >>> args.warmup_ratio + 0.05 + ``` + """ + self.lr_scheduler_type = SchedulerType(name) + self.num_train_epochs = num_epochs + self.max_steps = max_steps + self.warmup_ratio = warmup_ratio + self.warmup_steps = warmup_steps + return self + + def set_dataloader( + self, + train_batch_size: int = 8, + eval_batch_size: int = 8, + drop_last: bool = False, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + prefetch_factor: Optional[int] = None, + auto_find_batch_size: bool = False, + ignore_data_skip: bool = False, + sampler_seed: Optional[int] = None, + ): + """ + A method that regroups all arguments linked to the dataloaders creation. + + Args: + drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch + size) or not. + num_workers (`int`, *optional*, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in + the main process. + pin_memory (`bool`, *optional*, defaults to `True`): + Whether you want to pin memory in data loaders or not. Will default to `True`. + persistent_workers (`bool`, *optional*, defaults to `False`): + If True, the data loader will not shut down the worker processes after a dataset has been consumed + once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, + but will increase RAM usage. Will default to `False`. + prefetch_factor (`int`, *optional*): + Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + auto_find_batch_size (`bool`, *optional*, defaults to `False`) + Whether to find a batch size that will fit into memory automatically through exponential decay, + avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the + same stage as in the previous training. If set to `True`, the training will begin faster (as that + skipping step can take a long time) but will not yield the same results as the interrupted training + would have. + sampler_seed (`int`, *optional*): + Random seed to be used with data samplers. If not set, random generators for data sampling will use the + same seed as `self.seed`. This can be used to ensure reproducibility of data sampling, independent of + the model seed. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_dataloader(train_batch_size=16, eval_batch_size=64) + >>> args.per_device_train_batch_size + 16 + ``` + """ + self.per_device_train_batch_size = train_batch_size + self.per_device_eval_batch_size = eval_batch_size + self.dataloader_drop_last = drop_last + self.dataloader_num_workers = num_workers + self.dataloader_pin_memory = pin_memory + self.dataloader_persistent_workers = persistent_workers + self.dataloader_prefetch_factor = prefetch_factor + self.auto_find_batch_size = auto_find_batch_size + self.ignore_data_skip = ignore_data_skip + self.data_seed = sampler_seed + return self + + +class ParallelMode(Enum): + NOT_PARALLEL = "not_parallel" + NOT_DISTRIBUTED = "not_distributed" + DISTRIBUTED = "distributed" + SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel" + SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel" + TPU = "tpu" diff --git a/training_args_seq2seq.py b/training_args_seq2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..5342b7add3932c542e35247e52920d8fc91ed325 --- /dev/null +++ b/training_args_seq2seq.py @@ -0,0 +1,90 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional, Union + +from .generation.configuration_utils import GenerationConfig +from .training_args import TrainingArguments +from .utils import add_start_docstrings + + +logger = logging.getLogger(__name__) + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class Seq2SeqTrainingArguments(TrainingArguments): + """ + Args: + predict_with_generate (`bool`, *optional*, defaults to `False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + generation_max_length (`int`, *optional*): + The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the + `max_length` value of the model configuration. + generation_num_beams (`int`, *optional*): + The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the + `num_beams` value of the model configuration. + generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*): + Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`. + - a [`~generation.GenerationConfig`] object. + """ + + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + generation_max_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `max_length` value of the model configuration." + ) + }, + ) + generation_num_beams: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `num_beams` value of the model configuration." + ) + }, + ) + generation_config: Optional[Union[str, Path, GenerationConfig]] = field( + default=None, + metadata={ + "help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction." + }, + ) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values and `GenerationConfig` by dictionaries (for JSON + serialization support). It obfuscates the token values by removing their value. + """ + # filter out fields that are defined as field(init=False) + d = super().to_dict() + for k, v in d.items(): + if isinstance(v, GenerationConfig): + d[k] = v.to_dict() + return d diff --git a/training_args_tf.py b/training_args_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..3716a78879d50170149f2fe4fdf4f5e1740b4cec --- /dev/null +++ b/training_args_tf.py @@ -0,0 +1,299 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Optional, Tuple + +from .training_args import TrainingArguments +from .utils import cached_property, is_tf_available, logging, requires_backends + + +logger = logging.get_logger(__name__) + +if is_tf_available(): + import tensorflow as tf + + from .modeling_tf_utils import keras + + +@dataclass +class TFTrainingArguments(TrainingArguments): + """ + TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop + itself**. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + output_dir (`str`): + The output directory where the model predictions and checkpoints will be written. + overwrite_output_dir (`bool`, *optional*, defaults to `False`): + If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` + points to a checkpoint directory. + do_train (`bool`, *optional*, defaults to `False`): + Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used + by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_eval (`bool`, *optional*): + Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is + different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your + training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_predict (`bool`, *optional*, defaults to `False`): + Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's + intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `eval_steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + per_device_train_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU/TPU core/CPU for training. + per_device_eval_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU/TPU core/CPU for evaluation. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, + evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples. + + + + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for Adam. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero). + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the Adam optimizer. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the Adam optimizer. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the Adam optimizer. + max_grad_norm (`float`, *optional*, defaults to 1.0): + Maximum gradient norm (for gradient clipping). + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform. + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`. + logging_dir (`str`, *optional*): + [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to + *runs/**CURRENT_DATETIME_HOSTNAME***. + logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + logging_first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + logging_steps (`int`, *optional*, defaults to 500): + Number of update steps between two logs if `logging_strategy="steps"`. + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + + save_steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `save_strategy="steps"`. + save_total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. + no_cuda (`bool`, *optional*, defaults to `False`): + Whether to not use CUDA even when it is available or not. + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. + fp16 (`bool`, *optional*, defaults to `False`): + Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training. + fp16_opt_level (`str`, *optional*, defaults to 'O1'): + For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on + the [Apex documentation](https://nvidia.github.io/apex/amp). + local_rank (`int`, *optional*, defaults to -1): + During distributed training, the rank of the process. + tpu_num_cores (`int`, *optional*): + When training on TPU, the number of TPU cores (automatically passed by launcher script). + debug (`bool`, *optional*, defaults to `False`): + Whether to activate the trace to record computation graphs and profiling information or not. + dataloader_drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) + or not. + eval_steps (`int`, *optional*, defaults to 1000): + Number of update steps before two evaluations. + past_index (`int`, *optional*, defaults to -1): + Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make + use of the past hidden states for their predictions. If this argument is set to a positive int, the + `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at + the next training step under the keyword argument `mems`. + tpu_name (`str`, *optional*): + The name of the TPU the process is running on. + tpu_zone (`str`, *optional*): + The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect + from metadata. + gcp_project (`str`, *optional*): + Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to + automatically detect from metadata. + run_name (`str`, *optional*): + A descriptor for the run. Notably used for wandb, mlflow and comet logging. + xla (`bool`, *optional*): + Whether to activate the XLA compilation or not. + """ + + framework = "tf" + tpu_name: Optional[str] = field( + default=None, + metadata={"help": "Name of TPU"}, + ) + + tpu_zone: Optional[str] = field( + default=None, + metadata={"help": "Zone of TPU"}, + ) + + gcp_project: Optional[str] = field( + default=None, + metadata={"help": "Name of Cloud TPU-enabled project"}, + ) + + poly_power: float = field( + default=1.0, + metadata={"help": "Power for the Polynomial decay LR scheduler."}, + ) + + xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"}) + + @cached_property + def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: + requires_backends(self, ["tf"]) + logger.info("Tensorflow: setting up strategy") + + gpus = tf.config.list_physical_devices("GPU") + + # Set to float16 at first + if self.fp16: + keras.mixed_precision.set_global_policy("mixed_float16") + + if self.no_cuda: + strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") + else: + try: + if self.tpu_name: + tpu = tf.distribute.cluster_resolver.TPUClusterResolver( + self.tpu_name, zone=self.tpu_zone, project=self.gcp_project + ) + else: + tpu = tf.distribute.cluster_resolver.TPUClusterResolver() + except ValueError: + if self.tpu_name: + raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!") + else: + tpu = None + + if tpu: + # Set to bfloat16 in case of TPU + if self.fp16: + keras.mixed_precision.set_global_policy("mixed_bfloat16") + + tf.config.experimental_connect_to_cluster(tpu) + tf.tpu.experimental.initialize_tpu_system(tpu) + + strategy = tf.distribute.TPUStrategy(tpu) + + elif len(gpus) == 0: + strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") + elif len(gpus) == 1: + strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") + elif len(gpus) > 1: + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + strategy = tf.distribute.MirroredStrategy() + else: + raise ValueError("Cannot find the proper strategy, please check your environment properties.") + + return strategy + + @property + def strategy(self) -> "tf.distribute.Strategy": + """ + The strategy used for distributed training. + """ + requires_backends(self, ["tf"]) + return self._setup_strategy + + @property + def n_replicas(self) -> int: + """ + The number of replicas (CPUs, GPUs or TPU cores) used in this training. + """ + requires_backends(self, ["tf"]) + return self._setup_strategy.num_replicas_in_sync + + @property + def should_log(self): + """ + Whether or not the current process should produce log. + """ + return False # TF Logging is handled by Keras not the Trainer + + @property + def train_batch_size(self) -> int: + """ + The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training). + """ + if self.per_gpu_train_batch_size: + logger.warning( + "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " + "version. Using `--per_device_train_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size + return per_device_batch_size * self.n_replicas + + @property + def eval_batch_size(self) -> int: + """ + The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training). + """ + if self.per_gpu_eval_batch_size: + logger.warning( + "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " + "version. Using `--per_device_eval_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size + return per_device_batch_size * self.n_replicas + + @property + def n_gpu(self) -> int: + """ + The number of replicas (CPUs, GPUs or TPU cores) used in this training. + """ + requires_backends(self, ["tf"]) + warnings.warn( + "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.", + FutureWarning, + ) + return self._setup_strategy.num_replicas_in_sync