# coding=utf-8 # Copyright 2025 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import gc import json import os import re import torch from safetensors.torch import save_file from safetensors.torch import safe_open from huggingface_hub import snapshot_download from transformers import Mistral3Config, Mistral3ForConditionalGeneration # fmt: off STATE_DICT_MAPPING = { r"^language_model\.lm_head": r"output", r"^language_model\.model\.norm": r"norm", r"^language_model\.model\.embed_tokens": r"tok_embeddings", r"^language_model\.model\.layers\.(\d+)\.input_layernorm": r"layers.\1.attention_norm", r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm": r"layers.\1.ffn_norm", r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj": r"layers.\1.attention.w\2", r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj": r"layers.\1.feed_forward.w1", r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj": r"layers.\1.feed_forward.w2", r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj": r"layers.\1.feed_forward.w3", r"multi_modal_projector.patch_merger.merging_layer.weight": r"patch_merger.merging_layer.weight", r"multi_modal_projector.norm.weight": r"pre_mm_projector_norm.weight", r"multi_modal_projector.linear_1.weight": r"vision_language_adapter.w_in.weight", r"multi_modal_projector.linear_2.weight": r"vision_language_adapter.w_out.weight", r"vision_tower.ln_pre.weight": r"vision_encoder.ln_pre.weight", r"vision_tower.patch_conv.weight": r"vision_encoder.patch_conv.weight", r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm": r"vision_encoder.transformer.layers.\1.attention_norm", r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm": r"vision_encoder.transformer.layers.\1.ffn_norm", r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj": r"vision_encoder.transformer.layers.\1.attention.w\2", r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj": r"vision_encoder.transformer.layers.\1.feed_forward.w1", r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj": r"vision_encoder.transformer.layers.\1.feed_forward.w2", r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj": r"vision_encoder.transformer.layers.\1.feed_forward.w3", } # fmt: on IGNORE_STATE_DICT_MAPPING = { r"^model\.language_model": r"language_model.model", } SKIP_KEYS = [ ] def add_quantization_config(config, hf_config: Mistral3ForConditionalGeneration): quantization_config = hf_config.hf_quantizer.quantization_config mistral_ignore = [] # keys to ignore in the quantization config for hf_key in quantization_config.quantization_config.ignore: mistral_key = map_hf_key_to_mistral(hf_key, state_dict_mapping = IGNORE_STATE_DICT_MAPPING) mistral_ignore.append(mistral_key) quantization_config.quantization_config.ignore = mistral_ignore quant_config_dict = quantization_config.to_dict() quant_config_dict['config_groups']['group_0']['input_activations'].pop('scale_dtype') quant_config_dict['config_groups']['group_0']['input_activations'].pop('zp_dtype') quant_config_dict['config_groups']['group_0']['weights'].pop('scale_dtype') quant_config_dict['config_groups']['group_0']['weights'].pop('zp_dtype') config["quantization_config"] = quant_config_dict return config def map_hf_key_to_mistral(hf_key, state_dict_mapping = STATE_DICT_MAPPING): """Map a key from HF format to Mistral format""" for pattern, replacement in state_dict_mapping.items(): new_key, n_replace = re.subn(pattern, replacement, hf_key) if n_replace > 0: return new_key.replace("weight_scale", "qscale_weight") # If no mapping found, return the original key return hf_key.replace("weight_scale", "qscale_weight") def permute_for_mistral_rope(tensor, n_heads, dim1, dim2): """Reverse the ROPE permutation to get back to Mistral format.""" old_tensor = tensor tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) tensor = tensor.transpose(1, 2) tensor = tensor.reshape(dim1, dim2) return tensor def convert_state_dict(hf_state_dict, config): """Convert HF Ministral state dict to Mistral format""" mistral_dict = {} text_config = config["text_config"] vision_config = config["vision_config"] text_num_attention_heads = text_config["num_attention_heads"] text_hidden_size = text_config["hidden_size"] text_head_dim = text_config["head_dim"] text_num_key_value_heads = text_config["num_key_value_heads"] text_key_value_dim = text_head_dim * text_num_key_value_heads text_query_dim = text_head_dim * text_num_attention_heads vision_num_attention_heads = vision_config["num_attention_heads"] vision_hidden_size = vision_config["hidden_size"] vision_head_dim = vision_config["head_dim"] vision_num_key_value_heads = vision_num_attention_heads vision_key_value_dim = vision_head_dim * vision_num_key_value_heads vision_query_dim = vision_head_dim * vision_num_attention_heads for hf_key, tensor in hf_state_dict.items(): if hf_key in SKIP_KEYS: continue mistral_key = map_hf_key_to_mistral(hf_key) if "language_model" in hf_key: if hf_key.endswith("q_proj.weight"): tensor = permute_for_mistral_rope(tensor, text_num_attention_heads, text_query_dim, text_hidden_size) elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == text_num_attention_heads: tensor = permute_for_mistral_rope(tensor, text_num_attention_heads, text_query_dim, 1) elif hf_key.endswith("k_proj.weight"): tensor = permute_for_mistral_rope(tensor, text_num_key_value_heads, text_key_value_dim, text_hidden_size) elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == text_num_key_value_heads: tensor = permute_for_mistral_rope(tensor, text_num_key_value_heads, text_key_value_dim, 1) if "vision_tower" in hf_key: if hf_key.endswith("q_proj.weight"): tensor = permute_for_mistral_rope(tensor, vision_num_attention_heads, vision_query_dim, vision_hidden_size) elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == vision_num_attention_heads: tensor = permute_for_mistral_rope(tensor, vision_num_attention_heads, vision_query_dim, 1) elif hf_key.endswith("k_proj.weight"): tensor = permute_for_mistral_rope(tensor, vision_num_key_value_heads, vision_key_value_dim, vision_hidden_size) elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == vision_num_key_value_heads: tensor = permute_for_mistral_rope(tensor, vision_num_key_value_heads, vision_key_value_dim, 1) mistral_dict[mistral_key] = tensor return mistral_dict def write_model( input_path_or_repo, output_dir, unquantized_model_path=None, ): print("Converting HF Ministral model to Mistral format.") os.makedirs(output_dir, exist_ok=True) # Load the HF Ministral model print(f"Loading HF Ministral model from {input_path_or_repo}...") hf_config = Mistral3ForConditionalGeneration.from_pretrained(input_path_or_repo) if os.path.exists(input_path_or_repo): local_path = input_path_or_repo else: local_path = snapshot_download(input_path_or_repo) # Convert config if unquantized_model_path is not None: if os.path.exists(unquantized_model_path): unquantized_model_path = unquantized_model_path else: unquantized_model_path = snapshot_download(unquantized_model_path) config_path = os.path.join(unquantized_model_path, "params.json") with open(config_path, "r") as f: config = json.load(f) config = add_quantization_config(config, hf_config) with open(os.path.join(output_dir, "params.json"), "w") as f: json.dump(config, f, indent=2) else: raise ValueError(f"Unquantized model config not found for {unquantized_model_path}") # Convert state dict print("Converting state dict...") tensor_files = sorted([f for f in os.listdir(os.path.join(local_path)) if f.endswith(".safetensors")]) hf_state_dict = {} for file in tensor_files: file_path = os.path.join(local_path, file) with safe_open(file_path, framework="pt", device="cuda") as f: for key in f.keys(): hf_state_dict[key] = f.get_tensor(key) mistral_config = Mistral3Config().to_dict() mistral_state_dict = convert_state_dict(hf_state_dict, mistral_config) # save the state dict save_file(mistral_state_dict, os.path.join(output_dir, "consolidated.safetensors")) del hf_state_dict, mistral_state_dict gc.collect() print("Model converted successfully.") def write_tokenizer(input_path_or_repo: str, output_dir: str): """Extract and save the tokenizer from Ministral model""" from transformers import MistralCommonBackend print("Extracting tokenizer...") tokenizer = MistralCommonBackend.from_pretrained(input_path_or_repo) tokenizer.save_pretrained(output_dir) print("Tokenizer saved successfully.") def main(): parser = argparse.ArgumentParser(description="Convert HF Ministral weights to Mistral format") parser.add_argument( "--input_path_or_repo", type=str, default="Ministral-3-14B-Instruct-2512-QUANTIZED", help="Path or repo containing HF Ministral model", ) parser.add_argument( "--output_dir", type=str, default="Ministral-3-14B-Instruct-2512-QUANTIZED-CONVERTED", help="Location to write Mistral model and tokenizer", ) parser.add_argument( "--skip_tokenizer", action="store_true", help="Skip tokenizer conversion" ) parser.add_argument( "--unquantized_model_path", type=str, default="mistralai/Ministral-3-14B-Instruct-2512-BF16", help="Path to the unquantized model", ) args = parser.parse_args() write_model( args.input_path_or_repo, args.output_dir, unquantized_model_path=args.unquantized_model_path, ) if not args.skip_tokenizer: write_tokenizer( args.input_path_or_repo, args.output_dir, ) if __name__ == "__main__": main()