File size: 1,203 Bytes
fc24706
 
 
 
 
 
515cb80
fc24706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class

# Define the Router for dynamic routing
class MeshRouter(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
        self.softmax = nn.Softmax(dim=-1)
        self.routing_k = config.routing_k

    def forward(self, x):
        # x shape: (batch_size, sequence_length, hidden_size)
        gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
        gate_weights = self.softmax(gate_scores)

        # Select top-k experts
        topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)

        # Normalize top-k weights
        topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)

        return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)