|
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
class MeshExpert(nn.Module): |
|
def __init__(self, config: MeshConfig): |
|
super().__init__() |
|
self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size) |
|
self.gelu = nn.GELU() |
|
self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size) |
|
|
|
def forward(self, x): |
|
return self.fc2(self.gelu(self.fc1(x))) |
|
|