v0.1-2x2-stage001 / meshlayer.py
aquiffoo's picture
Update meshlayer.py
7ce8e21 verified
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
# Define the main Mesh Layer
class MeshLayer(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.config = config
self.router = MeshRouter(config)
self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
self.neighbor_exchange = NeighborExchange(config)
self.cross_expert_attention = CrossExpertAttention(config)
def forward(self, hidden_states):
# hidden_states shape: (batch_size, sequence_length, hidden_size)
# 1. Routing
topk_weights, topk_indices = self.router(hidden_states)
# topk_weights shape: (batch_size, sequence_length, k)
# topk_indices shape: (batch_size, sequence_length, k)
# Prepare expert inputs: repeat hidden_states for each expert
# shape: (batch_size, sequence_length, num_experts, hidden_size)
expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
# 2. Expert Computation
# Compute output for all experts (can be optimized to only compute for selected experts)
expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2)
# expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
# 3. Neighbor Exchange (conceptual implementation needed)
exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
# 4. Cross-Expert Attention (conceptual implementation needed)
cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
# 5. Combine expert outputs based on routing weights
# Create a tensor to gather the outputs of the selected experts
# shape: (batch_size, sequence_length, k, hidden_size)
gathered_outputs = torch.gather(
cross_attned_expert_outputs,
dim=2,
index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
)
# Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size)
combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
# combined_output shape: (batch_size, sequence_length, hidden_size)
# Return the combined output and the expert indices for potential visualization
return combined_output, topk_indices # Return combined output and expert indices