from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM import torch import torch.nn as nn import torch.nn.functional as F import math from transformers.modeling_outputs import CausalLMOutputWithPast class MeshModel(PreTrainedModel): config_class = MeshConfig def __init__(self, config: MeshConfig): super().__init__(config) self.config = config self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # Ensure return_dict is set to True by default if not specified return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() inputs_embeds = self.embedding(input_ids) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") hidden_states = inputs_embeds expert_indices_list = [] # To collect expert indices from each layer for i, layer in enumerate(self.layers): hidden_states, expert_indices = layer(hidden_states) expert_indices_list.append(expert_indices) # Collect indices hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: # Compute loss (e.g., CrossEntropyLoss) loss_fct = nn.CrossEntropyLoss() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Calculate scalar loss loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) # Return a CausalLMOutputWithPast object or a tuple if return_dict: return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, # Need to implement caching hidden_states=hidden_states, attentions=None, # Need to implement attention handling ) else: # Return a tuple including loss, logits, and collected expert indices # Ensure the order and content match what the Trainer expects or can handle # Trainer expects (loss, logits, hidden_states, attentions) or similar. # We can return (loss, logits) as the primary outputs for the Trainer # and potentially include expert_indices as an additional output if needed # by a custom callback or logging, but the default Trainer expects loss as the first element for backward. return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list