|
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
class MeshRouter(nn.Module): |
|
def __init__(self, config: MeshConfig): |
|
super().__init__() |
|
self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1]) |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.routing_k = config.routing_k |
|
|
|
def forward(self, x): |
|
|
|
gate_scores = self.gate(x) |
|
gate_weights = self.softmax(gate_scores) |
|
|
|
|
|
topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1) |
|
|
|
|
|
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6) |
|
|
|
return topk_weights, topk_indices |
|
|