# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""Pydantic-based configuration management for MaxText."""
import logging
import os
import sys
from typing import Any
import copy
# Disable dill to avoid conflict with gfile (dill requires buffering=0, which gfile forbids)
os.environ["HF_DATASETS_DISABLE_DILL"] = "1"
import jax
import jax.numpy as jnp
import omegaconf
from maxtext.configs import pyconfig_deprecated
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT, HF_IDS, MAXTEXT_PKG_DIR
from maxtext.common.common_types import DecoderBlockType, ShardMode
from maxtext.configs import types
from maxtext.configs.types import MaxTextConfig
from maxtext.inference.inference_utils import str2bool
from maxtext.utils import max_utils
from maxtext.utils import max_logging
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
_BASE_CONFIG_ATTR = "base_config"
_MAX_PREFIX = "M_"
_yaml_types_to_parser = {str: str, int: int, float: float, bool: str2bool}
# Don't log the following keys.
KEYS_NO_LOGGING = ("hf_access_token",)
# Module paths to their default config file (relative to MAXTEXT_CONFIGS_DIR).
_CONFIG_FILE_MAPPING: dict[str, str] = {
"maxtext.trainers.pre_train.train": "base.yml",
"maxtext.trainers.pre_train.train_compile": "base.yml",
"maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml",
"maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
"maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
"maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
"maxtext.inference.decode": "base.yml",
"maxtext.inference.decode_multi": "base.yml",
"maxtext.inference.inference_microbenchmark": "base.yml",
"maxtext.inference.inference_microbenchmark_sweep": "base.yml",
"maxtext.inference.maxengine.maxengine_server": "base.yml",
"maxtext.inference.mlperf.microbenchmarks.benchmark_chunked_prefill": "base.yml",
"maxtext.inference.vllm_decode": "base.yml",
"maxtext.checkpoint_conversion.to_maxtext": "base.yml",
"maxtext.checkpoint_conversion.to_huggingface": "base.yml",
}
def _module_from_path(path: str) -> str | None:
"""Convert a file path to module path for config inference."""
real_path = os.path.realpath(path)
pkg_parent = os.path.realpath(os.path.dirname(MAXTEXT_PKG_DIR))
if real_path.startswith(pkg_parent + os.sep):
relative = os.path.relpath(real_path, pkg_parent)
return relative.replace(os.sep, ".").removesuffix(".py")
return None
def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
"""Resolves or infers config file path from module."""
if argv is None:
argv = [""]
if kwargs.get("base_config"):
logger.info("Using config : %s", kwargs["base_config"])
return resolve_config_path(kwargs["base_config"]), argv[1:]
# if passing at least two arguments via list (no kwargs), then we have to specify
# first one as either "" or python script like train_rl.py or train.py
# the second argument is the yaml file
if len(argv) >= 2 and argv[1].endswith(".yml"):
return resolve_config_path(argv[1]), argv[2:]
module = _module_from_path(argv[0]) if len(argv) > 0 else None
if module not in _CONFIG_FILE_MAPPING:
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
logger.warning("No config file provided and no default config found for module '%s', using base.yml", module)
else:
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
logger.warning("No config file provided, using default config mapping: %s", config_path)
remaining_argv = argv[1:]
return config_path, remaining_argv
def _resolve_or_infer_addl_config(**kwargs):
"""Resolves or infers more configs from module."""
inferred_kwargs = {}
# if base_output_directory key is not seen
if not kwargs.get("base_output_directory"):
max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output")
base_output_directory = os.path.abspath("maxtext_output")
inferred_kwargs["base_output_directory"] = base_output_directory
# if hf_access_token key is not seen
if not kwargs.get("hf_access_token"):
hf_access_token = os.environ.get("HF_TOKEN")
if hf_access_token:
inferred_kwargs["hf_access_token"] = hf_access_token
return inferred_kwargs
def yaml_key_to_env_key(s: str) -> str:
return _MAX_PREFIX + s.upper()
def validate_no_keys_overridden_twice(keys1: list[str], keys2: list[str]):
overridden_keys = [k for k in keys1 if k in keys2]
if overridden_keys:
raise ValueError(
f"Keys {overridden_keys} are overridden by both model config and CLI/kwargs."
"This is not allowed, unless setting `override_model_config=True`."
)
def resolve_config_path(param: str) -> str:
"""Resolve config path to auto rewrite to use new src folder."""
if os.path.isfile(param):
return param
# For pip-installed packages, strip the src prefix and resolve against
# the installed configs directory (MAXTEXT_CONFIGS_DIR).
if param.startswith("src/maxtext/configs/"):
candidate = os.path.join(MAXTEXT_CONFIGS_DIR, param[len("src/maxtext/configs/") :])
if os.path.isfile(candidate):
return candidate
return os.path.join("src", param)
def _merge_logical_axis_rules(base_rules, new_rules):
"""Merges two lists of logical_axis_rules. Rules in new_rules override all rules
with the same name in base_rules."""
if not new_rules:
return base_rules
new_rule_keys = {rule[0] for rule in new_rules}
# Filter old rules to exclude any that will be replaced.
updated_rules = [rule for rule in base_rules if rule[0] not in new_rule_keys]
# Add all the new rules.
updated_rules.extend(new_rules)
return updated_rules
def _apply_rules(base_rules, new_rules, config):
if config.get("override_logical_axis_rules"):
return new_rules
return _merge_logical_axis_rules(base_rules, new_rules)
def _load_config(config_name: str) -> omegaconf.DictConfig:
"""Loads a YAML file and its base_configs recursively using OmegaConf."""
cfg = omegaconf.OmegaConf.load(config_name)
if _BASE_CONFIG_ATTR in cfg:
base_path = cfg[_BASE_CONFIG_ATTR]
if not os.path.isabs(base_path):
# Search relative to current config, then in the default configs folder
loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), base_path)
if not os.path.isfile(loaded_parent_config_filename):
loaded_parent_config_filename = os.path.join(MAXTEXT_CONFIGS_DIR, base_path)
else:
loaded_parent_config_filename = base_path
base_cfg = _load_config(loaded_parent_config_filename)
cfg = omegaconf.OmegaConf.merge(base_cfg, cfg)
return cfg
def _tuples_to_lists(l: list | tuple | Any) -> list | Any:
"""Recursively converts nested tuples to lists for Pydantic compatibility."""
return [_tuples_to_lists(x) for x in l] if isinstance(l, (list, tuple)) else l
def _lists_to_tuples(l: list | Any) -> tuple | Any:
"""Recursively converts nested lists to tuples for JAX compatibility."""
return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l
def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
"""Prepares the raw dictionary for Pydantic model instantiation."""
pydantic_kwargs = {}
valid_fields = types.MaxTextConfig.model_fields.keys()
for key, value in raw_keys.items():
if key not in valid_fields:
logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.")
new_value = value
if isinstance(new_value, str) and new_value.lower() == "none":
new_value = None
# Pydantic validates enums from their values, so string is fine.
# It also handles type coercion for simple types.
if key in ("logical_axis_rules", "data_sharding"):
if isinstance(new_value, tuple):
new_value = _tuples_to_lists(new_value)
if key == "data_sharding" and isinstance(new_value, list) and new_value and isinstance(new_value[0], str):
new_value = [new_value]
# An empty value provided in the configuration is treated as None
if (
key
in (
"hf_train_files",
"hf_eval_files",
"hf_access_token",
"hf_name",
"hf_data_dir",
"hf_eval_split",
"tokenizer_path",
)
and new_value == ""
):
new_value = None
if key == "run_name" and new_value is None:
new_value = ""
if key in ("dump_hlo_local_module_name", "dump_hlo_module_name") and new_value is None:
new_value = ""
if key == "tokenizer_path" and new_value is None:
try:
new_value = HF_IDS[raw_keys["model_name"]]
except KeyError:
new_value = os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers/tokenizer.llama2")
max_logging.warning(
"tokenizer_path not found in HF_IDS in maxtext/src/maxtext/utils/globals.py. \
Using the default src/maxtext/assets/tokenizers/tokenizer.llama2 instead. \
Please pass tokenizer_path in your command if this is not intended."
)
pydantic_kwargs[key] = new_value
return pydantic_kwargs
class HyperParameters:
"""
Wrapper class to expose the configuration in a read-only manner,
maintaining backward compatibility with attribute-style access and JAX object types.
"""
def __init__(self, pydantic_config: types.MaxTextConfig):
object.__setattr__(self, "_pydantic_config", pydantic_config)
final_dict = pydantic_config.model_dump()
final_dict["dtype"] = jnp.dtype(final_dict["dtype"])
final_dict["grad_dtype"] = jnp.dtype(final_dict["grad_dtype"])
final_dict["weight_dtype"] = jnp.dtype(final_dict["weight_dtype"])
final_dict["mu_dtype"] = (
final_dict["weight_dtype"] if not final_dict["mu_dtype"] else jnp.dtype(final_dict["mu_dtype"])
)
final_dict["logical_axis_rules"] = _lists_to_tuples(final_dict["logical_axis_rules"])
final_dict["data_sharding"] = _lists_to_tuples(final_dict["data_sharding"])
final_dict["decoder_block"] = DecoderBlockType(final_dict["decoder_block"])
final_dict["shard_mode"] = ShardMode(final_dict["shard_mode"])
object.__setattr__(self, "_flat_config", final_dict)
def __deepcopy__(self, memo):
new_pydantic_config = copy.deepcopy(self._pydantic_config, memo)
return HyperParameters(new_pydantic_config)
def tree_flatten(self):
return (), self
def __getattr__(self, attr: str) -> Any:
"""Provides attribute-style access to the final configuration dictionary."""
# Use object.__getattribute__ to avoid recursion when accessing _flat_config
# This is necessary for proper pickling/unpickling support
flat_config = object.__getattribute__(self, "_flat_config")
if attr in flat_config:
return flat_config[attr]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")
def __setattr__(self, attr: str, value: Any) -> None:
"""Makes the configuration object read-only."""
raise ValueError("Configuration is read-only and cannot be modified after initialization.")
def get_keys(self) -> dict[str, Any]:
"""Returns the configuration as a flat dictionary for backward compatibility."""
return self._flat_config
[docs]
def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters:
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
pydantic_config = initialize_pydantic(argv, **kwargs)
config = HyperParameters(pydantic_config)
return config
[docs]
def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig:
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
"""
# 1. Load base and inherited configs from file(s)
config_path, cli_args = _resolve_or_infer_config(argv, **kwargs)
base_yml_config = _load_config(config_path)
# 2. Get overrides from CLI and kwargs
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
# 3.1. infer more configs if possible
temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1)
# update overrides_cfg with temp_cfg1
overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1)
temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
# 3.2. Handle model-specific config
model_name = temp_cfg.get("model_name", "default")
# The architecture for -Instruct v/s base models are the same, so for identifying the
# architecture we replace "-Instruct" from the model_name and get the base model name
model_name = model_name.replace("-Instruct", "") if "-Instruct" in model_name else model_name
model_cfg = {}
if model_name != "default":
# First try relative to base config path
model_config_path = os.path.join(os.path.dirname(config_path), "models", f"{model_name}.yml")
# Try looking for "models" under "src/maxtext/configs/"
if not os.path.isfile(model_config_path):
model_config_path = os.path.join(
os.path.dirname(os.path.dirname(config_path)),
"models",
f"{model_name}.yml",
)
if not os.path.isfile(model_config_path):
# Fallback to the default location within package
dir_path = os.path.dirname(os.path.realpath(__file__))
model_config_path = os.path.join(dir_path, "models", f"{model_name}.yml")
if os.path.exists(model_config_path):
model_loaded_cfg = omegaconf.OmegaConf.load(model_config_path)
# if override_model_config=True, only apply model configs for keys not present in overrides.
if temp_cfg.get("override_model_config"):
model_cfg = {k: v for k, v in model_loaded_cfg.items() if k not in overrides_cfg}
else:
model_cfg = model_loaded_cfg
# Validate that no keys are overridden by both model config and CLI/kwargs
validate_no_keys_overridden_twice(model_loaded_cfg.keys(), overrides_cfg.keys())
else:
logger.warning("Model config for '%s' not found at %s", model_name, model_config_path)
# Finally merge (base, model, then overrides)
model_cfg_oc = omegaconf.OmegaConf.create(model_cfg)
# 4. Manually merge logical_axis_rules to avoid OmegaConf's list replacement behavior.
base_rules_oc = base_yml_config.get("logical_axis_rules", [])
model_rules_oc = model_cfg_oc.get("logical_axis_rules", [])
overrides_rules_oc = overrides_cfg.get("logical_axis_rules", [])
base_rules = omegaconf.OmegaConf.to_container(base_rules_oc, resolve=True) if base_rules_oc else []
model_rules = omegaconf.OmegaConf.to_container(model_rules_oc, resolve=True) if model_rules_oc else []
overrides_rules = omegaconf.OmegaConf.to_container(overrides_rules_oc, resolve=True) if overrides_rules_oc else []
merged_rules = _apply_rules(base_rules, model_rules, model_cfg_oc)
merged_rules = _apply_rules(merged_rules, overrides_rules, overrides_cfg)
# Remove the rules from the original configs before the main merge
if "logical_axis_rules" in base_yml_config:
del base_yml_config["logical_axis_rules"]
if "logical_axis_rules" in model_cfg_oc:
del model_cfg_oc["logical_axis_rules"]
if "logical_axis_rules" in overrides_cfg:
del overrides_cfg["logical_axis_rules"]
# 5. Final merge for all other keys
final_config = omegaconf.OmegaConf.merge(base_yml_config, model_cfg_oc, overrides_cfg)
final_config["logical_axis_rules"] = merged_rules
raw_keys_dict = omegaconf.OmegaConf.to_container(final_config, resolve=True)
# 6. Handle environment variable overrides
cli_keys = frozenset(omegaconf.OmegaConf.to_container(cli_cfg, resolve=True).keys())
kwargs_keys = frozenset(kwargs.keys())
for k in tuple(raw_keys_dict.keys()):
env_key = yaml_key_to_env_key(k)
if env_key in os.environ:
# Validate that no keys are overridden by both CLI/kwargs and environment variable
if k in cli_keys or k in kwargs_keys:
raise ValueError(
f"Key '{k}' is overridden by both CLI/kwargs and environment variable '{env_key}'. This is not allowed."
)
# Validate that no keys are overridden by both model config and environment variable
if not temp_cfg.get("override_model_config") and k in model_cfg.keys():
raise ValueError(
f"Key '{k}' is overridden by both model config and environment variable '{env_key}'."
"This is not allowed, unless setting `override_model_config=True`."
)
new_proposal = os.environ.get(env_key)
original_value = raw_keys_dict.get(k)
parser = None
if isinstance(original_value, bool):
parser = _yaml_types_to_parser[bool]
elif isinstance(original_value, (str, int, float)):
parser = type(original_value)
if parser is None:
raise TypeError(f"Type {type(original_value)} for key '{k}' not supported for ENV override.")
try:
raw_keys_dict[k] = parser(new_proposal)
except (ValueError, KeyError) as e:
raise ValueError(f"Couldn't parse value from ENV '{new_proposal}' for key '{k}'") from e
pydantic_kwargs = _prepare_for_pydantic(raw_keys_dict)
if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get("use_jax_splash"):
raise ValueError("At most one of `use_tokamax_splash` and `use_jax_splash` can be set to True.")
# Initialize JAX distributed system before device backend is initialized.
if pydantic_kwargs.get("jax_debug_log_modules"):
jax.config.update("jax_debug_log_modules", pydantic_kwargs["jax_debug_log_modules"])
# Do not initialize jax distributed system during pytest runs.
if "pytest" not in sys.modules:
max_utils.maybe_initialize_jax_distributed_system(pydantic_kwargs)
if pydantic_kwargs.get("jax_cache_dir"):
from jax.experimental.compilation_cache import compilation_cache # pylint: disable=import-outside-toplevel
compilation_cache.set_cache_dir(os.path.expanduser(pydantic_kwargs["jax_cache_dir"]))
pydantic_config = types.MaxTextConfig(**pydantic_kwargs)
config = HyperParameters(pydantic_config)
if config.log_config:
for k, v in sorted(config.get_keys().items()):
if k not in KEYS_NO_LOGGING:
logger.info("Config param %s: %s", k, v)
return pydantic_config
# Shim for backward compatibility with pyconfig_deprecated_test.py
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
__all__ = ["initialize", "initialize_pydantic"]
class _CallablePyconfigModule(sys.modules[__name__].__class__):
"""Allows calling the module directly as mt.pyconfig()."""
def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters:
return initialize(argv, **kwargs)
sys.modules[__name__].__class__ = _CallablePyconfigModule