v0.1-2x2-stage001 / meshmodel.py
aquiffoo's picture
Update meshmodel.py
4765ac9 verified
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast
class MeshModel(PreTrainedModel):
config_class = MeshConfig
def __init__(self, config: MeshConfig):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# Ensure return_dict is set to True by default if not specified
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
inputs_embeds = self.embedding(input_ids)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
hidden_states = inputs_embeds
expert_indices_list = [] # To collect expert indices from each layer
for i, layer in enumerate(self.layers):
hidden_states, expert_indices = layer(hidden_states)
expert_indices_list.append(expert_indices) # Collect indices
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Compute loss (e.g., CrossEntropyLoss)
loss_fct = nn.CrossEntropyLoss()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Calculate scalar loss
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
# Return a CausalLMOutputWithPast object or a tuple
if return_dict:
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None, # Need to implement caching
hidden_states=hidden_states,
attentions=None, # Need to implement attention handling
)
else:
# Return a tuple including loss, logits, and collected expert indices
# Ensure the order and content match what the Trainer expects or can handle
# Trainer expects (loss, logits, hidden_states, attentions) or similar.
# We can return (loss, logits) as the primary outputs for the Trainer
# and potentially include expert_indices as an additional output if needed
# by a custom callback or logging, but the default Trainer expects loss as the first element for backward.
return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list