import torch.nn as nn import torch import torch.nn.functional as F def convert_to_linear_experts(old_module: GptOssExperts, config) -> NewGptOssExperts: new_mod = NewGptOssExperts(config).to(old_module.gate_up_proj.device) new_mod.alpha = old_module.alpha new_mod.limit = old_module.limit E, D, two_dexp = old_module.gate_up_proj.shape for e in range(E): # up proj W_old = old_module.gate_up_proj[e].detach().to(config.torch_dtype) b_old = old_module.gate_up_proj_bias[e].detach().to(config.torch_dtype) new_mod.gate_up_projs[e].weight.data.copy_(W_old.transpose(0, 1)) new_mod.gate_up_projs[e].bias.data.copy_(b_old) # down proj Wd_old = old_module.down_proj[e].detach().to(config.torch_dtype) bd_old = old_module.down_proj_bias[e].detach().to(config.torch_dtype) new_mod.down_projs[e].weight.data.copy_(Wd_old.transpose(0, 1)) new_mod.down_projs[e].bias.data.copy_(bd_old) return new_mod class NewGptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = config.intermediate_size self.alpha = 1.702 self.limit = 7.0 self.dtype = config.torch_dtype self.gate_up_projs = nn.ModuleList([ nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype) for _ in range(self.num_experts) ]) self.down_projs = nn.ModuleList([ nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype) for _ in range(self.num_experts) ]) def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_weights = None ) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_experts = routing_weights.shape[1] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hitted[:]: with torch.no_grad(): _, token_idx = torch.where(expert_mask[expert_idx[0]]) current_state = hidden_states[token_idx] gate_up = self.gate_up_projs[expert_idx](current_state) gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) weighted_output = out * routing_weights[token_idx, expert_idx, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) return next_states else: X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) gate = gate_up[..., ::2] up_h = gate_up[..., 1::2] gate = gate.clamp(max=self.limit) up_h = up_h.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) fused = (up_h + 1) * glu out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) mixed = (outs * rw).sum(dim=0) return mixed.view(batch_size, -1, self.hidden_size) # to load do # monkey patch to linear from transformers.models.gpt_oss import modeling_gpt_oss modeling_gpt_oss.GptOssExperts = NewGptOssExperts from transformers import AutoModelForCausalLM, AutoTokenizer import torch model = AutoModelForCausalLM.from_pretrained( "imdatta0/gpt_oss_20b_linear", # make sure you load the right weights device_map='cuda:0', # modify appropriately. torch_dtype=torch.bfloat16 ) # or to convert on the go, use from transformers import AutoModelForCausalLM, AutoTokenizer import torch model = AutoModelForCausalLM.from_pretrained( "openai/gpt-oss-20b", device_map='cuda:0', # modify appropriately. torch_dtype=torch.bfloat16 ) from tqdm import tqdm for layer in tqdm(model.model.layers): experts = layer.mlp.experts if isinstance(experts, GptOssExperts): new_experts = convert_to_linear_experts(experts, model.config) # function is defined in the file above layer.mlp.experts = new_experts.to(model.device, model.dtype) print('✅ All experts converted to linear')