from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM import torch import torch.nn as nn import torch.nn.functional as F import math from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class # Define a single Expert within the Mesh class MeshExpert(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size) self.gelu = nn.GELU() # Using GELU as an example activation self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size) def forward(self, x): return self.fc2(self.gelu(self.fc1(x)))