QuasarV4-Tiny / modeling_quasrav4.py
eyad-silx's picture
Update modeling_quasrav4.py
9c608de verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.generation import GenerationMixin
from transformers.utils import logging
from configuration_quasrav4 import InfinityFormerConfig
logger = logging.get_logger(__name__)
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim: int, base: int = 10000):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _get_rotary_embeddings(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
seq_len = x.size(seq_dim)
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (x * cos) + (self.rotate_half(x) * sin)
def forward(self, x: torch.Tensor) -> torch.Tensor:
cos, sin = self._get_rotary_embeddings(x, seq_dim=1)
return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2))
class KernelFunction(nn.Module):
def __init__(self, config: InfinityFormerConfig):
super().__init__()
self.kernel_type = config.kernel_type
self.epsilon = config.kernel_epsilon
if self.kernel_type == 'learnable':
self.temperature = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.kernel_type == 'elu':
return F.elu(x) + 1.0 + self.epsilon
elif self.kernel_type == 'relu':
return F.relu(x) + self.epsilon
elif self.kernel_type == 'learnable':
return F.elu(x * self.temperature) + 1.0 + self.epsilon
else:
raise ValueError(f"Unknown kernel type: {self.kernel_type}")
class GatedFeedForward(nn.Module):
def __init__(self, config: InfinityFormerConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2)
self.fc2 = nn.Linear(self.intermediate_size, self.hidden_size)
self.activation_dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, gate = self.fc1(hidden_states).chunk(2, dim=-1)
hidden_states = F.gelu(hidden_states) * torch.sigmoid(gate)
hidden_states = self.activation_dropout(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states + residual
class LinearAttention(nn.Module):
def __init__(self, config: InfinityFormerConfig, layer_idx: int = 0):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.layer_idx = layer_idx
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.kernel = KernelFunction(config)
self.use_memory = False # Memory is disabled in this version
self.use_rotary = config.use_rotary_embeddings
if self.use_rotary:
self.rotary_emb = RotaryPositionEmbedding(self.head_dim, config.rotary_embedding_base)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_len, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)
if self.use_rotary:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q = self.kernel(q)
k = self.kernel(k)
q_for_sdpa = q.transpose(1, 2)
k_for_sdpa = k.transpose(1, 2)
v_for_sdpa = v.transpose(1, 2)
bool_attention_mask = None
if attention_mask is not None:
if attention_mask.dim() == 2:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
bool_attention_mask = attention_mask < 0
context_output = F.scaled_dot_product_attention(
q_for_sdpa, k_for_sdpa, v_for_sdpa, attn_mask=bool_attention_mask, dropout_p=self.dropout.p if self.training else 0.0
)
context_output = context_output.transpose(1, 2)
final_output = context_output.reshape(batch_size, seq_len, self.hidden_size)
final_output = self.out_proj(final_output)
final_output = self.dropout(final_output)
return final_output, None
class InfinityFormerLayer(nn.Module):
def __init__(self, config: InfinityFormerConfig, layer_idx: int):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = LinearAttention(config, layer_idx)
if config.use_memory_attention:
self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mem_attn = LinearAttention(config, layer_idx)
self.ffn = GatedFeedForward(config)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]
hidden_states = residual + hidden_states
if hasattr(self, 'mem_attn'):
mem_residual = hidden_states
hidden_states = self.mem_attn_layer_norm(hidden_states)
hidden_states = self.mem_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0]
hidden_states = mem_residual + hidden_states
hidden_states = self.ffn(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
return (hidden_states,)
class InfinityFormerEmbeddings(nn.Module):
def __init__(self, config: InfinityFormerConfig):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
def forward(self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None):
seq_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class InfinityFormerPreTrainedModel(PreTrainedModel):
config_class = InfinityFormerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["InfinityFormerLayer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class InfinityFormerModel(InfinityFormerPreTrainedModel):
def __init__(self, config: InfinityFormerConfig):
super().__init__(config)
self.config = config
self.embeddings = InfinityFormerEmbeddings(config)
self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
for layer_module in self.layers:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpoint(layer_module, hidden_states, attention_mask)
else:
layer_outputs = layer_module(hidden_states, attention_mask=attention_mask)
hidden_states = layer_outputs[0]
if not return_dict:
return (hidden_states,)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
class InfinityFormerForCausalLM(GenerationMixin, InfinityFormerPreTrainedModel):
_auto_class = "AutoModelForCausalLM"
def __init__(self, config: InfinityFormerConfig):
super().__init__(config)
self.model = InfinityFormerModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def tie_weights(self):
if self.config.tie_word_embeddings:
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
output_embeddings.weight = input_embeddings.weight
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
"constant",
0,
)
if hasattr(self, "tie_weights_post_actions"):
self.tie_weights_post_actions()
def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, return_dict=return_dict, **kwargs
)
sequence_output = outputs[0]
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
def set_to_generation_mode(self):
"""Sets the model to generation mode."""
self.eval() # Set the entire model to evaluation mode