File size: 7,334 Bytes
5a91cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Copyright (C) 2025 Arcee AI
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import logging
from typing import List, Optional
import torch
import tqdm
import transformers
from mergekit.architecture import MISTRAL_INFO, WeightInfo
from mergekit.moe.arch import MoEOutputArchitecture
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
from mergekit.moe.config import MoEMergeConfig
from mergekit.options import MergeOptions
class MixtralMoE(MoEOutputArchitecture):
def name(self) -> str:
return "Mixtral"
def supports_config(
self,
config: MoEMergeConfig,
explain: bool = False,
trust_remote_code: bool = False,
) -> bool:
if config.shared_experts:
if explain:
logging.warning("Mixtral does not support shared experts")
return False
model_types = []
for model_ref in [config.base_model] + [e.source_model for e in config.experts]:
model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
model_types.append(model_cfg.model_type)
if len(set(model_types)) != 1:
if explain:
logging.warning(
"Mixtral requires all input models to have the same architecture"
)
return False
if model_types[0] not in ("llama", "mistral"):
if explain:
logging.warning(
"Mixtral requires all input models to be Llama or Mistral models"
)
return False
return True
def _generate_config(
self,
base_config: transformers.PretrainedConfig,
num_experts: int,
shared_experts: Optional[int] = None,
experts_per_token: Optional[int] = None,
) -> transformers.PretrainedConfig:
if shared_experts:
raise NotImplementedError("Shared experts not supported for Mixtral output")
if not isinstance(base_config, transformers.MistralConfig):
base_cfg_mistral = transformers.MistralConfig(**base_config.to_dict())
base_cfg_mistral.sliding_window = None
base_cfg_mistral.max_position_embeddings = (
base_config.max_position_embeddings
)
base_config = base_cfg_mistral
out_cfg = transformers.MixtralConfig(**base_config.to_dict())
out_cfg.architectures = ["MixtralForCausalLM"]
out_cfg.num_local_experts = num_experts
out_cfg.num_experts_per_tok = experts_per_token or 2
out_cfg.sliding_window = None
if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
logging.warning(
f"Your model has {out_cfg.num_local_experts} experts, which is "
"not a power of two. The model will not be usable in llama.cpp."
)
return out_cfg
def _remap_weight_name(self, weight: WeightInfo) -> str:
if ".mlp." not in weight.name:
# Everything but MLP is identical to base Mistral
return weight.name
res = weight.name
for needle, replacement in [
(".mlp.gate_proj", ".block_sparse_moe.experts.{expert_idx}.w1"),
(".mlp.down_proj", ".block_sparse_moe.experts.{expert_idx}.w2"),
(".mlp.up_proj", ".block_sparse_moe.experts.{expert_idx}.w3"),
]:
res = res.replace(needle, replacement)
return res
def _router_weight_name(self, layer_idx: int) -> str:
return f"model.layers.{layer_idx}.block_sparse_moe.gate.weight"
def write_model(
self,
out_path: str,
config: MoEMergeConfig,
merge_options: MergeOptions,
router_weights: List[torch.Tensor],
shared_router_weights: Optional[List[torch.Tensor]] = None,
):
base_model = config.base_model
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
assert len(router_weights) == base_cfg.num_hidden_layers, (
f"Expected {base_cfg.num_hidden_layers} router weights, "
f"got {len(router_weights)}"
)
out_dtype = select_dtype(config, base_cfg)
out_cfg = self._generate_config(
base_cfg,
len(config.experts),
len(config.shared_experts or []),
config.experts_per_token,
)
out_cfg.torch_dtype = out_dtype
out_cfg.save_pretrained(out_path)
loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
for weight_info in tqdm.tqdm(
MISTRAL_INFO.all_weights(base_cfg),
desc="Weights",
):
tensor_name = self._remap_weight_name(weight_info)
if "{expert_idx}" in tensor_name:
for expert_index, expert in enumerate(config.experts):
expert_name = tensor_name.replace("{expert_idx}", str(expert_index))
expert_loader = loaders.get(expert.source_model)
copy_tensor_out(
weight_info,
expert_loader,
writer,
expert=expert,
out_dtype=out_dtype,
output_name=expert_name,
clone=merge_options.clone_tensors,
is_residual="down_proj" in tensor_name,
)
else:
# START FINAL PATCH
# Because WeightInfo is a frozen Pydantic model, we cannot modify it.
# We must manually load and save the tensor for the tied weights case.
if (
weight_info.name == "lm_head.weight"
and base_cfg.tie_word_embeddings
):
# If tie_word_embeddings is used, lm_head.weight should not be copied.
pass
else:
tensor = base_loader.get_tensor(weight_info.name)
writer.save_tensor(
weight_info.name, # Always save with the correct destination name
tensor.to(dtype=out_dtype),
clone=merge_options.clone_tensors,
)
# END FINAL PATCH
for layer_idx, weight in enumerate(
tqdm.tqdm(router_weights, desc="Router weights")
):
writer.save_tensor(
self._router_weight_name(layer_idx),
weight.to(dtype=out_dtype).contiguous(),
clone=merge_options.clone_tensors,
)
writer.finalize() |