QuasarV4-Tiny / configuration_quasrav4.py
eyad-silx's picture
Update configuration_quasrav4.py
a44b70f verified
from transformers import PretrainedConfig
class InfinityFormerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InfinityFormerModel`]. It is used to instantiate an
InfinityFormer model according to the specified arguments, defining the model architecture.
"""
model_type = "infinity_former"
def __init__(self, **kwargs):
self.vocab_size = kwargs.pop("vocab_size", 151669)
self.hidden_size = kwargs.pop("hidden_size", 768)
self.num_hidden_layers = kwargs.pop("num_hidden_layers", 54)
self.num_attention_heads = kwargs.pop("num_attention_heads", 12)
self.intermediate_size = kwargs.pop("intermediate_size", 3072)
self.hidden_dropout_prob = kwargs.pop("hidden_dropout_prob", 0.1)
self.attention_probs_dropout_prob = kwargs.pop("attention_probs_dropout_prob", 0.1)
self.max_position_embeddings = kwargs.pop("max_position_embeddings", 8192)
self.initializer_range = kwargs.pop("initializer_range", 0.02)
self.layer_norm_eps = kwargs.pop("layer_norm_eps", 1e-5)
self.use_rotary_embeddings = kwargs.pop("use_rotary_embeddings", True)
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
self.use_multi_scale_memory = kwargs.pop("use_multi_scale_memory", True)
self.num_memory_scales = kwargs.pop("num_memory_scales", 3)
self.memory_compression_ratio = kwargs.pop("memory_compression_ratio", 0.5)
self.memory_compression_frequency = kwargs.pop("memory_compression_frequency", 100)
self.kernel_type = kwargs.pop("kernel_type", 'elu')
self.kernel_epsilon = kwargs.pop("kernel_epsilon", 0.1)
self.use_gating = kwargs.pop("use_gating", True)
self.gate_init_bias = kwargs.pop("gate_init_bias", -2.0)
self.use_memory_attention = kwargs.pop("use_memory_attention", False)
self.use_gradient_checkpointing = kwargs.pop("use_gradient_checkpointing", False)
use_return_dict = kwargs.pop("use_return_dict", True)
super().__init__(**kwargs)
self.return_dict = use_return_dict
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"`hidden_size` ({self.hidden_size}) must be a multiple of `num_attention_heads` "
f"({self.num_attention_heads})"
)
if self.kernel_type not in ['elu', 'relu', 'learnable']:
raise ValueError(f"`kernel_type` must be one of 'elu', 'relu', or 'learnable', got {self.kernel_type}")