from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM import torch import torch.nn as nn import torch.nn.functional as F import math from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class # Define the Router for dynamic routing class MeshRouter(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1]) self.softmax = nn.Softmax(dim=-1) self.routing_k = config.routing_k def forward(self, x): # x shape: (batch_size, sequence_length, hidden_size) gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts) gate_weights = self.softmax(gate_scores) # Select top-k experts topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1) # Normalize top-k weights topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6) return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)