|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple, List, Union |
|
|
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
from transformers.generation import GenerationMixin |
|
from transformers.utils import logging |
|
|
|
from configuration_quasrav4 import InfinityFormerConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class RotaryPositionEmbedding(nn.Module): |
|
def __init__(self, dim: int, base: int = 10000): |
|
super().__init__() |
|
self.dim = dim |
|
self.base = base |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
def _get_rotary_embeddings(self, x: torch.Tensor, seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]: |
|
seq_len = x.size(seq_dim) |
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
|
freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
return emb.cos(), emb.sin() |
|
|
|
def rotate_half(self, x: torch.Tensor) -> torch.Tensor: |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def apply_rotary_pos_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
return (x * cos) + (self.rotate_half(x) * sin) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
cos, sin = self._get_rotary_embeddings(x, seq_dim=1) |
|
return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2)) |
|
|
|
class KernelFunction(nn.Module): |
|
def __init__(self, config: InfinityFormerConfig): |
|
super().__init__() |
|
self.kernel_type = config.kernel_type |
|
self.epsilon = config.kernel_epsilon |
|
if self.kernel_type == 'learnable': |
|
self.temperature = nn.Parameter(torch.ones(1) * 0.1) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.kernel_type == 'elu': |
|
return F.elu(x) + 1.0 + self.epsilon |
|
elif self.kernel_type == 'relu': |
|
return F.relu(x) + self.epsilon |
|
elif self.kernel_type == 'learnable': |
|
return F.elu(x * self.temperature) + 1.0 + self.epsilon |
|
else: |
|
raise ValueError(f"Unknown kernel type: {self.kernel_type}") |
|
|
|
class GatedFeedForward(nn.Module): |
|
def __init__(self, config: InfinityFormerConfig): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2) |
|
self.fc2 = nn.Linear(self.intermediate_size, self.hidden_size) |
|
self.activation_dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
residual = hidden_states |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states, gate = self.fc1(hidden_states).chunk(2, dim=-1) |
|
hidden_states = F.gelu(hidden_states) * torch.sigmoid(gate) |
|
hidden_states = self.activation_dropout(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
return hidden_states + residual |
|
|
|
class LinearAttention(nn.Module): |
|
def __init__(self, config: InfinityFormerConfig, layer_idx: int = 0): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.layer_idx = layer_idx |
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) |
|
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) |
|
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) |
|
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size) |
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
self.kernel = KernelFunction(config) |
|
self.use_memory = False |
|
self.use_rotary = config.use_rotary_embeddings |
|
if self.use_rotary: |
|
self.rotary_emb = RotaryPositionEmbedding(self.head_dim, config.rotary_embedding_base) |
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
batch_size, seq_len, _ = hidden_states.size() |
|
q = self.q_proj(hidden_states) |
|
k = self.k_proj(hidden_states) |
|
v = self.v_proj(hidden_states) |
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
if self.use_rotary: |
|
q = self.rotary_emb(q) |
|
k = self.rotary_emb(k) |
|
q = self.kernel(q) |
|
k = self.kernel(k) |
|
q_for_sdpa = q.transpose(1, 2) |
|
k_for_sdpa = k.transpose(1, 2) |
|
v_for_sdpa = v.transpose(1, 2) |
|
bool_attention_mask = None |
|
if attention_mask is not None: |
|
if attention_mask.dim() == 2: |
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) |
|
bool_attention_mask = attention_mask < 0 |
|
context_output = F.scaled_dot_product_attention( |
|
q_for_sdpa, k_for_sdpa, v_for_sdpa, attn_mask=bool_attention_mask, dropout_p=self.dropout.p if self.training else 0.0 |
|
) |
|
context_output = context_output.transpose(1, 2) |
|
final_output = context_output.reshape(batch_size, seq_len, self.hidden_size) |
|
final_output = self.out_proj(final_output) |
|
final_output = self.dropout(final_output) |
|
return final_output, None |
|
|
|
class InfinityFormerLayer(nn.Module): |
|
def __init__(self, config: InfinityFormerConfig, layer_idx: int): |
|
super().__init__() |
|
self.embed_dim = config.hidden_size |
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.self_attn = LinearAttention(config, layer_idx) |
|
if config.use_memory_attention: |
|
self.mem_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.mem_attn = LinearAttention(config, layer_idx) |
|
self.ffn = GatedFeedForward(config) |
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> Tuple[torch.Tensor, ...]: |
|
residual = hidden_states |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0] |
|
hidden_states = residual + hidden_states |
|
|
|
if hasattr(self, 'mem_attn'): |
|
mem_residual = hidden_states |
|
hidden_states = self.mem_attn_layer_norm(hidden_states) |
|
hidden_states = self.mem_attn(hidden_states, attention_mask=attention_mask, **kwargs)[0] |
|
hidden_states = mem_residual + hidden_states |
|
|
|
hidden_states = self.ffn(hidden_states) |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
return (hidden_states,) |
|
|
|
class InfinityFormerEmbeddings(nn.Module): |
|
def __init__(self, config: InfinityFormerConfig): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False) |
|
|
|
def forward(self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None): |
|
seq_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
if position_ids is None: |
|
position_ids = self.position_ids[:, :seq_length] |
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = inputs_embeds + position_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
class InfinityFormerPreTrainedModel(PreTrainedModel): |
|
config_class = InfinityFormerConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["InfinityFormerLayer"] |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
class InfinityFormerModel(InfinityFormerPreTrainedModel): |
|
def __init__(self, config: InfinityFormerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.embeddings = InfinityFormerEmbeddings(config) |
|
self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
|
self.gradient_checkpointing = False |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.embeddings.word_embeddings = value |
|
|
|
def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, BaseModelOutputWithPast]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) |
|
for layer_module in self.layers: |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpoint(layer_module, hidden_states, attention_mask) |
|
else: |
|
layer_outputs = layer_module(hidden_states, attention_mask=attention_mask) |
|
hidden_states = layer_outputs[0] |
|
if not return_dict: |
|
return (hidden_states,) |
|
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None) |
|
|
|
class InfinityFormerForCausalLM(GenerationMixin, InfinityFormerPreTrainedModel): |
|
_auto_class = "AutoModelForCausalLM" |
|
|
|
def __init__(self, config: InfinityFormerConfig): |
|
super().__init__(config) |
|
self.model = InfinityFormerModel(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def tie_weights(self): |
|
if self.config.tie_word_embeddings: |
|
output_embeddings = self.get_output_embeddings() |
|
input_embeddings = self.get_input_embeddings() |
|
output_embeddings.weight = input_embeddings.weight |
|
if getattr(output_embeddings, "bias", None) is not None: |
|
output_embeddings.bias.data = nn.functional.pad( |
|
output_embeddings.bias.data, |
|
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]), |
|
"constant", |
|
0, |
|
) |
|
if hasattr(self, "tie_weights_post_actions"): |
|
self.tie_weights_post_actions() |
|
|
|
def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
outputs = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, return_dict=return_dict, **kwargs |
|
) |
|
sequence_output = outputs[0] |
|
lm_logits = self.lm_head(sequence_output) |
|
loss = None |
|
if labels is not None: |
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) |
|
|
|
def set_to_generation_mode(self): |
|
"""Sets the model to generation mode.""" |
|
self.eval() |
|
|