QuasarV4-Tiny / configuration_quasrav4.py
eyad-silx's picture
Update configuration_quasrav4.py
523f4fd verified
raw
history blame
2.91 kB
from transformers import PretrainedConfig
class QuasraV4Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`QuasraV4Model`]. It is used to instantiate a
QuasraV4 model according to the specified arguments, defining the model architecture.
"""
model_type = "quasarv4"
def __init__(self, **kwargs):
# Pop custom arguments from kwargs, using defaults from your config.json
self.vocab_size = kwargs.pop("vocab_size", 151669)
self.hidden_size = kwargs.pop("hidden_size", 768)
self.num_hidden_layers = kwargs.pop("num_hidden_layers", 54)
self.num_attention_heads = kwargs.pop("num_attention_heads", 12)
self.intermediate_size = kwargs.pop("intermediate_size", 3072)
self.hidden_dropout_prob = kwargs.pop("hidden_dropout_prob", 0.1)
self.attention_probs_dropout_prob = kwargs.pop("attention_probs_dropout_prob", 0.1)
self.max_position_embeddings = kwargs.pop("max_position_embeddings", 812)
self.initializer_range = kwargs.pop("initializer_range", 0.02)
self.layer_norm_eps = kwargs.pop("layer_norm_eps", 1e-5)
self.use_rotary_embeddings = kwargs.pop("use_rotary_embeddings", True)
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
self.use_multi_scale_memory = kwargs.pop("use_multi_scale_memory", True)
self.num_memory_scales = kwargs.pop("num_memory_scales", 3)
self.memory_compression_ratio = kwargs.pop("memory_compression_ratio", 0.5)
self.memory_compression_frequency = kwargs.pop("memory_compression_frequency", 100)
self.kernel_type = kwargs.pop("kernel_type", 'elu')
self.kernel_epsilon = kwargs.pop("kernel_epsilon", 0.1)
self.use_gating = kwargs.pop("use_gating", True)
self.gate_init_bias = kwargs.pop("gate_init_bias", -2.0)
self.use_gradient_checkpointing = kwargs.pop("use_gradient_checkpointing", False)
# The `use_return_dict` is a read-only property that depends on `return_dict`.
# We must pop it from kwargs before calling super().__init__ to avoid an error.
use_return_dict = kwargs.pop("use_return_dict", True)
# Pass the rest of the arguments to the parent class.
super().__init__(**kwargs)
# Now, set the underlying attribute that the `use_return_dict` property uses.
self.return_dict = use_return_dict
# Validation logic
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"`hidden_size` ({self.hidden_size}) must be a multiple of `num_attention_heads` "
f"({self.num_attention_heads})"
)
if self.kernel_type not in ['elu', 'relu', 'learnable']:
raise ValueError(f"`kernel_type` must be one of 'elu', 'relu', or 'learnable', got {self.kernel_type}")