minpeter commited on
Commit
f9156b3
·
verified ·
1 Parent(s): 19826d0

Delete mixtral-patch.py

Browse files
Files changed (1) hide show
  1. mixtral-patch.py +0 -187
mixtral-patch.py DELETED
@@ -1,187 +0,0 @@
1
- # Copyright (C) 2025 Arcee AI
2
- #
3
- # This software is free software: you can redistribute it and/or
4
- # modify it under the terms of the GNU Lesser General Public License as
5
- # published by the Free Software Foundation, either version 3 of the
6
- # License, or (at your option) any later version.
7
- #
8
- # This software is distributed in the hope that it will be useful, but
9
- # WITHOUT ANY WARRANTY; without even the implied warranty of
10
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11
- # Lesser General Public License for more details.
12
- #
13
- # You should have received a copy of the GNU Lesser General Public License
14
- # along with this program. If not, see http://www.gnu.org/licenses/.
15
-
16
- import logging
17
- from typing import List, Optional
18
-
19
- import torch
20
- import tqdm
21
- import transformers
22
-
23
- from mergekit.architecture import MISTRAL_INFO, WeightInfo
24
- from mergekit.moe.arch import MoEOutputArchitecture
25
- from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
26
- from mergekit.moe.config import MoEMergeConfig
27
- from mergekit.options import MergeOptions
28
-
29
-
30
- class MixtralMoE(MoEOutputArchitecture):
31
- def name(self) -> str:
32
- return "Mixtral"
33
-
34
- def supports_config(
35
- self,
36
- config: MoEMergeConfig,
37
- explain: bool = False,
38
- trust_remote_code: bool = False,
39
- ) -> bool:
40
- if config.shared_experts:
41
- if explain:
42
- logging.warning("Mixtral does not support shared experts")
43
- return False
44
-
45
- model_types = []
46
- for model_ref in [config.base_model] + [e.source_model for e in config.experts]:
47
- model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
48
- model_types.append(model_cfg.model_type)
49
-
50
- if len(set(model_types)) != 1:
51
- if explain:
52
- logging.warning(
53
- "Mixtral requires all input models to have the same architecture"
54
- )
55
- return False
56
- if model_types[0] not in ("llama", "mistral"):
57
- if explain:
58
- logging.warning(
59
- "Mixtral requires all input models to be Llama or Mistral models"
60
- )
61
- return False
62
- return True
63
-
64
- def _generate_config(
65
- self,
66
- base_config: transformers.PretrainedConfig,
67
- num_experts: int,
68
- shared_experts: Optional[int] = None,
69
- experts_per_token: Optional[int] = None,
70
- ) -> transformers.PretrainedConfig:
71
- if shared_experts:
72
- raise NotImplementedError("Shared experts not supported for Mixtral output")
73
-
74
- if not isinstance(base_config, transformers.MistralConfig):
75
- base_cfg_mistral = transformers.MistralConfig(**base_config.to_dict())
76
- base_cfg_mistral.sliding_window = None
77
- base_cfg_mistral.max_position_embeddings = (
78
- base_config.max_position_embeddings
79
- )
80
- base_config = base_cfg_mistral
81
-
82
- out_cfg = transformers.MixtralConfig(**base_config.to_dict())
83
- out_cfg.architectures = ["MixtralForCausalLM"]
84
- out_cfg.num_local_experts = num_experts
85
- out_cfg.num_experts_per_tok = experts_per_token or 2
86
- out_cfg.sliding_window = None
87
-
88
- if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
89
- logging.warning(
90
- f"Your model has {out_cfg.num_local_experts} experts, which is "
91
- "not a power of two. The model will not be usable in llama.cpp."
92
- )
93
- return out_cfg
94
-
95
- def _remap_weight_name(self, weight: WeightInfo) -> str:
96
- if ".mlp." not in weight.name:
97
- # Everything but MLP is identical to base Mistral
98
- return weight.name
99
-
100
- res = weight.name
101
- for needle, replacement in [
102
- (".mlp.gate_proj", ".block_sparse_moe.experts.{expert_idx}.w1"),
103
- (".mlp.down_proj", ".block_sparse_moe.experts.{expert_idx}.w2"),
104
- (".mlp.up_proj", ".block_sparse_moe.experts.{expert_idx}.w3"),
105
- ]:
106
- res = res.replace(needle, replacement)
107
- return res
108
-
109
- def _router_weight_name(self, layer_idx: int) -> str:
110
- return f"model.layers.{layer_idx}.block_sparse_moe.gate.weight"
111
-
112
- def write_model(
113
- self,
114
- out_path: str,
115
- config: MoEMergeConfig,
116
- merge_options: MergeOptions,
117
- router_weights: List[torch.Tensor],
118
- shared_router_weights: Optional[List[torch.Tensor]] = None,
119
- ):
120
- base_model = config.base_model
121
- base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
122
-
123
- assert len(router_weights) == base_cfg.num_hidden_layers, (
124
- f"Expected {base_cfg.num_hidden_layers} router weights, "
125
- f"got {len(router_weights)}"
126
- )
127
-
128
- out_dtype = select_dtype(config, base_cfg)
129
- out_cfg = self._generate_config(
130
- base_cfg,
131
- len(config.experts),
132
- len(config.shared_experts or []),
133
- config.experts_per_token,
134
- )
135
- out_cfg.torch_dtype = out_dtype
136
- out_cfg.save_pretrained(out_path)
137
-
138
- loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
139
- for weight_info in tqdm.tqdm(
140
- MISTRAL_INFO.all_weights(base_cfg),
141
- desc="Weights",
142
- ):
143
- tensor_name = self._remap_weight_name(weight_info)
144
- if "{expert_idx}" in tensor_name:
145
- for expert_index, expert in enumerate(config.experts):
146
- expert_name = tensor_name.replace("{expert_idx}", str(expert_index))
147
- expert_loader = loaders.get(expert.source_model)
148
- copy_tensor_out(
149
- weight_info,
150
- expert_loader,
151
- writer,
152
- expert=expert,
153
- out_dtype=out_dtype,
154
- output_name=expert_name,
155
- clone=merge_options.clone_tensors,
156
- is_residual="down_proj" in tensor_name,
157
- )
158
- else:
159
- # START FINAL PATCH
160
- # Because WeightInfo is a frozen Pydantic model, we cannot modify it.
161
- # We must manually load and save the tensor for the tied weights case.
162
- if (
163
- weight_info.name == "lm_head.weight"
164
- and base_cfg.tie_word_embeddings
165
- ):
166
- # If tie_word_embeddings is used, lm_head.weight should not be copied.
167
- pass
168
-
169
- else:
170
- tensor = base_loader.get_tensor(weight_info.name)
171
- writer.save_tensor(
172
- weight_info.name, # Always save with the correct destination name
173
- tensor.to(dtype=out_dtype),
174
- clone=merge_options.clone_tensors,
175
- )
176
- # END FINAL PATCH
177
-
178
- for layer_idx, weight in enumerate(
179
- tqdm.tqdm(router_weights, desc="Router weights")
180
- ):
181
- writer.save_tensor(
182
- self._router_weight_name(layer_idx),
183
- weight.to(dtype=out_dtype).contiguous(),
184
- clone=merge_options.clone_tensors,
185
- )
186
-
187
- writer.finalize()