File size: 16,398 Bytes
b3e3307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
#!/usr/bin/env python3
"""
Mistral Model Transformer

This script transforms Mistral-Small-3.1-24B-Base-2503 into a text-only model by:
1. Removing multimodality features
2. Removing the vision encoder
3. Changing the architecture from "mistral3" to "mistral"
4. Ensuring weight mapping structure matches Devstral-Small-2505 exactly

Usage:
    python convert.py --input-model mistralai/Mistral-Small-3.1-24B-Base-2503 --output-path ./mistral-small-text-only --reference-model mistralai/Devstral-Small-2505

Note:
    This script requires significant disk space to download and process the full model.
"""

import argparse
import json
import os
import shutil
from pathlib import Path
import logging

from huggingface_hub import snapshot_download, hf_hub_download
from safetensors.torch import load_file, save_file
from transformers import AutoConfig, AutoModelForCausalLM

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def parse_args():
    parser = argparse.ArgumentParser(description="Transform Mistral model to text-only version")
    parser.add_argument(
        "--input-model",
        type=str,
        default="mistralai/Mistral-Small-3.1-24B-Base-2503",
        help="Path or HF repo id of the input model"
    )
    parser.add_argument(
        "--output-path",
        type=str,
        required=True,
        help="Path to save the transformed model"
    )
    parser.add_argument(
        "--cache-dir",
        type=str,
        default=None,
        help="Cache directory for downloading models"
    )
    parser.add_argument(
        "--reference-model",
        type=str,
        default="mistralai/Devstral-Small-2505",
        help="Path or HF repo id of the reference model for weight mapping"
    )
    return parser.parse_args()

def transform_config(config_path, output_path, reference_config=None):
    """
    Transform the model config by:
    1. Changing model_type from "mistral3" to "mistral"
    2. Removing vision_config
    3. Removing multimodal parameters
    4. Updating architectures to match Devstral exactly
    5. Ensuring all parameters match Devstral's config exactly
    """
    logger.info(f"Transforming config at {config_path}")
    
    with open(config_path, "r") as f:
        config = json.load(f)
    
    if reference_config:
        logger.info("Using reference config as template")
        new_config = reference_config.copy()
        
        text_config = config.get("text_config", config)
        
        for key, value in text_config.items():
            if key not in new_config and key != "model_type":
                new_config[key] = value
                logger.info(f"Added parameter from original config: {key}")
    else:
        logger.info("No reference config available, using basic transformation")
        new_config = config.copy()
        
        # Change model_type from mistral3 to mistral
        if new_config.get("model_type") == "mistral3":
            new_config["model_type"] = "mistral"
            logger.info("Changed model_type from 'mistral3' to 'mistral'")
        
        # Update architectures to use MistralForCausalLM
        if "architectures" in new_config:
            new_config["architectures"] = ["MistralForCausalLM"]
            logger.info("Changed architecture to 'MistralForCausalLM'")
        
        # Remove vision_config
        if "vision_config" in new_config:
            del new_config["vision_config"]
            logger.info("Removed vision_config")
        
        # Remove multimodal-related parameters
        multimodal_params = [
            "image_token_index",
            "multimodal_projector_bias",
            "projector_hidden_act",
            "spatial_merge_size",
            "vision_tower_layer_list",
            "vision_feature_layer"
        ]
        
        for param in multimodal_params:
            if param in new_config:
                del new_config[param]
                logger.info(f"Removed multimodal parameter: {param}")
        
        if "text_config" in new_config:
            text_config = new_config.pop("text_config")
            for key, value in text_config.items():
                if key != "model_type":  # Don't overwrite the model_type
                    new_config[key] = value
            logger.info("Moved text_config parameters to top level")
        
        if "bos_token_id" not in new_config:
            new_config["bos_token_id"] = 1
            logger.info("Added bos_token_id: 1")
        
        if "eos_token_id" not in new_config:
            new_config["eos_token_id"] = 2
            logger.info("Added eos_token_id: 2")
        
        if "tie_word_embeddings" not in new_config:
            new_config["tie_word_embeddings"] = False
            logger.info("Added tie_word_embeddings: false")
        
        new_config["transformers_version"] = "4.51.3"
        logger.info("Updated transformers_version to 4.51.3")
    
    os_output_path = Path(output_path) / "config.json"
    with open(os_output_path, "w") as f:
        json.dump(new_config, f, indent=2)
    
    logger.info(f"Saved transformed config to {os_output_path}")
    return new_config

