|
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
class MeshLayer(nn.Module): |
|
def __init__(self, config: MeshConfig): |
|
super().__init__() |
|
self.config = config |
|
self.router = MeshRouter(config) |
|
self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])]) |
|
self.neighbor_exchange = NeighborExchange(config) |
|
self.cross_expert_attention = CrossExpertAttention(config) |
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
|
|
topk_weights, topk_indices = self.router(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1) |
|
|
|
|
|
|
|
expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2) |
|
|
|
|
|
|
|
exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices) |
|
|
|
|
|
cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs) |
|
|
|
|
|
|
|
|
|
gathered_outputs = torch.gather( |
|
cross_attned_expert_outputs, |
|
dim=2, |
|
index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size) |
|
) |
|
|
|
|
|
combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2) |
|
|
|
|
|
|
|
return combined_output, topk_indices |
|
|