eyad-silx commited on
Commit
11528c6
·
verified ·
1 Parent(s): 3921342

Create modeling_quasrav4.py

Browse files
Files changed (1) hide show
  1. modeling_quasrav4.py +268 -0
modeling_quasrav4.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Optional, Tuple, List, Union
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
9
+ from transformers.utils import logging
10
+
11
+ from .configuration_quasrav4 import QuasraV4Config
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ # --- Helper Modules ---
16
+
17
+ class RotaryPositionEmbedding(nn.Module):
18
+ def __init__(self, dim: int, base: int = 10000):
19
+ super().__init__()
20
+ self.dim = dim
21
+ self.base = base
22
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
23
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
24
+
25
+ def _get_rotary_embeddings(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ seq_len = x.size(seq_dim)
27
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
28
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
29
+ emb = torch.cat((freqs, freqs), dim=-1)
30
+ return emb.cos(), emb.sin()
31
+
32
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
33
+ x1, x2 = x.chunk(2, dim=-1)
34
+ return torch.cat((-x2, x1), dim=-1)
35
+
36
+ def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
37
+ return (x * cos) + (self.rotate_half(x) * sin)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ cos, sin = self._get_rotary_embeddings(x, seq_dim=1)
41
+ return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2))
42
+
43
+ class KernelFunction(nn.Module):
44
+ def __init__(self, config: QuasraV4Config):
45
+ super().__init__()
46
+ self.kernel_type = config.kernel_type
47
+ self.epsilon = config.kernel_epsilon
48
+ if self.kernel_type == 'learnable':
49
+ self.temperature = nn.Parameter(torch.ones(1) * 0.1)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ if self.kernel_type == 'elu':
53
+ return F.elu(x) + 1.0 + self.epsilon
54
+ elif self.kernel_type == 'relu':
55
+ return F.relu(x) + self.epsilon
56
+ elif self.kernel_type == 'learnable':
57
+ return F.elu(x * self.temperature) + 1.0 + self.epsilon
58
+ else:
59
+ raise ValueError(f"Unknown kernel type: {self.kernel_type}")
60
+
61
+ class GatedFeedForward(nn.Module):
62
+ def __init__(self, config: QuasraV4Config):
63
+ super().__init__()
64
+ self.hidden_size = config.hidden_size
65
+ self.intermediate_size = config.intermediate_size
66
+ self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2)
67
+ self.fc2 = nn.Linear(self.intermediate_size, self.hidden_size)
68
+ self.activation_dropout = nn.Dropout(config.hidden_dropout_prob)
69
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
70
+ self.layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
71
+
72
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
73
+ residual = hidden_states
74
+ hidden_states = self.layer_norm(hidden_states)
75
+ hidden_states, gate = self.fc1(hidden_states).chunk(2, dim=-1)
76
+ hidden_states = F.gelu(hidden_states) * torch.sigmoid(gate)
77
+ hidden_states = self.activation_dropout(hidden_states)
78
+ hidden_states = self.fc2(hidden_states)
79
+ hidden_states = self.dropout(hidden_states)
80
+ return hidden_states + residual
81
+
82
+ class LinearAttention(nn.Module):
83
+ def __init__(self, config: QuasraV4Config, layer_idx: int = 0):
84
+ super().__init__()
85
+ self.hidden_size = config.hidden_size
86
+ self.num_heads = config.num_attention_heads
87
+ self.head_dim = self.hidden_size // self.num_heads
88
+ self.layer_idx = layer_idx
89
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
90
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
91
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
92
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size)
93
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
94
+ self.kernel = KernelFunction(config)
95
+ self.use_memory = False # Memory is disabled in this version
96
+ self.use_rotary = config.use_rotary_embeddings
97
+ if self.use_rotary:
98
+ self.rotary_emb = RotaryPositionEmbedding(self.head_dim, config.rotary_embedding_base)
99
+
100
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
101
+ batch_size, seq_len, _ = hidden_states.size()
102
+ q = self.q_proj(hidden_states)
103
+ k = self.k_proj(hidden_states)
104
+ v = self.v_proj(hidden_states)
105
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
106
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
107
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)
108
+ if self.use_rotary:
109
+ q = self.rotary_emb(q)
110
+ k = self.rotary_emb(k)
111
+ q = self.kernel(q)
112
+ k = self.kernel(k)
113
+ q_for_sdpa = q.transpose(1, 2)
114
+ k_for_sdpa = k.transpose(1, 2)
115
+ v_for_sdpa = v.transpose(1, 2)
116
+ bool_attention_mask = None
117
+ if attention_mask is not None:
118
+ if attention_mask.dim() == 2:
119
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
120
+ bool_attention_mask = attention_mask < 0
121
+ context_output = F.scaled_dot_product_attention(
122
+ q_for_sdpa, k_for_sdpa, v_for_sdpa, attn_mask=bool_attention_mask, dropout_p=self.dropout.p if self.training else 0.0
123
+ )
124
+ context_output = context_output.transpose(1, 2)
125
+ final_output = context_output.reshape(batch_size, seq_len, self.hidden_size)
126
+ final_output = self.out_proj(final_output)
127
+ final_output = self.dropout(final_output)
128
+ return final_output, None
129
+
130
+ # --- Main Model Components ---
131
+
132
+ class QuasraV4Layer(nn.Module):
133
+ def __init__(self, config: QuasraV4Config, layer_idx: int):
134
+ super().__init__()
135
+ self.embed_dim = config.hidden_size
136
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
137
+ self.self_attn = LinearAttention(config, layer_idx)
138
+ self.ffn = GatedFeedForward(config)
139
+
140
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
141
+ residual = hidden_states
142
+ hidden_states_ln = self.self_attn_layer_norm(hidden_states)
143
+ attn_outputs, _ = self.self_attn(hidden_states=hidden_states_ln, attention_mask=attention_mask, **kwargs)
144
+ hidden_states = residual + attn_outputs
145
+ hidden_states = self.ffn(hidden_states)
146
+ return (hidden_states,)
147
+
148
+ class QuasraV4Embeddings(nn.Module):
149
+ def __init__(self, config: QuasraV4Config):
150
+ super().__init__()
151
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0)
152
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
153
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
154
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
155
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
156
+
157
+ def forward(self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None):
158
+ seq_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
159
+ if position_ids is None:
160
+ position_ids = self.position_ids[:, :seq_length]
161
+ if inputs_embeds is None:
162
+ inputs_embeds = self.word_embeddings(input_ids)
163
+ position_embeddings = self.position_embeddings(position_ids)
164
+ embeddings = inputs_embeds + position_embeddings
165
+ embeddings = self.LayerNorm(embeddings)
166
+ embeddings = self.dropout(embeddings)
167
+ return embeddings
168
+
169
+ class QuasraV4PreTrainedModel(PreTrainedModel):
170
+ config_class = QuasraV4Config
171
+ base_model_prefix = "model"
172
+ supports_gradient_checkpointing = True
173
+ _no_split_modules = ["QuasraV4Layer"]
174
+
175
+ def _init_weights(self, module):
176
+ std = self.config.initializer_range
177
+ if isinstance(module, nn.Linear):
178
+ module.weight.data.normal_(mean=0.0, std=std)
179
+ if module.bias is not None:
180
+ module.bias.data.zero_()
181
+ elif isinstance(module, nn.Embedding):
182
+ module.weight.data.normal_(mean=0.0, std=std)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+
186
+ class QuasraV4Model(QuasraV4PreTrainedModel):
187
+ def __init__(self, config: QuasraV4Config):
188
+ super().__init__(config)
189
+ self.config = config
190
+ self.embeddings = QuasraV4Embeddings(config)
191
+ self.layers = nn.ModuleList([QuasraV4Layer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
192
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
193
+ self.gradient_checkpointing = False
194
+ self.post_init()
195
+
196
+ def get_input_embeddings(self):
197
+ return self.embeddings.word_embeddings
198
+
199
+ def set_input_embeddings(self, value):
200
+ self.embeddings.word_embeddings = value
201
+
202
+ def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, BaseModelOutputWithPast]:
203
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
205
+ for layer_module in self.layers:
206
+ if self.gradient_checkpointing and self.training:
207
+ layer_outputs = self._gradient_checkpoint(layer_module, hidden_states, attention_mask)
208
+ else:
209
+ layer_outputs = layer_module(hidden_states, attention_mask=attention_mask)
210
+ hidden_states = layer_outputs[0]
211
+ hidden_states = self.final_layer_norm(hidden_states)
212
+ if not return_dict:
213
+ return (hidden_states,)
214
+ return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
215
+
216
+ class QuasraV4ForCausalLM(QuasraV4PreTrainedModel):
217
+ _auto_class = "AutoModelForCausalLM"
218
+
219
+ def __init__(self, config: QuasraV4Config):
220
+ super().__init__(config)
221
+ self.model = QuasraV4Model(config)
222
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
223
+ self.post_init()
224
+
225
+ def get_input_embeddings(self):
226
+ return self.model.get_input_embeddings()
227
+
228
+ def set_input_embeddings(self, value):
229
+ self.model.set_input_embeddings(value)
230
+
231
+ def get_output_embeddings(self):
232
+ return self.lm_head
233
+
234
+ def set_output_embeddings(self, new_embeddings):
235
+ self.lm_head = new_embeddings
236
+
237
+ def tie_weights(self):
238
+ if self.config.tie_word_embeddings:
239
+ output_embeddings = self.get_output_embeddings()
240
+ input_embeddings = self.get_input_embeddings()
241
+ output_embeddings.weight = input_embeddings.weight
242
+ if getattr(output_embeddings, "bias", None) is not None:
243
+ output_embeddings.bias.data = nn.functional.pad(
244
+ output_embeddings.bias.data,
245
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
246
+ "constant",
247
+ 0,
248
+ )
249
+ if hasattr(self, "tie_weights_post_actions"):
250
+ self.tie_weights_post_actions()
251
+
252
+ def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
253
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
254
+ outputs = self.model(
255
+ input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, return_dict=return_dict, **kwargs
256
+ )
257
+ sequence_output = outputs[0]
258
+ lm_logits = self.lm_head(sequence_output)
259
+ loss = None
260
+ if labels is not None:
261
+ shift_logits = lm_logits[..., :-1, :].contiguous()
262
+ shift_labels = labels[..., 1:].contiguous()
263
+ loss_fct = nn.CrossEntropyLoss()
264
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
265
+ if not return_dict:
266
+ output = (lm_logits,) + outputs[1:]
267
+ return ((loss,) + output) if loss is not None else output
268
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)