def is_vision_weight(weight_name):
    """Check if a weight is related to vision functionality"""
    vision_patterns = ["vision_tower", "multi_modal_projector"]
    return any(pattern in weight_name for pattern in vision_patterns)

def transform_weights(model_path, output_path, safetensors_index_path, reference_weight_map=None):
    """
    Transform model weights by:
    1. Loading the weight map from safetensors index
    2. Filtering out vision-related weights
    3. Removing the "language_model." prefix from weight names
    4. Ensuring the exact same partitioning as Devstral
    5. Saving the filtered weights to the output path
    """
    logger.info(f"Transforming weights using index at {safetensors_index_path}")
    
    with open(safetensors_index_path, "r") as f:
        index_data = json.load(f)
    
    original_weight_map = index_data.get("weight_map", {})
    
    # Count vision and non-vision weights
    vision_weights = [name for name in original_weight_map if is_vision_weight(name)]
    non_vision_weights = [name for name in original_weight_map if not is_vision_weight(name)]
    
    logger.info(f"Found {len(vision_weights)} vision-related weights to remove")
    logger.info(f"Found {len(non_vision_weights)} non-vision weights to keep")
    
    # Create a mapping from original weight names to Devstral-style weight names
    weight_name_mapping = {}
    for original_name in non_vision_weights:
        if original_name.startswith("language_model."):
            new_name = original_name[len("language_model."):]
            weight_name_mapping[original_name] = new_name
        else:
            weight_name_mapping[original_name] = original_name
    
    logger.info(f"Created mapping for {len(weight_name_mapping)} weight names")
    
    new_weight_map = {}
    
    if reference_weight_map and "weight_map" in reference_weight_map:
        devstral_weight_map = reference_weight_map["weight_map"]
        logger.info(f"Using Devstral reference weight map with {len(devstral_weight_map)} entries")
        
        for original_name, new_name in weight_name_mapping.items():
            if new_name in devstral_weight_map:
                new_weight_map[new_name] = devstral_weight_map[new_name]
            else:
                logger.warning(f"Weight {new_name} not found in Devstral reference map")
    else:
        logger.warning("No Devstral reference map available, using original partitioning")
        for original_name, new_name in weight_name_mapping.items():
            new_weight_map[new_name] = original_weight_map[original_name]
    
    # Group weights by their safetensor file for the actual transformation
    file_to_weights = {}
    for new_name, file_name in new_weight_map.items():
        if file_name not in file_to_weights:
            file_to_weights[file_name] = []
        
        original_names = [orig for orig, new in weight_name_mapping.items() if new == new_name]
        if original_names:
            file_to_weights[file_name].append((original_names[0], new_name))
    
    os.makedirs(Path(output_path), exist_ok=True)
    
    # Process each safetensor file
    for file_name, weight_pairs in file_to_weights.items():
        logger.info(f"Processing {file_name} with {len(weight_pairs)} weights")
        
        tensors_to_save = {}
        
        for original_name, new_name in weight_pairs:
            original_file = original_weight_map.get(original_name)
            if not original_file:
                logger.warning(f"Original file not found for weight {original_name}")
                continue
                
            input_file_path = Path(model_path) / original_file
            if not input_file_path.exists():
                logger.warning(f"File {input_file_path} does not exist, skipping")
                continue
            
            try:
                original_tensors = load_file(input_file_path)
                if original_name in original_tensors:
                    tensors_to_save[new_name] = original_tensors[original_name]
                else:
                    logger.warning(f"Weight {original_name} not found in {original_file}")
            except Exception as e:
                logger.error(f"Error loading {original_file}: {e}")
        
        if tensors_to_save:
            output_file_path = Path(output_path) / file_name
            try:
                save_file(tensors_to_save, output_file_path)
                logger.info(f"Saved {len(tensors_to_save)} weights to {file_name}")
            except Exception as e:
                logger.error(f"Error saving {file_name}: {e}")
    
    # Save the new safetensors index
    new_index = {
        "metadata": {"total_size": reference_weight_map.get("metadata", {}).get("total_size", 0)} 
                  if reference_weight_map else index_data.get("metadata", {}),
        "weight_map": new_weight_map
    }
    
    output_index_path = Path(output_path) / "model.safetensors.index.json"
    with open(output_index_path, "w") as f:
        json.dump(new_index, f, indent=2)
    
    logger.info(f"Saved transformed safetensors index to {output_index_path}")

