minpeter commited on
Commit
36b369f
·
verified ·
1 Parent(s): 5686743

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. mixtral-patch.py +186 -0
mixtral-patch.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Arcee AI
2
+ #
3
+ # This software is free software: you can redistribute it and/or
4
+ # modify it under the terms of the GNU Lesser General Public License as
5
+ # published by the Free Software Foundation, either version 3 of the
6
+ # License, or (at your option) any later version.
7
+ #
8
+ # This software is distributed in the hope that it will be useful, but
9
+ # WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11
+ # Lesser General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU Lesser General Public License
14
+ # along with this program. If not, see http://www.gnu.org/licenses/.
15
+
16
+ import logging
17
+ from typing import List, Optional
18
+
19
+ import torch
20
+ import tqdm
21
+ import transformers
22
+
23
+ from mergekit.architecture import MISTRAL_INFO, WeightInfo
24
+ from mergekit.moe.arch import MoEOutputArchitecture
25
+ from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
26
+ from mergekit.moe.config import MoEMergeConfig
27
+ from mergekit.options import MergeOptions
28
+
29
+
30
+ class MixtralMoE(MoEOutputArchitecture):
31
+ def name(self) -> str:
32
+ return "Mixtral"
33
+
34
+ def supports_config(
35
+ self,
36
+ config: MoEMergeConfig,
37
+ explain: bool = False,
38
+ trust_remote_code: bool = False,
39
+ ) -> bool:
40
+ if config.shared_experts:
41
+ if explain:
42
+ logging.warning("Mixtral does not support shared experts")
43
+ return False
44
+
45
+ model_types = []
46
+ for model_ref in [config.base_model] + [e.source_model for e in config.experts]:
47
+ model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
48
+ model_types.append(model_cfg.model_type)
49
+
50
+ if len(set(model_types)) != 1:
51
+ if explain:
52
+ logging.warning(
53
+ "Mixtral requires all input models to have the same architecture"
54
+ )
55
+ return False
56
+ if model_types[0] not in ("llama", "mistral"):
57
+ if explain:
58
+ logging.warning(
59
+ "Mixtral requires all input models to be Llama or Mistral models"
60
+ )
61
+ return False
62
+ return True
63
+
64
+ def _generate_config(
65
+ self,
66
+ base_config: transformers.PretrainedConfig,
67
+ num_experts: int,
68
+ shared_experts: Optional[int] = None,
69
+ experts_per_token: Optional[int] = None,
70
+ ) -> transformers.PretrainedConfig:
71
+ if shared_experts:
72
+ raise NotImplementedError("Shared experts not supported for Mixtral output")
73
+
74
+ if not isinstance(base_config, transformers.MistralConfig):
75
+ base_cfg_mistral = transformers.MistralConfig(**base_config.to_dict())
76
+ base_cfg_mistral.sliding_window = None
77
+ base_cfg_mistral.max_position_embeddings = (
78
+ base_config.max_position_embeddings
79
+ )
80
+ base_config = base_cfg_mistral
81
+
82
+ out_cfg = transformers.MixtralConfig(**base_config.to_dict())
83
+ out_cfg.architectures = ["MixtralForCausalLM"]
84
+ out_cfg.num_local_experts = num_experts
85
+ out_cfg.num_experts_per_tok = experts_per_token or 2
86
+ out_cfg.sliding_window = None
87
+
88
+ if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
89
+ logging.warning(
90
+ f"Your model has {out_cfg.num_local_experts} experts, which is "
91
+ "not a power of two. The model will not be usable in llama.cpp."
92
+ )
93
+ return out_cfg
94
+
95
+ def _remap_weight_name(self, weight: WeightInfo) -> str:
96
+ if ".mlp." not in weight.name:
97
+ # Everything but MLP is identical to base Mistral
98
+ return weight.name
99
+
100
+ res = weight.name
101
+ for needle, replacement in [
102
+ (".mlp.gate_proj", ".block_sparse_moe.experts.{expert_idx}.w1"),
103
+ (".mlp.down_proj", ".block_sparse_moe.experts.{expert_idx}.w2"),
104
+ (".mlp.up_proj", ".block_sparse_moe.experts.{expert_idx}.w3"),
105
+ ]:
106
+ res = res.replace(needle, replacement)
107
+ return res
108
+
109
+ def _router_weight_name(self, layer_idx: int) -> str:
110
+ return f"model.layers.{layer_idx}.block_sparse_moe.gate.weight"
111
+
112
+ def write_model(
113
+ self,
114
+ out_path: str,
115
+ config: MoEMergeConfig,
116
+ merge_options: MergeOptions,
117
+ router_weights: List[torch.Tensor],
118
+ shared_router_weights: Optional[List[torch.Tensor]] = None,
119
+ ):
120
+ base_model = config.base_model
121
+ base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
122
+
123
+ assert len(router_weights) == base_cfg.num_hidden_layers, (
124
+ f"Expected {base_cfg.num_hidden_layers} router weights, "
125
+ f"got {len(router_weights)}"
126
+ )
127
+
128
+ out_dtype = select_dtype(config, base_cfg)
129
+ out_cfg = self._generate_config(
130
+ base_cfg,
131
+ len(config.experts),
132
+ len(config.shared_experts or []),
133
+ config.experts_per_token,
134
+ )
135
+ out_cfg.torch_dtype = out_dtype
136
+ out_cfg.save_pretrained(out_path)
137
+
138
+ loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
139
+ for weight_info in tqdm.tqdm(
140
+ MISTRAL_INFO.all_weights(base_cfg),
141
+ desc="Weights",
142
+ ):
143
+ tensor_name = self._remap_weight_name(weight_info)
144
+ if "{expert_idx}" in tensor_name:
145
+ for expert_index, expert in enumerate(config.experts):
146
+ expert_name = tensor_name.replace("{expert_idx}", str(expert_index))
147
+ expert_loader = loaders.get(expert.source_model)
148
+ copy_tensor_out(
149
+ weight_info,
150
+ expert_loader,
151
+ writer,
152
+ expert=expert,
153
+ out_dtype=out_dtype,
154
+ output_name=expert_name,
155
+ clone=merge_options.clone_tensors,
156
+ is_residual="down_proj" in tensor_name,
157
+ )
158
+ else:
159
+ # START FINAL PATCH
160
+ # Because WeightInfo is a frozen Pydantic model, we cannot modify it.
161
+ # We must manually load and save the tensor for the tied weights case.
162
+ tensor_to_load_name = weight_info.name
163
+ if (
164
+ weight_info.name == "lm_head.weight"
165
+ and base_cfg.tie_word_embeddings
166
+ ):
167
+ tensor_to_load_name = "model.embed_tokens.weight"
168
+
169
+ tensor = base_loader.get_tensor(tensor_to_load_name)
170
+ writer.save_tensor(
171
+ weight_info.name, # Always save with the correct destination name
172
+ tensor.to(dtype=out_dtype),
173
+ clone=merge_options.clone_tensors,
174
+ )
175
+ # END FINAL PATCH
176
+
177
+ for layer_idx, weight in enumerate(
178
+ tqdm.tqdm(router_weights, desc="Router weights")
179
+ ):
180
+ writer.save_tensor(
181
+ self._router_weight_name(layer_idx),
182
+ weight.to(dtype=out_dtype).contiguous(),
183
+ clone=merge_options.clone_tensors,
184
+ )
185
+
186
+ writer.finalize()