|
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
class MeshModel(PreTrainedModel): |
|
config_class = MeshConfig |
|
|
|
def __init__(self, config: MeshConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
inputs_embeds = self.embedding(input_ids) |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
hidden_states = inputs_embeds |
|
expert_indices_list = [] |
|
|
|
for i, layer in enumerate(self.layers): |
|
hidden_states, expert_indices = layer(hidden_states) |
|
expert_indices_list.append(expert_indices) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
|
|
|
if return_dict: |
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=None, |
|
hidden_states=hidden_states, |
|
attentions=None, |
|
) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
return (loss, logits, hidden_states, expert_indices_list) |
|
|