def copy_additional_files(model_path, output_path):
    """Copy additional model files like tokenizer, generation config, etc."""
    additional_files = [
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "generation_config.json"
    ]
    
    for filename in additional_files:
        src_path = Path(model_path) / filename
        if src_path.exists():
            dst_path = Path(output_path) / filename
            shutil.copy(src_path, dst_path)
            logger.info(f"Copied {filename} to output directory")
        else:
            logger.warning(f"File {filename} not found in model directory")

def download_minimal_files(repo_id, output_dir, cache_dir=None):
    """Download only the necessary files for transformation without the full model"""
    logger.info(f"Downloading minimal files from {repo_id}")
    
    # List of files to download
    files_to_download = [
        "config.json",
        "model.safetensors.index.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "generation_config.json"
    ]
    
    downloaded_files = {}
    
    for filename in files_to_download:
        try:
            file_path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                cache_dir=cache_dir,
                local_files_only=False
            )
            downloaded_files[filename] = file_path
            logger.info(f"Downloaded {filename} to {file_path}")
        except Exception as e:
            logger.warning(f"Failed to download {filename}: {e}")
    
    return downloaded_files

def download_reference_weight_map(reference_model, cache_dir=None):
    """Download reference model's weight map to use as a reference"""
    logger.info(f"Downloading reference weight map from {reference_model}")
    
    try:
        file_path = hf_hub_download(
            repo_id=reference_model,
            filename="model.safetensors.index.json",
            cache_dir=cache_dir,
            local_files_only=False
        )
        
        with open(file_path, "r") as f:
            reference_map = json.load(f)
        
        logger.info(f"Successfully loaded reference weight map with {len(reference_map.get('weight_map', {}))} weights")
        return reference_map
    except Exception as e:
        logger.error(f"Failed to download reference weight map: {e}")
        return None

def download_reference_config(reference_model, cache_dir=None):
    """Download reference model's config.json to use as a reference"""
    logger.info(f"Downloading reference config from {reference_model}")
    
    try:
        file_path = hf_hub_download(
            repo_id=reference_model,
            filename="config.json",
            cache_dir=cache_dir,
            local_files_only=False
        )
        
        with open(file_path, "r") as f:
            reference_config = json.load(f)
        
        logger.info(f"Successfully loaded reference config")
        return reference_config
    except Exception as e:
        logger.error(f"Failed to download reference config: {e}")
        return None

def verify_model(output_path):
    """Verify that the transformed model can be loaded without errors"""
    logger.info(f"Verifying transformed model at {output_path}")
    try:
        config = AutoConfig.from_pretrained(output_path)
        logger.info(f"Successfully loaded config with model_type={config.model_type}")
        
        # Attempt to load just the model architecture (without weights)
        # This verifies the configuration is valid
        AutoModelForCausalLM.from_config(config)
        logger.info("Successfully loaded model architecture from config")
        
        return True
    except Exception as e:
        logger.error(f"Error verifying model: {e}")
        return False

def main():
    args = parse_args()
    
    input_model = args.input_model
    output_path = args.output_path
    cache_dir = args.cache_dir
    reference_model = args.reference_model
    
    # Download reference weight map and config
    reference_weight_map = download_reference_weight_map(reference_model, cache_dir)
    if not reference_weight_map:
        logger.warning("Could not download reference weight map. The weight partitioning may not match exactly.")
    
    reference_config = download_reference_config(reference_model, cache_dir)
    if not reference_config:
        logger.warning("Could not download reference config. The config may not match exactly.")
    
    # Create output directory
    os.makedirs(output_path, exist_ok=True)

    # Download the full model
    if not os.path.exists(input_model) or not os.path.isdir(input_model):
        logger.info(f"Downloading model from {input_model}")
        try:
            model_path = snapshot_download(
                repo_id=input_model,
                cache_dir=cache_dir,
                local_files_only=False,
                ignore_patterns=["*consolidated*"]
            )
        except Exception as e:
            logger.error(f"Error downloading model: {e}")
            return
    else:
        model_path = input_model
    
    logger.info(f"Model path: {model_path}")
    
    # Transform config
    config_path = os.path.join(model_path, "config.json")
    transform_config(config_path, output_path, reference_config)
    
    # Transform weights
    safetensors_index_path = os.path.join(model_path, "model.safetensors.index.json")
    transform_weights(
        model_path, 
        output_path, 
        safetensors_index_path,
        reference_weight_map=reference_weight_map
    )
    
    # Copy additional files
    copy_additional_files(model_path, output_path)
    
    # Verify the transformed model
    success = verify_model(output_path)
    
    if success:
        logger.info(f"Successfully transformed model to {output_path}")
    else:
        logger.error(f"Failed to transform model properly")

if __name__ == "__main__":
    main()