from transformers import PretrainedConfig class InfinityFormerConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`InfinityFormerModel`]. It is used to instantiate an InfinityFormer model according to the specified arguments, defining the model architecture. """ model_type = "infinity_former" def __init__(self, **kwargs): self.vocab_size = kwargs.pop("vocab_size", 151669) self.hidden_size = kwargs.pop("hidden_size", 768) self.num_hidden_layers = kwargs.pop("num_hidden_layers", 54) self.num_attention_heads = kwargs.pop("num_attention_heads", 12) self.intermediate_size = kwargs.pop("intermediate_size", 3072) self.hidden_dropout_prob = kwargs.pop("hidden_dropout_prob", 0.1) self.attention_probs_dropout_prob = kwargs.pop("attention_probs_dropout_prob", 0.1) self.max_position_embeddings = kwargs.pop("max_position_embeddings", 8192) self.initializer_range = kwargs.pop("initializer_range", 0.02) self.layer_norm_eps = kwargs.pop("layer_norm_eps", 1e-5) self.use_rotary_embeddings = kwargs.pop("use_rotary_embeddings", True) self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000) self.use_multi_scale_memory = kwargs.pop("use_multi_scale_memory", True) self.num_memory_scales = kwargs.pop("num_memory_scales", 3) self.memory_compression_ratio = kwargs.pop("memory_compression_ratio", 0.5) self.memory_compression_frequency = kwargs.pop("memory_compression_frequency", 100) self.kernel_type = kwargs.pop("kernel_type", 'elu') self.kernel_epsilon = kwargs.pop("kernel_epsilon", 0.1) self.use_gating = kwargs.pop("use_gating", True) self.gate_init_bias = kwargs.pop("gate_init_bias", -2.0) self.use_memory_attention = kwargs.pop("use_memory_attention", False) self.use_gradient_checkpointing = kwargs.pop("use_gradient_checkpointing", False) use_return_dict = kwargs.pop("use_return_dict", True) super().__init__(**kwargs) self.return_dict = use_return_dict if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"`hidden_size` ({self.hidden_size}) must be a multiple of `num_attention_heads` " f"({self.num_attention_heads})" ) if self.kernel_type not in ['elu', 'relu', 'learnable']: raise ValueError(f"`kernel_type` must be one of 'elu', 'relu', or 'learnable', got {self.kernel_type}")