File size: 7,334 Bytes
5a91cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright (C) 2025 Arcee AI
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from typing import List, Optional

import torch
import tqdm
import transformers

from mergekit.architecture import MISTRAL_INFO, WeightInfo
from mergekit.moe.arch import MoEOutputArchitecture
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
from mergekit.moe.config import MoEMergeConfig
from mergekit.options import MergeOptions


class MixtralMoE(MoEOutputArchitecture):
    def name(self) -> str:
        return "Mixtral"

    def supports_config(
        self,
        config: MoEMergeConfig,
        explain: bool = False,
        trust_remote_code: bool = False,
    ) -> bool:
        if config.shared_experts:
            if explain:
                logging.warning("Mixtral does not support shared experts")
            return False

        model_types = []
        for model_ref in [config.base_model] + [e.source_model for e in config.experts]:
            model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
            model_types.append(model_cfg.model_type)

        if len(set(model_types)) != 1:
            if explain:
                logging.warning(
                    "Mixtral requires all input models to have the same architecture"
                )
            return False
        if model_types[0] not in ("llama", "mistral"):
            if explain:
                logging.warning(
                    "Mixtral requires all input models to be Llama or Mistral models"
                )
            return False
        return True

    def _generate_config(
        self,
        base_config: transformers.PretrainedConfig,
        num_experts: int,
        shared_experts: Optional[int] = None,
        experts_per_token: Optional[int] = None,
    ) -> transformers.PretrainedConfig:
        if shared_experts:
            raise NotImplementedError("Shared experts not supported for Mixtral output")

        if not isinstance(base_config, transformers.MistralConfig):
            base_cfg_mistral = transformers.MistralConfig(**base_config.to_dict())
            base_cfg_mistral.sliding_window = None
            base_cfg_mistral.max_position_embeddings = (
                base_config.max_position_embeddings
            )
            base_config = base_cfg_mistral

        out_cfg = transformers.MixtralConfig(**base_config.to_dict())
        out_cfg.architectures = ["MixtralForCausalLM"]
        out_cfg.num_local_experts = num_experts
        out_cfg.num_experts_per_tok = experts_per_token or 2
        out_cfg.sliding_window = None

        if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
            logging.warning(
                f"Your model has {out_cfg.num_local_experts} experts, which is "
                "not a power of two. The model will not be usable in llama.cpp."
            )
        return out_cfg

    def _remap_weight_name(self, weight: WeightInfo) -> str:
        if ".mlp." not in weight.name:
            # Everything but MLP is identical to base Mistral
            return weight.name

        res = weight.name
        for needle, replacement in [
            (".mlp.gate_proj", ".block_sparse_moe.experts.{expert_idx}.w1"),
            (".mlp.down_proj", ".block_sparse_moe.experts.{expert_idx}.w2"),
            (".mlp.up_proj", ".block_sparse_moe.experts.{expert_idx}.w3"),
        ]:
            res = res.replace(needle, replacement)
        return res

    def _router_weight_name(self, layer_idx: int) -> str:
        return f"model.layers.{layer_idx}.block_sparse_moe.gate.weight"

    def write_model(
        self,
        out_path: str,
        config: MoEMergeConfig,
        merge_options: MergeOptions,
        router_weights: List[torch.Tensor],
        shared_router_weights: Optional[List[torch.Tensor]] = None,
    ):
        base_model = config.base_model
        base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)

        assert len(router_weights) == base_cfg.num_hidden_layers, (
            f"Expected {base_cfg.num_hidden_layers} router weights, "
            f"got {len(router_weights)}"
        )

        out_dtype = select_dtype(config, base_cfg)
        out_cfg = self._generate_config(
            base_cfg,
            len(config.experts),
            len(config.shared_experts or []),
            config.experts_per_token,
        )
        out_cfg.torch_dtype = out_dtype
        out_cfg.save_pretrained(out_path)

        loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
        for weight_info in tqdm.tqdm(
            MISTRAL_INFO.all_weights(base_cfg),
            desc="Weights",
        ):
            tensor_name = self._remap_weight_name(weight_info)
            if "{expert_idx}" in tensor_name:
                for expert_index, expert in enumerate(config.experts):
                    expert_name = tensor_name.replace("{expert_idx}", str(expert_index))
                    expert_loader = loaders.get(expert.source_model)
                    copy_tensor_out(
                        weight_info,
                        expert_loader,
                        writer,
                        expert=expert,
                        out_dtype=out_dtype,
                        output_name=expert_name,
                        clone=merge_options.clone_tensors,
                        is_residual="down_proj" in tensor_name,
                    )
            else:
                # START FINAL PATCH
                # Because WeightInfo is a frozen Pydantic model, we cannot modify it.
                # We must manually load and save the tensor for the tied weights case.
                if (
                    weight_info.name == "lm_head.weight"
                    and base_cfg.tie_word_embeddings
                ):
                    # If tie_word_embeddings is used, lm_head.weight should not be copied.
                    pass
                
                else:
                    tensor = base_loader.get_tensor(weight_info.name)
                    writer.save_tensor(
                        weight_info.name,  # Always save with the correct destination name
                        tensor.to(dtype=out_dtype),
                        clone=merge_options.clone_tensors,
                    )
                # END FINAL PATCH

        for layer_idx, weight in enumerate(
            tqdm.tqdm(router_weights, desc="Router weights")
        ):
            writer.save_tensor(
                self._router_weight_name(layer_idx),
                weight.to(dtype=out_dtype).contiguous(),
                clone=merge_options.clone_tensors,
            )

        writer.finalize()