|
from transformers import PretrainedConfig |
|
|
|
class InfinityFormerConfig(PretrainedConfig): |
|
r""" |
|
This is the configuration class to store the configuration of a [`InfinityFormerModel`]. It is used to instantiate an |
|
InfinityFormer model according to the specified arguments, defining the model architecture. |
|
""" |
|
model_type = "infinity_former" |
|
|
|
def __init__(self, **kwargs): |
|
self.vocab_size = kwargs.pop("vocab_size", 151669) |
|
self.hidden_size = kwargs.pop("hidden_size", 768) |
|
self.num_hidden_layers = kwargs.pop("num_hidden_layers", 54) |
|
self.num_attention_heads = kwargs.pop("num_attention_heads", 12) |
|
self.intermediate_size = kwargs.pop("intermediate_size", 3072) |
|
self.hidden_dropout_prob = kwargs.pop("hidden_dropout_prob", 0.1) |
|
self.attention_probs_dropout_prob = kwargs.pop("attention_probs_dropout_prob", 0.1) |
|
self.max_position_embeddings = kwargs.pop("max_position_embeddings", 8192) |
|
self.initializer_range = kwargs.pop("initializer_range", 0.02) |
|
self.layer_norm_eps = kwargs.pop("layer_norm_eps", 1e-5) |
|
self.use_rotary_embeddings = kwargs.pop("use_rotary_embeddings", True) |
|
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000) |
|
self.use_multi_scale_memory = kwargs.pop("use_multi_scale_memory", True) |
|
self.num_memory_scales = kwargs.pop("num_memory_scales", 3) |
|
self.memory_compression_ratio = kwargs.pop("memory_compression_ratio", 0.5) |
|
self.memory_compression_frequency = kwargs.pop("memory_compression_frequency", 100) |
|
self.kernel_type = kwargs.pop("kernel_type", 'elu') |
|
self.kernel_epsilon = kwargs.pop("kernel_epsilon", 0.1) |
|
self.use_gating = kwargs.pop("use_gating", True) |
|
self.gate_init_bias = kwargs.pop("gate_init_bias", -2.0) |
|
self.use_memory_attention = kwargs.pop("use_memory_attention", False) |
|
self.use_gradient_checkpointing = kwargs.pop("use_gradient_checkpointing", False) |
|
|
|
use_return_dict = kwargs.pop("use_return_dict", True) |
|
super().__init__(**kwargs) |
|
self.return_dict = use_return_dict |
|
|
|
if self.hidden_size % self.num_attention_heads != 0: |
|
raise ValueError( |
|
f"`hidden_size` ({self.hidden_size}) must be a multiple of `num_attention_heads` " |
|
f"({self.num_attention_heads})" |
|
) |
|
if self.kernel_type not in ['elu', 'relu', 'learnable']: |
|
raise ValueError(f"`kernel_type` must be one of 'elu', 'relu', or 'learnable', got {self.kernel_type}") |
|
|