Source code for maxtext.checkpoint_conversion.utils.hf_model_configs

#  Copyright 2023–2026 Google LLC
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
This config defines the architectural configurations of the Hugging Face version of a model.
"""


import transformers

if transformers.__version__ >= "5.0.0":
  from transformers.configuration_utils import PreTrainedConfig as PTConfig  # pytype: disable=import-error
else:
  from transformers.configuration_utils import PretrainedConfig as PTConfig


gemma4_26b_dict = {
    "architectures": ["Gemma4ForConditionalGeneration"],
    "audio_config": None,
    "audio_token_id": 258881,
    "boa_token_id": 256000,
    "boi_token_id": 255999,
    "dtype": "bfloat16",
    "eoa_token_id": 258883,
    "eoa_token_index": 258883,
    "eoi_token_id": 258882,
    "eos_token_id": [1, 106],
    "image_token_id": 258880,
    "initializer_range": 0.02,
    "model_type": "gemma4",
    "text_config": {
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attention_k_eq_v": True,
        "bos_token_id": 2,
        "dtype": "bfloat16",
        "enable_moe_block": True,
        "eos_token_id": 1,
        "moe_intermediate_size": 704,
        "final_logit_softcapping": 30.0,
        "global_head_dim": 512,
        "head_dim": 256,
        "hidden_activation": "gelu_pytorch_tanh",
        "hidden_size": 2816,
        "hidden_size_per_layer_input": 0,
        "initializer_range": 0.02,
        "intermediate_size": 2112,
        "layer_types": [
            "sliding_attention",
            "sliding_attention",
            "sliding_attention",
            "sliding_attention",
            "sliding_attention",
            "full_attention",
        ]
        * 5,
        "max_position_embeddings": 262144,
        "model_type": "gemma4_text",
        "num_attention_heads": 16,
        "num_experts": 128,
        "num_global_key_value_heads": 2,
        "num_hidden_layers": 30,
        "num_key_value_heads": 8,
        "num_kv_shared_layers": 0,
        "pad_token_id": 0,
        "rms_norm_eps": 1e-06,
        "rope_parameters": {
            "full_attention": {"partial_rotary_factor": 0.25, "rope_theta": 1_000_000.0, "rope_type": "proportional"},
            "sliding_attention": {"rope_theta": 10_000.0, "rope_type": "default"},
        },
        "sliding_window": 1024,
        "tie_word_embeddings": True,
        "top_k_experts": 8,
        "use_bidirectional_attention": "vision",
        "use_cache": True,
        "use_double_wide_mlp": False,
        "vocab_size": 262144,
        "vocab_size_per_layer_input": 262144,
    },
    "tie_word_embeddings": True,
    "transformers_version": "5.5.0.dev0",
    "video_token_id": 258884,
    "vision_config": {
        "attention_bias": False,
        "attention_dropout": 0.0,
        "default_output_length": 280,
        "dtype": "bfloat16",
        "global_head_dim": 72,
        "head_dim": 72,
        "hidden_activation": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "intermediate_size": 4304,
        "max_position_embeddings": 131072,
        "model_type": "gemma4_vision",
        "num_attention_heads": 16,
        "num_hidden_layers": 27,
        "num_key_value_heads": 16,
        "patch_size": 16,
        "pooling_kernel_size": 3,
        "position_embedding_size": 10240,
        "rms_norm_eps": 1e-06,
        "rope_parameters": {"rope_theta": 100.0, "rope_type": "default"},
        "standardize": True,
        "use_clipped_linears": False,
    },
    "vision_soft_tokens_per_image": 280,
}


gemma4_31b_dict = gemma4_26b_dict.copy()
gemma4_31b_dict["text_config"] = gemma4_26b_dict["text_config"].copy()
gemma4_31b_dict["text_config"].update(
    {
        "enable_moe_block": False,
        "hidden_size": 5376,
        "intermediate_size": 21504,
        "layer_types": [
            "sliding_attention",
            "sliding_attention",
            "sliding_attention",
            "sliding_attention",
            "sliding_attention",
            "full_attention",
        ]
        * 10,
        "num_attention_heads": 32,
        "num_experts": None,
        "num_global_key_value_heads": 4,
        "num_hidden_layers": 60,
        "num_key_value_heads": 16,
        "top_k_experts": None,
    }
)


try:
  # Will execute successfully if Transformers is updated with Gemma 4 support
  gemma4_26b_config = transformers.Gemma4Config(**gemma4_26b_dict)
  gemma4_31b_config = transformers.Gemma4Config(**gemma4_31b_dict)
except AttributeError:
  # Graceful fallback to raw dict-based PTConfig if Gemma 4 natively is missing
  gemma4_26b_config = PTConfig(**gemma4_26b_dict)  # pytype: disable=wrong-arg-types
  gemma4_31b_config = PTConfig(**gemma4_31b_dict)  # pytype: disable=wrong-arg-types


gemma3_4b_config = transformers.Gemma3Config(
    architectures=["Gemma3ForConditionalGeneration"],
    boi_token_index=255999,
    eoi_token_index=256000,
    eos_token_id=[1, 106],
    image_token_index=262144,
    initializer_range=0.02,
    mm_tokens_per_image=256,
    model_type="gemma3",
    text_config={
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attn_logit_softcapping": None,
        "cache_implementation": "hybrid",
        "final_logit_softcapping": None,
        "head_dim": 256,
        "hidden_activation": "gelu",
        "hidden_size": 2560,
        "initializer_range": 0.02,
        "intermediate_size": 10240,
        "max_position_embeddings": 163840,
        "model_type": "gemma3_text",
        "num_attention_heads": 8,
        "num_hidden_layers": 34,
        "num_key_value_heads": 4,
        "query_pre_attn_scalar": 256,
        "rms_norm_eps": 1e-06,
        "rope_local_base_freq": 10000.0,
        "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
        "rope_theta": 10000.0,
        "sliding_window": 1024,
        "sliding_window_pattern": 6,
        "use_cache": True,
        "vocab_size": 262144,
    },
    torch_dtype="bfloat16",
    vision_config={
        "attention_dropout": 0.0,
        "hidden_act": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "image_size": 896,
        "intermediate_size": 4304,
        "layer_norm_eps": 1e-06,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_channels": 3,
        "num_hidden_layers": 27,
        "patch_size": 14,
        "vision_use_head": False,
    },
)

gemma3_12b_config = transformers.Gemma3Config(
    architectures=["Gemma3ForConditionalGeneration"],
    boi_token_index=255999,
    eoi_token_index=256000,
    eos_token_id=[1, 106],
    image_token_index=262144,
    initializer_range=0.02,
    mm_tokens_per_image=256,
    model_type="gemma3",
    text_config={
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attn_logit_softcapping": None,
        "cache_implementation": "hybrid",
        "final_logit_softcapping": None,
        "head_dim": 256,
        "hidden_activation": "gelu",
        "hidden_size": 3840,
        "initializer_range": 0.02,
        "intermediate_size": 15360,
        "max_position_embeddings": 163840,
        "model_type": "gemma3_text",
        "num_attention_heads": 16,
        "num_hidden_layers": 48,
        "num_key_value_heads": 8,
        "query_pre_attn_scalar": 256,
        "rms_norm_eps": 1e-06,
        "rope_local_base_freq": 10000.0,
        "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
        "rope_theta": 10000.0,
        "sliding_window": 1024,
        "sliding_window_pattern": 6,
        "use_cache": True,
        "vocab_size": 262144,
    },
    torch_dtype="bfloat16",
    vision_config={
        "attention_dropout": 0.0,
        "hidden_act": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "image_size": 896,
        "intermediate_size": 4304,
        "layer_norm_eps": 1e-06,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_channels": 3,
        "num_hidden_layers": 27,
        "patch_size": 14,
        "vision_use_head": False,
    },
)

gemma3_27b_config = transformers.Gemma3Config(
    architectures=["Gemma3ForConditionalGeneration"],
    boi_token_index=255999,
    eoi_token_index=256000,
    eos_token_id=[1, 106],
    image_token_index=262144,
    initializer_range=0.02,
    mm_tokens_per_image=256,
    model_type="gemma3",
    text_config={
        "attention_bias": False,
        "attention_dropout": 0.0,
        "attn_logit_softcapping": None,
        "cache_implementation": "hybrid",
        "final_logit_softcapping": None,
        "head_dim": 128,
        "hidden_activation": "gelu",
        "hidden_size": 5376,
        "initializer_range": 0.02,
        "intermediate_size": 21504,
        "max_position_embeddings": 163840,
        "model_type": "gemma3_text",
        "num_attention_heads": 32,
        "num_hidden_layers": 62,
        "num_key_value_heads": 16,
        "query_pre_attn_scalar": 168,
        "rms_norm_eps": 1e-06,
        "rope_local_base_freq": 10000.0,
        "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
        "rope_theta": 10000.0,
        "sliding_window": 1024,
        "sliding_window_pattern": 6,
        "use_cache": True,
        "vocab_size": 262144,
    },
    torch_dtype="bfloat16",
    vision_config={
        "attention_dropout": 0.0,
        "hidden_act": "gelu_pytorch_tanh",
        "hidden_size": 1152,
        "image_size": 896,
        "intermediate_size": 4304,
        "layer_norm_eps": 1e-06,
        "model_type": "siglip_vision_model",
        "num_attention_heads": 16,
        "num_channels": 3,
        "num_hidden_layers": 27,
        "patch_size": 14,
        "vision_use_head": False,
    },
)


gemma2_2b_config = transformers.Gemma2Config(
    num_hidden_layers=26,
    num_attention_heads=8,
    num_key_value_heads=4,
    hidden_size=2304,
    intermediate_size=9216,
)

gemma2_9b_config = transformers.Gemma2Config(
    num_hidden_layers=42,
    num_attention_heads=16,
    num_key_value_heads=8,
    hidden_size=3584,
    intermediate_size=14336,
    final_logit_softcapping=30.0,
    attn_logit_softcapping=50.0,
    head_dim=256,
    sliding_window=4096,
    query_pre_attn_scalar=224,
)

gemma2_27b_config = transformers.Gemma2Config(
    num_hidden_layers=46,
    num_attention_heads=32,
    num_key_value_heads=16,
    hidden_size=4608,
    intermediate_size=36864,
    final_logit_softcapping=30.0,
    attn_logit_softcapping=50.0,
    head_dim=128,
    sliding_window=4096,
    query_pre_attn_scalar=144,
)

qwen25_1_5b_config = transformers.Qwen2Config(
    vocab_size=151936,
    hidden_size=1536,
    intermediate_size=8960,
    num_hidden_layers=28,
    num_attention_heads=12,
    num_key_value_heads=2,
    hidden_act="silu",
    max_position_embeddings=32768,
    rms_norm_eps=1e-06,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
    attention_bias=True,
)

qwen25_7b_config = transformers.Qwen2Config(
    vocab_size=152064,
    hidden_size=3584,
    intermediate_size=18944,
    num_hidden_layers=28,
    num_attention_heads=28,
    num_key_value_heads=4,
    hidden_act="silu",
    max_position_embeddings=32768,
    initializer_range=0.02,
    rms_norm_eps=1e-06,
    use_cache=True,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    attention_bias=True,
)

qwen25_14b_config = transformers.Qwen2Config(
    vocab_size=152064,
    hidden_size=5120,
    intermediate_size=13824,
    num_hidden_layers=48,
    num_attention_heads=40,
    num_key_value_heads=8,
    hidden_act="silu",
    max_position_embeddings=32768,
    rms_norm_eps=1e-06,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    attention_bias=True,
)


qwen3_0_6b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=1024,
    intermediate_size=3072,
    num_hidden_layers=28,
    num_attention_heads=16,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
)

qwen3_1_7b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=2048,
    intermediate_size=6144,
    num_hidden_layers=28,
    num_attention_heads=16,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
)

qwen3_4b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=2560,
    intermediate_size=9728,
    num_hidden_layers=36,
    num_attention_heads=32,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
)

qwen3_8b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=4096,
    intermediate_size=12288,
    num_hidden_layers=36,
    num_attention_heads=32,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
)

qwen3_14b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=5120,
    intermediate_size=17408,
    num_hidden_layers=40,
    num_attention_heads=40,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
)

qwen3_32b_config = transformers.Qwen3Config(
    vocab_size=151936,
    hidden_size=5120,
    intermediate_size=25600,
    num_hidden_layers=64,
    num_attention_heads=64,
    num_key_value_heads=8,
    head_dim=128,
    hidden_act="silu",
    max_position_embeddings=40960,
    rms_norm_eps=1.0e-6,
    rope_theta=1000000.0,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
)


llama31_8b_config = transformers.LlamaConfig(
    vocab_size=128256,
    hidden_size=4096,
    intermediate_size=14336,
    num_hidden_layers=32,
    num_attention_heads=32,
    num_key_value_heads=8,
    max_position_embeddings=131072,
    head_dim=128,
    rms_norm_eps=1e-5,
    bos_token_id=128000,
    eos_token_id=128001,
    attention_bias=False,
    attention_dropout=0.0,
    hidden_act="silu",
    initializer_range=0.02,
    mlp_bias=False,
    model_type="llama",
    pretraining_tp=1,
    rope_scaling={
        "factor": 8.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    rope_theta=500000.0,
    tie_word_embeddings=False,
    use_cache=True,
)

llama31_70b_config = transformers.LlamaConfig(
    vocab_size=128256,
    hidden_size=8192,
    intermediate_size=28672,
    num_hidden_layers=80,
    num_attention_heads=64,
    head_dim=128,
    num_key_value_heads=8,
    max_position_embeddings=131072,
    rms_norm_eps=1e-05,
    bos_token_id=128000,
    eos_token_id=[128001, 128008, 128009],
    rope_scaling={
        "factor": 8.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    rope_theta=500000.0,
    tie_word_embeddings=False,
)

llama31_405b_config = transformers.LlamaConfig(
    vocab_size=128256,
    hidden_size=16384,
    intermediate_size=53248,
    num_hidden_layers=126,
    num_attention_heads=128,
    num_key_value_heads=8,
    head_dim=128,
    max_position_embeddings=131072,
    rms_norm_eps=1e-05,
    bos_token_id=128000,
    eos_token_id=128001,
)

qwen3_30b_a3b_thinking_2507_config = transformers.Qwen3MoeConfig(
    architectures=["Qwen3MoeForCausalLM"],
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=151643,
    decoder_sparse_step=1,
    eos_token_id=151645,
    head_dim=128,
    hidden_act="silu",
    hidden_size=2048,
    initializer_range=0.02,
    intermediate_size=6144,
    max_position_embeddings=262144,
    max_window_layers=48,
    model_type="qwen3_moe",
    moe_intermediate_size=768,
    norm_topk_prob=True,
    num_attention_heads=32,
    num_experts=128,
    num_experts_per_tok=8,
    num_hidden_layers=48,
    num_key_value_heads=4,
    output_router_logits=False,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=10000000,
    router_aux_loss_coef=0.001,
    sliding_window=None,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    use_cache=True,
    vocab_size=151936,
)

qwen3_235b_a22b_thinking_2507_config = transformers.Qwen3MoeConfig(
    architectures=["Qwen3MoeForCausalLM"],
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=151643,
    decoder_sparse_step=1,
    eos_token_id=151645,
    head_dim=128,
    hidden_act="silu",
    hidden_size=4096,
    initializer_range=0.02,
    intermediate_size=12288,
    max_position_embeddings=262144,
    max_window_layers=94,
    mlp_only_layers=[],
    model_type="qwen3_moe",
    moe_intermediate_size=1536,
    norm_topk_prob=True,
    num_attention_heads=64,
    num_experts=128,
    num_experts_per_tok=8,
    num_hidden_layers=94,
    num_key_value_heads=4,
    output_router_logits=False,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=5000000.0,
    router_aux_loss_coef=0.001,
    sliding_window=None,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    transformers_version="4.51.0",
    use_cache=True,
    vocab_size=151936,
)

qwen3_coder_480b_a35b_config = transformers.Qwen3MoeConfig(
    architectures=["Qwen3MoeForCausalLM"],
    attention_dropout=0.0,
    decoder_sparse_step=1,
    eos_token_id=151645,
    head_dim=128,
    hidden_act="silu",
    hidden_size=6144,
    initializer_range=0.02,
    intermediate_size=8192,
    max_position_embeddings=262144,
    max_window_layers=62,
    mlp_only_layers=[],
    model_type="qwen3_moe",
    moe_intermediate_size=2560,
    norm_topk_prob=True,
    num_attention_heads=96,
    num_experts=160,
    num_experts_per_tok=8,
    num_hidden_layers=62,
    num_key_value_heads=8,
    output_router_logits=False,
    qkv_bias=False,
    rms_norm_eps=1e-06,
    rope_scaling=None,
    rope_theta=10000000,
    router_aux_loss_coef=0.0,
    shared_expert_intermediate_size=0,
    sliding_window=None,
    tie_word_embeddings=False,
    torch_dtype="bfloat16",
    transformers_version="4.51.0",
    use_cache=True,
    use_qk_norm=True,
    vocab_size=151936,
)

# from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json
deepseek2_16b_dict = {
    "architectures": ["DeepseekV2ForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "auto_map": {
        "AutoConfig": "configuration_deepseek.DeepseekV2Config",
        "AutoModel": "modeling_deepseek.DeepseekV2Model",
        "AutoModelForCausalLM": "modeling_deepseek.DeepseekV2ForCausalLM",
    },
    "aux_loss_alpha": 0.001,
    "bos_token_id": 100000,
    "eos_token_id": 100001,
    "first_k_dense_replace": 1,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 10944,
    "kv_lora_rank": 512,
    "max_position_embeddings": 163840,
    "model_type": "deepseek_v2",
    "moe_intermediate_size": 1408,
    "moe_layer_freq": 1,
    "n_group": 1,
    "n_routed_experts": 64,
    "n_shared_experts": 2,
    "norm_topk_prob": False,
    "num_attention_heads": 16,
    "num_experts_per_tok": 6,
    "num_hidden_layers": 27,
    "num_key_value_heads": 16,
    "pretraining_tp": 1,
    "q_lora_rank": None,
    "qk_nope_head_dim": 128,
    "qk_rope_head_dim": 64,
    "rms_norm_eps": 1e-06,
    "rope_scaling": {
        "beta_fast": 32.0,
        "beta_slow": 1.0,
        "factor": 40.0,
        "mscale": 0.707,
        "mscale_all_dim": 0.707,
        "original_max_position_embeddings": 4096,
        "rope_theta": 10_000,
        "type": "yarn",
    },
    "rope_theta": 10_000,
    "routed_scaling_factor": 1.0,
    "scoring_func": "softmax",
    "seq_aux": True,
    "tie_word_embeddings": False,
    "topk_group": 1,
    "topk_method": "greedy",
    "torch_dtype": "bfloat16",
    "transformers_version": "4.33.1",
    "use_cache": True,
    "v_head_dim": 128,
    "vocab_size": 102400,
}
deepseek2_16b_config = transformers.DeepseekV2Config(**deepseek2_16b_dict)

# from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
# remove fp8 quantization_config, since we are using bf16
deepseek3_671b_dict = {
    "architectures": ["DeepseekV3ForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "auto_map": {
        "AutoConfig": "configuration_deepseek.DeepseekV3Config",
        "AutoModel": "modeling_deepseek.DeepseekV3Model",
        "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM",
    },
    "bos_token_id": 0,
    "eos_token_id": 1,
    "ep_size": 1,
    "first_k_dense_replace": 3,
    "hidden_act": "silu",
    "hidden_size": 7168,
    "initializer_range": 0.02,
    "intermediate_size": 18432,
    "kv_lora_rank": 512,
    "max_position_embeddings": 163840,
    "model_type": "deepseek_v3",
    "moe_intermediate_size": 2048,
    "moe_layer_freq": 1,
    "n_group": 8,
    "n_routed_experts": 256,
    "n_shared_experts": 1,
    "norm_topk_prob": True,
    "num_attention_heads": 128,
    "num_experts_per_tok": 8,
    "num_hidden_layers": 61,
    "num_key_value_heads": 128,
    "num_nextn_predict_layers": 1,
    "q_lora_rank": 1536,
    "qk_nope_head_dim": 128,
    "qk_rope_head_dim": 64,
    "rms_norm_eps": 1e-06,
    "rope_scaling": {
        "beta_fast": 32.0,
        "beta_slow": 1.0,
        "factor": 40.0,
        "mscale": 1.0,
        "mscale_all_dim": 1.0,
        "original_max_position_embeddings": 4096,
        "rope_theta": 10_000,
        "type": "yarn",
    },
    "rope_theta": 10_000,
    "routed_scaling_factor": 2.5,
    "scoring_func": "sigmoid",
    "tie_word_embeddings": False,
    "topk_group": 4,
    "topk_method": "noaux_tc",
    "torch_dtype": "bfloat16",
    "transformers_version": "4.33.1",
    "use_cache": True,
    "v_head_dim": 128,
    "vocab_size": 129280,
}
deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict)

# from https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/config.json
# remove fp8 quantization_config, since we are using bf16
deepseek32_671b_dict = {
    "architectures": ["DeepseekV32ForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 0,
    "eos_token_id": 1,
    "ep_size": 1,
    "first_k_dense_replace": 3,
    "hidden_act": "silu",
    "hidden_size": 7168,
    "index_head_dim": 128,
    "index_n_heads": 64,
    "index_topk": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 18432,
    "kv_lora_rank": 512,
    "max_position_embeddings": 163840,
    "model_type": "deepseek_v32",
    "moe_intermediate_size": 2048,
    "moe_layer_freq": 1,
    "n_group": 8,
    "n_routed_experts": 256,
    "n_shared_experts": 1,
    "norm_topk_prob": True,
    "num_attention_heads": 128,
    "num_experts_per_tok": 8,
    "num_hidden_layers": 61,
    "num_key_value_heads": 128,
    "num_nextn_predict_layers": 1,
    "q_lora_rank": 1536,
    "qk_nope_head_dim": 128,
    "qk_rope_head_dim": 64,
    "rms_norm_eps": 1e-06,
    "rope_scaling": {
        "beta_fast": 32.0,
        "beta_slow": 1.0,
        "factor": 40.0,
        "mscale": 1.0,
        "mscale_all_dim": 1.0,
        "original_max_position_embeddings": 4096,
        "rope_theta": 10_000,
        "type": "yarn",
    },
    "rope_theta": 10_000,
    "routed_scaling_factor": 2.5,
    "scoring_func": "sigmoid",
    "tie_word_embeddings": False,
    "topk_group": 4,
    "topk_method": "noaux_tc",
    "torch_dtype": "bfloat16",
    "transformers_version": "4.44.2",
    "use_cache": True,
    "v_head_dim": 128,
    "vocab_size": 129280,
}


# TODO(shuningjin): replace with DeepseekV32Config when available in transformers library
[docs] class DeepseekV32Config(PTConfig): model_type = "deepseek_v32" def __init__(self, **kwargs): self.max_position_embeddings = kwargs.get("max_position_embeddings", 163840) super().__init__(**kwargs)
deepseek32_671b_config = DeepseekV32Config(**deepseek32_671b_dict) # from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json # remove mxfp4 quantization_config, since we are using bf16 gpt_oss_20b_dict = { "architectures": ["GptOssForCausalLM"], "attention_bias": True, "attention_dropout": 0.0, "eos_token_id": 200002, "experts_per_token": 4, "head_dim": 64, "hidden_act": "silu", "hidden_size": 2880, "initial_context_length": 4096, "initializer_range": 0.02, "intermediate_size": 2880, "layer_types": [ "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", ], "max_position_embeddings": 131072, "model_type": "gpt_oss", "num_attention_heads": 64, "num_experts_per_tok": 4, "num_hidden_layers": 24, "num_key_value_heads": 8, "num_local_experts": 32, "output_router_logits": False, "pad_token_id": 199999, "rms_norm_eps": 1e-05, "rope_scaling": { "beta_fast": 32.0, "beta_slow": 1.0, "factor": 32.0, "original_max_position_embeddings": 4096, "rope_theta": 150_000, "rope_type": "yarn", "truncate": False, }, "rope_theta": 150_000, "router_aux_loss_coef": 0.9, "sliding_window": 128, "swiglu_limit": 7.0, "tie_word_embeddings": False, "transformers_version": "4.55.0.dev0", "use_cache": True, "vocab_size": 201088, } gpt_oss_20b_config = transformers.GptOssConfig(**gpt_oss_20b_dict) # from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json # remove mxfp4 quantization_config, since we are using bf16 gpt_oss_120b_dict = { "architectures": ["GptOssForCausalLM"], "attention_bias": True, "attention_dropout": 0.0, "eos_token_id": 200002, "experts_per_token": 4, "head_dim": 64, "hidden_act": "silu", "hidden_size": 2880, "initial_context_length": 4096, "initializer_range": 0.02, "intermediate_size": 2880, "layer_types": [ "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", ], "max_position_embeddings": 131072, "model_type": "gpt_oss", "num_attention_heads": 64, "num_experts_per_tok": 4, "num_hidden_layers": 36, "num_key_value_heads": 8, "num_local_experts": 128, "output_router_logits": False, "pad_token_id": 199999, "rms_norm_eps": 1e-05, "rope_scaling": { "beta_fast": 32.0, "beta_slow": 1.0, "factor": 32.0, "original_max_position_embeddings": 4096, "rope_theta": 150_000, "rope_type": "yarn", "truncate": False, }, "rope_theta": 150_000, "router_aux_loss_coef": 0.9, "sliding_window": 128, "swiglu_limit": 7.0, "tie_word_embeddings": False, "transformers_version": "4.55.0.dev0", "use_cache": True, "vocab_size": 201088, } gpt_oss_120b_config = transformers.GptOssConfig(**gpt_oss_120b_dict) qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig( # TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts architectures=["Qwen3OmniMoeForConditionalGeneration"], thinker_config={ "text_config": { "num_hidden_layers": 48, "num_experts": 128, }, "audio_config": { "encoder_layers": 32, "d_model": 1280, "encoder_attention_heads": 20, }, "vision_config": { "depth": 27, "num_heads": 16, "hidden_size": 1152, }, }, ) qwen3_next_80b_a3b_dict = { "architectures": ["Qwen3NextForCausalLM"], "attention_dropout": 0.0, "bos_token_id": 151643, "decoder_sparse_step": 1, "eos_token_id": 151645, "full_attention_interval": 4, "head_dim": 256, "hidden_act": "silu", "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": 5120, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, "linear_num_key_heads": 16, "linear_num_value_heads": 32, "linear_value_head_dim": 128, "max_position_embeddings": 262144, "mlp_only_layers": [], "model_type": "qwen3_next", "moe_intermediate_size": 512, "norm_topk_prob": True, "num_attention_heads": 16, "num_experts": 512, "num_experts_per_tok": 10, "num_hidden_layers": 48, "num_key_value_heads": 2, "output_router_logits": False, "partial_rotary_factor": 0.25, "rms_norm_eps": 1e-06, "rope_scaling": None, "rope_theta": 10000000, "router_aux_loss_coef": 0.001, "shared_expert_intermediate_size": 512, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.57.0.dev0", "use_cache": True, "vocab_size": 151936, } qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict) # from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json mixtral_8x7b_dict = { "architectures": ["MixtralForCausalLM"], "attention_dropout": 0.0, "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 32768, "model_type": "mixtral", "num_attention_heads": 32, "num_experts_per_tok": 2, "num_hidden_layers": 32, "num_key_value_heads": 8, "num_local_experts": 8, "output_router_logits": False, "rms_norm_eps": 1e-05, "rope_theta": 1000000.0, "router_aux_loss_coef": 0.02, "sliding_window": None, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.36.0.dev0", "use_cache": True, "vocab_size": 32000, } mixtral_8x7b_config = transformers.MixtralConfig(**mixtral_8x7b_dict) # from https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json mixtral_8x22b_dict = { "architectures": ["MixtralForCausalLM"], "attention_dropout": 0.0, "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 6144, "initializer_range": 0.02, "intermediate_size": 16384, "max_position_embeddings": 65536, "model_type": "mixtral", "num_attention_heads": 48, "num_experts_per_tok": 2, "num_hidden_layers": 56, "num_key_value_heads": 8, "num_local_experts": 8, "output_router_logits": False, "rms_norm_eps": 1e-05, "rope_theta": 1000000.0, "router_aux_loss_coef": 0.001, "sliding_window": None, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.38.0", "use_cache": True, "vocab_size": 32768, } mixtral_8x22b_config = transformers.MixtralConfig(**mixtral_8x22b_dict) # shared by olmo3-7b and olmo3-7b-pt (only rope_scaling/max_position_embeddings differ) olmo3_7b_dict = { "architectures": ["Olmo3ForCausalLM"], "model_type": "olmo3", "hidden_size": 4096, "num_hidden_layers": 32, "num_attention_heads": 32, "num_key_value_heads": 32, "intermediate_size": 11008, "vocab_size": 100278, "max_position_embeddings": 8192, "rope_theta": 500000, "sliding_window": 4096, "rms_norm_eps": 1.0e-6, "torch_dtype": "bfloat16", "tie_word_embeddings": False, "pad_token_id": 100277, "hidden_act": "silu", "attention_bias": False, "attention_dropout": 0.0, "use_cache": True, } olmo3_7b_config = transformers.Olmo3Config(**olmo3_7b_dict) # from https://huggingface.co/allenai/Olmo-3.1-32B-Instruct/blob/main/config.json olmo3_32b_dict = { **olmo3_7b_dict, "hidden_size": 5120, "num_hidden_layers": 64, "num_attention_heads": 40, "num_key_value_heads": 8, "intermediate_size": 27648, } olmo3_32b_config = transformers.Olmo3Config(**olmo3_32b_dict) # {maxtext model name: hf model config} HF_MODEL_CONFIGS = { "gemma2-2b": gemma2_2b_config, "gemma2-9b": gemma2_9b_config, "gemma2-27b": gemma2_27b_config, "gemma3-4b": gemma3_4b_config, "gemma3-12b": gemma3_12b_config, "gemma3-27b": gemma3_27b_config, "gemma4-26b": gemma4_26b_config, "gemma4-31b": gemma4_31b_config, "qwen2.5-1.5b": qwen25_1_5b_config, "qwen2.5-7b": qwen25_7b_config, "qwen2.5-14b": qwen25_14b_config, "qwen3-0.6b": qwen3_0_6b_config, "qwen3-1.7b": qwen3_1_7b_config, "qwen3-1.7b-base": qwen3_1_7b_config, "qwen3-4b": qwen3_4b_config, "qwen3-4b-base": qwen3_4b_config, "qwen3-4b-thinking-2507": qwen3_4b_config, "qwen3-8b": qwen3_8b_config, "qwen3-8b-base": qwen3_8b_config, "qwen3-14b": qwen3_14b_config, "qwen3-14b-base": qwen3_14b_config, "qwen3-32b": qwen3_32b_config, "llama3.1-8b": llama31_8b_config, "llama3.1-8b-Instruct": llama31_8b_config, "llama3.1-70b": llama31_70b_config, "llama3.1-405b": llama31_405b_config, "qwen3-30b-a3b": qwen3_30b_a3b_thinking_2507_config, "qwen3-30b-a3b-base": qwen3_30b_a3b_thinking_2507_config, "qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config, "qwen3-480b-a35b": qwen3_coder_480b_a35b_config, "deepseek2-16b": deepseek2_16b_config, "deepseek3-671b": deepseek3_671b_config, "deepseek3.2-671b": deepseek32_671b_config, "gpt-oss-20b": gpt_oss_20b_config, "gpt-oss-120b": gpt_oss_120b_config, "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config, "qwen3-next-80b-a3b": qwen3_next_80b_a3b_config, "mixtral-8x7b": mixtral_8x7b_config, "mixtral-8x22b": mixtral_8x22b_config, "olmo3-7b": olmo3_7b_config, "olmo3-7b-pt": olmo3_7b_config, "olmo3-32b": olmo3_32b_config, }