File size: 5,261 Bytes
1302eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c7bb87
29250b7
5c7bb87
 
 
29250b7
 
 
 
 
 
 
 
 
 
5762b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

import torch.nn as nn
import torch
import torch.nn.functional as F


def convert_to_linear_experts(old_module: GptOssExperts, config) -> NewGptOssExperts:
    new_mod = NewGptOssExperts(config).to(old_module.gate_up_proj.device)
    new_mod.alpha = old_module.alpha
    new_mod.limit = old_module.limit

    E, D, two_dexp = old_module.gate_up_proj.shape

    for e in range(E):
        # up proj
        W_old = old_module.gate_up_proj[e].detach().to(config.torch_dtype)
        b_old = old_module.gate_up_proj_bias[e].detach().to(config.torch_dtype)
        new_mod.gate_up_projs[e].weight.data.copy_(W_old.transpose(0, 1))
        new_mod.gate_up_projs[e].bias.data.copy_(b_old)

        # down proj
        Wd_old = old_module.down_proj[e].detach().to(config.torch_dtype)
        bd_old = old_module.down_proj_bias[e].detach().to(config.torch_dtype)
        new_mod.down_projs[e].weight.data.copy_(Wd_old.transpose(0, 1))
        new_mod.down_projs[e].bias.data.copy_(bd_old)

    return new_mod

class NewGptOssExperts(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.num_experts = config.num_local_experts
        self.hidden_size = config.hidden_size
        self.expert_dim = config.intermediate_size
        self.alpha = 1.702
        self.limit = 7.0
        self.dtype = config.torch_dtype

        self.gate_up_projs = nn.ModuleList([
            nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype)
            for _ in range(self.num_experts)
        ])
        self.down_projs = nn.ModuleList([
            nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype)
            for _ in range(self.num_experts)
        ])

    def forward(self,
                hidden_states: torch.Tensor,
                router_indices = None,
                routing_weights = None
               ) -> torch.Tensor:

        batch_size = hidden_states.shape[0]
        hidden_states = hidden_states.reshape(-1, self.hidden_size)
        num_experts = routing_weights.shape[1]

        if self.training:
            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
            with torch.no_grad():
                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
                expert_mask = expert_mask.permute(2, 1, 0)
                expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
            for expert_idx in expert_hitted[:]:
                with torch.no_grad():
                    _, token_idx = torch.where(expert_mask[expert_idx[0]])
                current_state = hidden_states[token_idx]
                gate_up = self.gate_up_projs[expert_idx](current_state)
                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
                gate = gate.clamp(min=None, max=self.limit)
                up = up.clamp(min=-self.limit, max=self.limit)
                glu = gate * torch.sigmoid(gate * self.alpha)
                gated_output = (up + 1) * glu
                out = self.down_projs[expert_idx](gated_output)
                weighted_output = out * routing_weights[token_idx, expert_idx, None]
                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
            next_states = next_states.view(batch_size, -1, self.hidden_size)
            return next_states

        else:
            X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1)
            gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)]
            gate_up = torch.stack(gate_up_list, dim=0)
            gate = gate_up[..., ::2]
            up_h = gate_up[..., 1::2]
            gate = gate.clamp(max=self.limit)
            up_h = up_h.clamp(min=-self.limit, max=self.limit)
            glu = gate * torch.sigmoid(gate * self.alpha)
            fused = (up_h + 1) * glu
            out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)]
            outs = torch.stack(out_list, dim=0)
            rw = routing_weights.transpose(0, 1).unsqueeze(-1)
            mixed = (outs * rw).sum(dim=0)
            return mixed.view(batch_size, -1, self.hidden_size)


# to load do
# monkey patch to linear
from transformers.models.gpt_oss import modeling_gpt_oss
modeling_gpt_oss.GptOssExperts = NewGptOssExperts

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "imdatta0/gpt_oss_20b_linear", # make sure you load the right weights
    device_map='cuda:0', # modify appropriately.
    torch_dtype=torch.bfloat16
)


# or to convert on the go, use

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    device_map='cuda:0', # modify appropriately.
    torch_dtype=torch.bfloat16
)

from tqdm import tqdm
for layer in tqdm(model.model.layers):
    experts = layer.mlp.experts
    if isinstance(experts, GptOssExperts):
        new_experts = convert_to_linear_experts(experts, model.config) # function is defined in the file above
        layer.mlp.experts = new_experts.to(model.device, model.dtype)
print('✅ All experts converted to linear')