from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM import torch import torch.nn as nn import torch.nn.functional as F import math from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class # Define the main Mesh Layer class MeshLayer(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.config = config self.router = MeshRouter(config) self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])]) self.neighbor_exchange = NeighborExchange(config) self.cross_expert_attention = CrossExpertAttention(config) def forward(self, hidden_states): # hidden_states shape: (batch_size, sequence_length, hidden_size) # 1. Routing topk_weights, topk_indices = self.router(hidden_states) # topk_weights shape: (batch_size, sequence_length, k) # topk_indices shape: (batch_size, sequence_length, k) # Prepare expert inputs: repeat hidden_states for each expert # shape: (batch_size, sequence_length, num_experts, hidden_size) expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1) # 2. Expert Computation # Compute output for all experts (can be optimized to only compute for selected experts) expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2) # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size) # 3. Neighbor Exchange (conceptual implementation needed) exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices) # 4. Cross-Expert Attention (conceptual implementation needed) cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs) # 5. Combine expert outputs based on routing weights # Create a tensor to gather the outputs of the selected experts # shape: (batch_size, sequence_length, k, hidden_size) gathered_outputs = torch.gather( cross_attned_expert_outputs, dim=2, index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size) ) # Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size) combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2) # combined_output shape: (batch_size, sequence_length, hidden_size) # Return the combined output and the expert indices for potential visualization return combined_output, topk_indices # Return combined output and expert indices