File size: 5,261 Bytes
1302eb0 5c7bb87 29250b7 5c7bb87 29250b7 5762b49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch.nn as nn
import torch
import torch.nn.functional as F
def convert_to_linear_experts(old_module: GptOssExperts, config) -> NewGptOssExperts:
new_mod = NewGptOssExperts(config).to(old_module.gate_up_proj.device)
new_mod.alpha = old_module.alpha
new_mod.limit = old_module.limit
E, D, two_dexp = old_module.gate_up_proj.shape
for e in range(E):
# up proj
W_old = old_module.gate_up_proj[e].detach().to(config.torch_dtype)
b_old = old_module.gate_up_proj_bias[e].detach().to(config.torch_dtype)
new_mod.gate_up_projs[e].weight.data.copy_(W_old.transpose(0, 1))
new_mod.gate_up_projs[e].bias.data.copy_(b_old)
# down proj
Wd_old = old_module.down_proj[e].detach().to(config.torch_dtype)
bd_old = old_module.down_proj_bias[e].detach().to(config.torch_dtype)
new_mod.down_projs[e].weight.data.copy_(Wd_old.transpose(0, 1))
new_mod.down_projs[e].bias.data.copy_(bd_old)
return new_mod
class NewGptOssExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.expert_dim = config.intermediate_size
self.alpha = 1.702
self.limit = 7.0
self.dtype = config.torch_dtype
self.gate_up_projs = nn.ModuleList([
nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype)
for _ in range(self.num_experts)
])
self.down_projs = nn.ModuleList([
nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype)
for _ in range(self.num_experts)
])
def forward(self,
hidden_states: torch.Tensor,
router_indices = None,
routing_weights = None
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
num_experts = routing_weights.shape[1]
if self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hitted[:]:
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx[0]])
current_state = hidden_states[token_idx]
gate_up = self.gate_up_projs[expert_idx](current_state)
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
out = self.down_projs[expert_idx](gated_output)
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
else:
X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1)
gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)]
gate_up = torch.stack(gate_up_list, dim=0)
gate = gate_up[..., ::2]
up_h = gate_up[..., 1::2]
gate = gate.clamp(max=self.limit)
up_h = up_h.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
fused = (up_h + 1) * glu
out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)]
outs = torch.stack(out_list, dim=0)
rw = routing_weights.transpose(0, 1).unsqueeze(-1)
mixed = (outs * rw).sum(dim=0)
return mixed.view(batch_size, -1, self.hidden_size)
# to load do
# monkey patch to linear
from transformers.models.gpt_oss import modeling_gpt_oss
modeling_gpt_oss.GptOssExperts = NewGptOssExperts
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained(
"imdatta0/gpt_oss_20b_linear", # make sure you load the right weights
device_map='cuda:0', # modify appropriately.
torch_dtype=torch.bfloat16
)
# or to convert on the go, use
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained(
"openai/gpt-oss-20b",
device_map='cuda:0', # modify appropriately.
torch_dtype=torch.bfloat16
)
from tqdm import tqdm
for layer in tqdm(model.model.layers):
experts = layer.mlp.experts
if isinstance(experts, GptOssExperts):
new_experts = convert_to_linear_experts(experts, model.config) # function is defined in the file above
layer.mlp.experts = new_experts.to(model.device, model.dtype)
print('✅ All experts converted to linear')
|