eyad-silx commited on
Commit
3921342
·
verified ·
1 Parent(s): bab819c

Create configuration_quasrav4.py

Browse files
Files changed (1) hide show
  1. configuration_quasrav4.py +58 -0
configuration_quasrav4.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from transformers import PretrainedConfig
3
+
4
+ @dataclass
5
+ class QuasraV4Config(PretrainedConfig):
6
+ """
7
+ Configuration class for the QuasraV4 model.
8
+ This class stores the configuration of a QuasraV4Model. It is used to instantiate the
9
+ model according to the specified arguments, defining the model architecture.
10
+ """
11
+ model_type = "quasrav4"
12
+
13
+ # Model dimensions
14
+ vocab_size: int = 151936
15
+ hidden_size: int = 768
16
+ num_hidden_layers: int = 12
17
+ num_attention_heads: int = 12
18
+ intermediate_size: int = 3072
19
+ hidden_dropout_prob: float = 0.1
20
+ attention_probs_dropout_prob: float = 0.1
21
+ max_position_embeddings: int = 2048
22
+ initializer_range: float = 0.02
23
+ layer_norm_eps: float = 1e-5
24
+ tie_word_embeddings: bool = True
25
+ use_return_dict: bool = True
26
+ use_cache: bool = True
27
+ output_attentions: bool = False
28
+ output_hidden_states: bool = False
29
+
30
+ # QuasraV4/InfinityFormer specific parameters
31
+ use_rotary_embeddings: bool = True
32
+ rotary_embedding_base: int = 10000
33
+ use_multi_scale_memory: bool = True
34
+ num_memory_scales: int = 3
35
+ memory_compression_ratio: float = 0.5
36
+ memory_compression_frequency: int = 100
37
+ kernel_type: str = 'elu'
38
+ kernel_epsilon: float = 0.1
39
+
40
+ # Gating mechanism
41
+ use_gating: bool = True
42
+ gate_init_bias: float = -2.0
43
+
44
+ # Training parameters
45
+ use_gradient_checkpointing: bool = False
46
+ gradient_checkpointing_use_reentrant: bool = True
47
+ gradient_checkpointing_frequency: int = 1
48
+ pruned_heads: dict = field(default_factory=dict)
49
+
50
+ def __post_init__(self):
51
+ super().__post_init__()
52
+ if self.hidden_size % self.num_attention_heads != 0:
53
+ raise ValueError(
54
+ f"`hidden_size` ({self.hidden_size}) must be a multiple of `num_attention_heads` "
55
+ f"({self.num_attention_heads})"
56
+ )
57
+ if self.kernel_type not in ['elu', 'relu', 'learnable']:
58
+ raise ValueError(f"`kernel_type` must be one of 'elu', 'relu', or 'learnable', got {self.kernel_type}")