Ministral-3-14B-Instruct-2512-NVFP4 / convert_ministral_hf_to_mistral.py
ChibuUkachi's picture
Upload folder using huggingface_hub
7d439c2 verified
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import re
import torch
from safetensors.torch import save_file
from safetensors.torch import safe_open
from huggingface_hub import snapshot_download
from transformers import Mistral3Config, Mistral3ForConditionalGeneration
# fmt: off
STATE_DICT_MAPPING = {
r"^language_model\.lm_head": r"output",
r"^language_model\.model\.norm": r"norm",
r"^language_model\.model\.embed_tokens": r"tok_embeddings",
r"^language_model\.model\.layers\.(\d+)\.input_layernorm": r"layers.\1.attention_norm",
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm": r"layers.\1.ffn_norm",
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj": r"layers.\1.attention.w\2",
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj": r"layers.\1.feed_forward.w1",
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj": r"layers.\1.feed_forward.w2",
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj": r"layers.\1.feed_forward.w3",
r"multi_modal_projector.patch_merger.merging_layer.weight": r"patch_merger.merging_layer.weight",
r"multi_modal_projector.norm.weight": r"pre_mm_projector_norm.weight",
r"multi_modal_projector.linear_1.weight": r"vision_language_adapter.w_in.weight",
r"multi_modal_projector.linear_2.weight": r"vision_language_adapter.w_out.weight",
r"vision_tower.ln_pre.weight": r"vision_encoder.ln_pre.weight",
r"vision_tower.patch_conv.weight": r"vision_encoder.patch_conv.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm": r"vision_encoder.transformer.layers.\1.attention_norm",
r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm": r"vision_encoder.transformer.layers.\1.ffn_norm",
r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj": r"vision_encoder.transformer.layers.\1.attention.w\2",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj": r"vision_encoder.transformer.layers.\1.feed_forward.w1",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj": r"vision_encoder.transformer.layers.\1.feed_forward.w2",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj": r"vision_encoder.transformer.layers.\1.feed_forward.w3",
}
# fmt: on
IGNORE_STATE_DICT_MAPPING = {
r"^model\.language_model": r"language_model.model",
}
SKIP_KEYS = [ ]
def add_quantization_config(config, hf_config: Mistral3ForConditionalGeneration):
quantization_config = hf_config.hf_quantizer.quantization_config
mistral_ignore = [] # keys to ignore in the quantization config
for hf_key in quantization_config.quantization_config.ignore:
mistral_key = map_hf_key_to_mistral(hf_key, state_dict_mapping = IGNORE_STATE_DICT_MAPPING)
mistral_ignore.append(mistral_key)
quantization_config.quantization_config.ignore = mistral_ignore
quant_config_dict = quantization_config.to_dict()
quant_config_dict['config_groups']['group_0']['input_activations'].pop('scale_dtype')
quant_config_dict['config_groups']['group_0']['input_activations'].pop('zp_dtype')
quant_config_dict['config_groups']['group_0']['weights'].pop('scale_dtype')
quant_config_dict['config_groups']['group_0']['weights'].pop('zp_dtype')
config["quantization_config"] = quant_config_dict
return config
def map_hf_key_to_mistral(hf_key, state_dict_mapping = STATE_DICT_MAPPING):
"""Map a key from HF format to Mistral format"""
for pattern, replacement in state_dict_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, hf_key)
if n_replace > 0:
return new_key.replace("weight_scale", "qscale_weight")
# If no mapping found, return the original key
return hf_key.replace("weight_scale", "qscale_weight")
def permute_for_mistral_rope(tensor, n_heads, dim1, dim2):
"""Reverse the ROPE permutation to get back to Mistral format."""
old_tensor = tensor
tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2)
tensor = tensor.transpose(1, 2)
tensor = tensor.reshape(dim1, dim2)
return tensor
def convert_state_dict(hf_state_dict, config):
"""Convert HF Ministral state dict to Mistral format"""
mistral_dict = {}
text_config = config["text_config"]
vision_config = config["vision_config"]
text_num_attention_heads = text_config["num_attention_heads"]
text_hidden_size = text_config["hidden_size"]
text_head_dim = text_config["head_dim"]
text_num_key_value_heads = text_config["num_key_value_heads"]
text_key_value_dim = text_head_dim * text_num_key_value_heads
text_query_dim = text_head_dim * text_num_attention_heads
vision_num_attention_heads = vision_config["num_attention_heads"]
vision_hidden_size = vision_config["hidden_size"]
vision_head_dim = vision_config["head_dim"]
vision_num_key_value_heads = vision_num_attention_heads
vision_key_value_dim = vision_head_dim * vision_num_key_value_heads
vision_query_dim = vision_head_dim * vision_num_attention_heads
for hf_key, tensor in hf_state_dict.items():
if hf_key in SKIP_KEYS:
continue
mistral_key = map_hf_key_to_mistral(hf_key)
if "language_model" in hf_key:
if hf_key.endswith("q_proj.weight"):
tensor = permute_for_mistral_rope(tensor, text_num_attention_heads, text_query_dim, text_hidden_size)
elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == text_num_attention_heads:
tensor = permute_for_mistral_rope(tensor, text_num_attention_heads, text_query_dim, 1)
elif hf_key.endswith("k_proj.weight"):
tensor = permute_for_mistral_rope(tensor, text_num_key_value_heads, text_key_value_dim, text_hidden_size)
elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == text_num_key_value_heads:
tensor = permute_for_mistral_rope(tensor, text_num_key_value_heads, text_key_value_dim, 1)
if "vision_tower" in hf_key:
if hf_key.endswith("q_proj.weight"):
tensor = permute_for_mistral_rope(tensor, vision_num_attention_heads, vision_query_dim, vision_hidden_size)
elif hf_key.endswith("q_proj.weight_scale") and tensor.size(0) == vision_num_attention_heads:
tensor = permute_for_mistral_rope(tensor, vision_num_attention_heads, vision_query_dim, 1)
elif hf_key.endswith("k_proj.weight"):
tensor = permute_for_mistral_rope(tensor, vision_num_key_value_heads, vision_key_value_dim, vision_hidden_size)
elif hf_key.endswith("k_proj.weight_scale") and tensor.size(0) == vision_num_key_value_heads:
tensor = permute_for_mistral_rope(tensor, vision_num_key_value_heads, vision_key_value_dim, 1)
mistral_dict[mistral_key] = tensor
return mistral_dict
def write_model(
input_path_or_repo,
output_dir,
unquantized_model_path=None,
):
print("Converting HF Ministral model to Mistral format.")
os.makedirs(output_dir, exist_ok=True)
# Load the HF Ministral model
print(f"Loading HF Ministral model from {input_path_or_repo}...")
hf_config = Mistral3ForConditionalGeneration.from_pretrained(input_path_or_repo)
if os.path.exists(input_path_or_repo):
local_path = input_path_or_repo
else:
local_path = snapshot_download(input_path_or_repo)
# Convert config
if unquantized_model_path is not None:
if os.path.exists(unquantized_model_path):
unquantized_model_path = unquantized_model_path
else:
unquantized_model_path = snapshot_download(unquantized_model_path)
config_path = os.path.join(unquantized_model_path, "params.json")
with open(config_path, "r") as f:
config = json.load(f)
config = add_quantization_config(config, hf_config)
with open(os.path.join(output_dir, "params.json"), "w") as f:
json.dump(config, f, indent=2)
else:
raise ValueError(f"Unquantized model config not found for {unquantized_model_path}")
# Convert state dict
print("Converting state dict...")
tensor_files = sorted([f for f in os.listdir(os.path.join(local_path)) if f.endswith(".safetensors")])
hf_state_dict = {}
for file in tensor_files:
file_path = os.path.join(local_path, file)
with safe_open(file_path, framework="pt", device="cuda") as f:
for key in f.keys():
hf_state_dict[key] = f.get_tensor(key)
mistral_config = Mistral3Config().to_dict()
mistral_state_dict = convert_state_dict(hf_state_dict, mistral_config)
# save the state dict
save_file(mistral_state_dict, os.path.join(output_dir, "consolidated.safetensors"))
del hf_state_dict, mistral_state_dict
gc.collect()
print("Model converted successfully.")
def write_tokenizer(input_path_or_repo: str, output_dir: str):
"""Extract and save the tokenizer from Ministral model"""
from transformers import MistralCommonBackend
print("Extracting tokenizer...")
tokenizer = MistralCommonBackend.from_pretrained(input_path_or_repo)
tokenizer.save_pretrained(output_dir)
print("Tokenizer saved successfully.")
def main():
parser = argparse.ArgumentParser(description="Convert HF Ministral weights to Mistral format")
parser.add_argument(
"--input_path_or_repo",
type=str,
default="Ministral-3-14B-Instruct-2512-QUANTIZED",
help="Path or repo containing HF Ministral model",
)
parser.add_argument(
"--output_dir",
type=str,
default="Ministral-3-14B-Instruct-2512-QUANTIZED-CONVERTED",
help="Location to write Mistral model and tokenizer",
)
parser.add_argument(
"--skip_tokenizer",
action="store_true",
help="Skip tokenizer conversion"
)
parser.add_argument(
"--unquantized_model_path",
type=str,
default="mistralai/Ministral-3-14B-Instruct-2512-BF16",
help="Path to the unquantized model",
)
args = parser.parse_args()
write_model(
args.input_path_or_repo,
args.output_dir,
unquantized_model_path=args.unquantized_model_path,
)
if not args.skip_tokenizer:
write_tokenizer(
args.input_path_or_repo,
args.output_dir,
)
if __name__ == "__main__":
main()