eyad-silx commited on
Commit
f55b26b
·
verified ·
1 Parent(s): d5cfeec

Update modeling_quasrav4.py

Browse files
Files changed (1) hide show
  1. modeling_quasrav4.py +19 -19
modeling_quasrav4.py CHANGED
@@ -8,7 +8,7 @@ from transformers import PreTrainedModel
8
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
9
  from transformers.utils import logging
10
 
11
- from .configuration_quasrav4 import QuasraV4Config
12
 
13
  logger = logging.get_logger(__name__)
14
 
@@ -41,7 +41,7 @@ class RotaryPositionEmbedding(nn.Module):
41
  return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2))
42
 
43
  class KernelFunction(nn.Module):
44
- def __init__(self, config: QuasraV4Config):
45
  super().__init__()
46
  self.kernel_type = config.kernel_type
47
  self.epsilon = config.kernel_epsilon
@@ -59,7 +59,7 @@ class KernelFunction(nn.Module):
59
  raise ValueError(f"Unknown kernel type: {self.kernel_type}")
60
 
61
  class GatedFeedForward(nn.Module):
62
- def __init__(self, config: QuasraV4Config):
63
  super().__init__()
64
  self.hidden_size = config.hidden_size
65
  self.intermediate_size = config.intermediate_size
@@ -80,7 +80,7 @@ class GatedFeedForward(nn.Module):
80
  return hidden_states + residual
81
 
82
  class LinearAttention(nn.Module):
83
- def __init__(self, config: QuasraV4Config, layer_idx: int = 0):
84
  super().__init__()
85
  self.hidden_size = config.hidden_size
86
  self.num_heads = config.num_attention_heads
@@ -129,8 +129,8 @@ class LinearAttention(nn.Module):
129
 
130
  # --- Main Model Components ---
131
 
132
- class QuasraV4Layer(nn.Module):
133
- def __init__(self, config: QuasraV4Config, layer_idx: int):
134
  super().__init__()
135
  self.embed_dim = config.hidden_size
136
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -145,8 +145,8 @@ class QuasraV4Layer(nn.Module):
145
  hidden_states = self.ffn(hidden_states)
146
  return (hidden_states,)
147
 
148
- class QuasraV4Embeddings(nn.Module):
149
- def __init__(self, config: QuasraV4Config):
150
  super().__init__()
151
  self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0)
152
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
@@ -166,11 +166,11 @@ class QuasraV4Embeddings(nn.Module):
166
  embeddings = self.dropout(embeddings)
167
  return embeddings
168
 
169
- class QuasraV4PreTrainedModel(PreTrainedModel):
170
- config_class = QuasraV4Config
171
- base_model_prefix = "model"
172
  supports_gradient_checkpointing = True
173
- _no_split_modules = ["QuasraV4Layer"]
174
 
175
  def _init_weights(self, module):
176
  std = self.config.initializer_range
@@ -183,12 +183,12 @@ class QuasraV4PreTrainedModel(PreTrainedModel):
183
  if module.padding_idx is not None:
184
  module.weight.data[module.padding_idx].zero_()
185
 
186
- class QuasraV4Model(QuasraV4PreTrainedModel):
187
- def __init__(self, config: QuasraV4Config):
188
  super().__init__(config)
189
  self.config = config
190
- self.embeddings = QuasraV4Embeddings(config)
191
- self.layers = nn.ModuleList([QuasraV4Layer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
192
  self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
193
  self.gradient_checkpointing = False
194
  self.post_init()
@@ -213,12 +213,12 @@ class QuasraV4Model(QuasraV4PreTrainedModel):
213
  return (hidden_states,)
214
  return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
215
 
216
- class QuasraV4ForCausalLM(QuasraV4PreTrainedModel):
217
  _auto_class = "AutoModelForCausalLM"
218
 
219
- def __init__(self, config: QuasraV4Config):
220
  super().__init__(config)
221
- self.model = QuasraV4Model(config)
222
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
223
  self.post_init()
224
 
 
8
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
9
  from transformers.utils import logging
10
 
11
+ from .configuration_quasrav4 import InfinityFormerConfig
12
 
13
  logger = logging.get_logger(__name__)
14
 
 
41
  return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2))
42
 
43
  class KernelFunction(nn.Module):
44
+ def __init__(self, config: InfinityFormerConfig):
45
  super().__init__()
46
  self.kernel_type = config.kernel_type
47
  self.epsilon = config.kernel_epsilon
 
59
  raise ValueError(f"Unknown kernel type: {self.kernel_type}")
60
 
61
  class GatedFeedForward(nn.Module):
62
+ def __init__(self, config: InfinityFormerConfig):
63
  super().__init__()
64
  self.hidden_size = config.hidden_size
65
  self.intermediate_size = config.intermediate_size
 
80
  return hidden_states + residual
81
 
82
  class LinearAttention(nn.Module):
83
+ def __init__(self, config: InfinityFormerConfig, layer_idx: int = 0):
84
  super().__init__()
85
  self.hidden_size = config.hidden_size
86
  self.num_heads = config.num_attention_heads
 
129
 
130
  # --- Main Model Components ---
131
 
132
+ class InfinityFormerLayer(nn.Module):
133
+ def __init__(self, config: InfinityFormerConfig, layer_idx: int):
134
  super().__init__()
135
  self.embed_dim = config.hidden_size
136
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
 
145
  hidden_states = self.ffn(hidden_states)
146
  return (hidden_states,)
147
 
148
+ class InfinityFormerEmbeddings(nn.Module):
149
+ def __init__(self, config: InfinityFormerConfig):
150
  super().__init__()
151
  self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0)
152
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
 
166
  embeddings = self.dropout(embeddings)
167
  return embeddings
168
 
169
+ class InfinityFormerPreTrainedModel(PreTrainedModel):
170
+ config_class = InfinityFormerConfig
171
+ base_model_prefix = "infinity_former"
172
  supports_gradient_checkpointing = True
173
+ _no_split_modules = ["InfinityFormerLayer"]
174
 
175
  def _init_weights(self, module):
176
  std = self.config.initializer_range
 
183
  if module.padding_idx is not None:
184
  module.weight.data[module.padding_idx].zero_()
185
 
186
+ class InfinityFormerModel(InfinityFormerPreTrainedModel):
187
+ def __init__(self, config: InfinityFormerConfig):
188
  super().__init__(config)
189
  self.config = config
190
+ self.embeddings = InfinityFormerEmbeddings(config)
191
+ self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
192
  self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
193
  self.gradient_checkpointing = False
194
  self.post_init()
 
213
  return (hidden_states,)
214
  return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
215
 
216
+ class InfinityFormerForCausalLM(InfinityFormerPreTrainedModel):
217
  _auto_class = "AutoModelForCausalLM"
218
 
219
+ def __init__(self, config: InfinityFormerConfig):
220
  super().__init__(config)
221
+ self.model = InfinityFormerModel(config)
222
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
223
  self.post_init()
224