File size: 1,203 Bytes
fc24706 515cb80 fc24706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
# Define the Router for dynamic routing
class MeshRouter(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
self.softmax = nn.Softmax(dim=-1)
self.routing_k = config.routing_k
def forward(self, x):
# x shape: (batch_size, sequence_length, hidden_size)
gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
gate_weights = self.softmax(gate_scores)
# Select top-k experts
topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
# Normalize top-k weights
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
|