Source code for maxtext.configs.pyconfig

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pytype: skip-file
"""Pydantic-based configuration management for MaxText."""
import logging
import os
import sys
from typing import Any
import copy

# Disable dill to avoid conflict with gfile (dill requires buffering=0, which gfile forbids)
os.environ["HF_DATASETS_DISABLE_DILL"] = "1"

import jax
import jax.numpy as jnp

import omegaconf

from maxtext.configs import pyconfig_deprecated
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT, HF_IDS, MAXTEXT_PKG_DIR
from maxtext.common.common_types import DecoderBlockType, ShardMode
from maxtext.configs import types
from maxtext.configs.types import MaxTextConfig
from maxtext.inference.inference_utils import str2bool
from maxtext.utils import max_utils
from maxtext.utils import max_logging

logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))

_BASE_CONFIG_ATTR = "base_config"
_MAX_PREFIX = "M_"
_yaml_types_to_parser = {str: str, int: int, float: float, bool: str2bool}

# Don't log the following keys.
KEYS_NO_LOGGING = ("hf_access_token",)

# Module paths to their default config file (relative to MAXTEXT_CONFIGS_DIR).
_CONFIG_FILE_MAPPING: dict[str, str] = {
    "maxtext.trainers.pre_train.train": "base.yml",
    "maxtext.trainers.pre_train.train_compile": "base.yml",
    "maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml",
    "maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
    "maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
    "maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
    "maxtext.inference.decode": "base.yml",
    "maxtext.inference.decode_multi": "base.yml",
    "maxtext.inference.inference_microbenchmark": "base.yml",
    "maxtext.inference.inference_microbenchmark_sweep": "base.yml",
    "maxtext.inference.maxengine.maxengine_server": "base.yml",
    "maxtext.inference.mlperf.microbenchmarks.benchmark_chunked_prefill": "base.yml",
    "maxtext.inference.vllm_decode": "base.yml",
    "maxtext.checkpoint_conversion.to_maxtext": "base.yml",
    "maxtext.checkpoint_conversion.to_huggingface": "base.yml",
}


def _module_from_path(path: str) -> str | None:
  """Convert a file path to module path for config inference."""
  real_path = os.path.realpath(path)
  pkg_parent = os.path.realpath(os.path.dirname(MAXTEXT_PKG_DIR))
  if real_path.startswith(pkg_parent + os.sep):
    relative = os.path.relpath(real_path, pkg_parent)
    return relative.replace(os.sep, ".").removesuffix(".py")
  return None


def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
  """Resolves or infers config file path from module."""
  if argv is None:
    argv = [""]

  if kwargs.get("base_config"):
    logger.info("Using config : %s", kwargs["base_config"])
    return resolve_config_path(kwargs["base_config"]), argv[1:]

  # if passing at least two arguments via list (no kwargs), then we have to specify
  # first one as either "" or python script like train_rl.py or train.py
  # the second argument is the yaml file
  if len(argv) >= 2 and argv[1].endswith(".yml"):
    return resolve_config_path(argv[1]), argv[2:]
  module = _module_from_path(argv[0]) if len(argv) > 0 else None
  if module not in _CONFIG_FILE_MAPPING:
    config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
    logger.warning("No config file provided and no default config found for module '%s', using base.yml", module)
  else:
    config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
    logger.warning("No config file provided, using default config mapping: %s", config_path)
  remaining_argv = argv[1:]

  return config_path, remaining_argv


def _resolve_or_infer_addl_config(**kwargs):
  """Resolves or infers more configs from module."""
  inferred_kwargs = {}
  # if base_output_directory key is not seen
  if not kwargs.get("base_output_directory"):
    max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output")
    base_output_directory = os.path.abspath("maxtext_output")
    inferred_kwargs["base_output_directory"] = base_output_directory

  # if hf_access_token key is not seen
  if not kwargs.get("hf_access_token"):
    hf_access_token = os.environ.get("HF_TOKEN")
    if hf_access_token:
      inferred_kwargs["hf_access_token"] = hf_access_token

  return inferred_kwargs


def yaml_key_to_env_key(s: str) -> str:
  return _MAX_PREFIX + s.upper()


def validate_no_keys_overridden_twice(keys1: list[str], keys2: list[str]):
  overridden_keys = [k for k in keys1 if k in keys2]
  if overridden_keys:
    raise ValueError(
        f"Keys {overridden_keys} are overridden by both model config and CLI/kwargs."
        "This is not allowed, unless setting `override_model_config=True`."
    )


def resolve_config_path(param: str) -> str:
  """Resolve config path to auto rewrite to use new src folder."""
  if os.path.isfile(param):
    return param
  # For pip-installed packages, strip the src prefix and resolve against
  # the installed configs directory (MAXTEXT_CONFIGS_DIR).
  if param.startswith("src/maxtext/configs/"):
    candidate = os.path.join(MAXTEXT_CONFIGS_DIR, param[len("src/maxtext/configs/") :])
    if os.path.isfile(candidate):
      return candidate
  return os.path.join("src", param)


def _merge_logical_axis_rules(base_rules, new_rules):
  """Merges two lists of logical_axis_rules. Rules in new_rules override all rules
  with the same name in base_rules."""
  if not new_rules:
    return base_rules

  new_rule_keys = {rule[0] for rule in new_rules}

  # Filter old rules to exclude any that will be replaced.
  updated_rules = [rule for rule in base_rules if rule[0] not in new_rule_keys]

  # Add all the new rules.
  updated_rules.extend(new_rules)
  return updated_rules


def _apply_rules(base_rules, new_rules, config):
  if config.get("override_logical_axis_rules"):
    return new_rules
  return _merge_logical_axis_rules(base_rules, new_rules)


def _load_config(config_name: str) -> omegaconf.DictConfig:
  """Loads a YAML file and its base_configs recursively using OmegaConf."""
  cfg = omegaconf.OmegaConf.load(config_name)
  if _BASE_CONFIG_ATTR in cfg:
    base_path = cfg[_BASE_CONFIG_ATTR]
    if not os.path.isabs(base_path):
      # Search relative to current config, then in the default configs folder
      loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), base_path)
      if not os.path.isfile(loaded_parent_config_filename):
        loaded_parent_config_filename = os.path.join(MAXTEXT_CONFIGS_DIR, base_path)
    else:
      loaded_parent_config_filename = base_path

    base_cfg = _load_config(loaded_parent_config_filename)
    cfg = omegaconf.OmegaConf.merge(base_cfg, cfg)
  return cfg


def _tuples_to_lists(l: list | tuple | Any) -> list | Any:
  """Recursively converts nested tuples to lists for Pydantic compatibility."""
  return [_tuples_to_lists(x) for x in l] if isinstance(l, (list, tuple)) else l


def _lists_to_tuples(l: list | Any) -> tuple | Any:
  """Recursively converts nested lists to tuples for JAX compatibility."""
  return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l


def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
  """Prepares the raw dictionary for Pydantic model instantiation."""
  pydantic_kwargs = {}
  valid_fields = types.MaxTextConfig.model_fields.keys()

  for key, value in raw_keys.items():
    if key not in valid_fields:
      logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
      raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.")

    new_value = value
    if isinstance(new_value, str) and new_value.lower() == "none":
      new_value = None

    # Pydantic validates enums from their values, so string is fine.
    # It also handles type coercion for simple types.
    if key in ("logical_axis_rules", "data_sharding"):
      if isinstance(new_value, tuple):
        new_value = _tuples_to_lists(new_value)
      if key == "data_sharding" and isinstance(new_value, list) and new_value and isinstance(new_value[0], str):
        new_value = [new_value]

    # An empty value provided in the configuration is treated as None
    if (
        key
        in (
            "hf_train_files",
            "hf_eval_files",
            "hf_access_token",
            "hf_name",
            "hf_data_dir",
            "hf_eval_split",
            "tokenizer_path",
        )
        and new_value == ""
    ):
      new_value = None

    if key == "run_name" and new_value is None:
      new_value = ""

    if key in ("dump_hlo_local_module_name", "dump_hlo_module_name") and new_value is None:
      new_value = ""

    if key == "tokenizer_path" and new_value is None:
      try:
        new_value = HF_IDS[raw_keys["model_name"]]
      except KeyError:
        new_value = os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers/tokenizer.llama2")
        max_logging.warning(
            "tokenizer_path not found in HF_IDS in maxtext/src/maxtext/utils/globals.py. \
          Using the default src/maxtext/assets/tokenizers/tokenizer.llama2 instead. \
          Please pass tokenizer_path in your command if this is not intended."
        )

    pydantic_kwargs[key] = new_value

  return pydantic_kwargs


class HyperParameters:
  """
  Wrapper class to expose the configuration in a read-only manner,
  maintaining backward compatibility with attribute-style access and JAX object types.
  """

  def __init__(self, pydantic_config: types.MaxTextConfig):
    object.__setattr__(self, "_pydantic_config", pydantic_config)

    final_dict = pydantic_config.model_dump()
    final_dict["dtype"] = jnp.dtype(final_dict["dtype"])
    final_dict["grad_dtype"] = jnp.dtype(final_dict["grad_dtype"])
    final_dict["weight_dtype"] = jnp.dtype(final_dict["weight_dtype"])
    final_dict["mu_dtype"] = (
        final_dict["weight_dtype"] if not final_dict["mu_dtype"] else jnp.dtype(final_dict["mu_dtype"])
    )

    final_dict["logical_axis_rules"] = _lists_to_tuples(final_dict["logical_axis_rules"])
    final_dict["data_sharding"] = _lists_to_tuples(final_dict["data_sharding"])

    final_dict["decoder_block"] = DecoderBlockType(final_dict["decoder_block"])
    final_dict["shard_mode"] = ShardMode(final_dict["shard_mode"])

    object.__setattr__(self, "_flat_config", final_dict)

  def __deepcopy__(self, memo):
    new_pydantic_config = copy.deepcopy(self._pydantic_config, memo)
    return HyperParameters(new_pydantic_config)

  def tree_flatten(self):
    return (), self

  def __getattr__(self, attr: str) -> Any:
    """Provides attribute-style access to the final configuration dictionary."""
    # Use object.__getattribute__ to avoid recursion when accessing _flat_config
    # This is necessary for proper pickling/unpickling support
    flat_config = object.__getattribute__(self, "_flat_config")
    if attr in flat_config:
      return flat_config[attr]
    raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")

  def __setattr__(self, attr: str, value: Any) -> None:
    """Makes the configuration object read-only."""
    raise ValueError("Configuration is read-only and cannot be modified after initialization.")

  def get_keys(self) -> dict[str, Any]:
    """Returns the configuration as a flat dictionary for backward compatibility."""
    return self._flat_config


[docs] def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters: """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.""" pydantic_config = initialize_pydantic(argv, **kwargs) config = HyperParameters(pydantic_config) return config
[docs] def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig: """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides. Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` """ # 1. Load base and inherited configs from file(s) config_path, cli_args = _resolve_or_infer_config(argv, **kwargs) base_yml_config = _load_config(config_path) # 2. Get overrides from CLI and kwargs cli_cfg = omegaconf.OmegaConf.from_cli(cli_args) kwargs_cfg = omegaconf.OmegaConf.create(kwargs) overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg) temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) # 3.1. infer more configs if possible temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1) # update overrides_cfg with temp_cfg1 overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1) temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) # 3.2. Handle model-specific config model_name = temp_cfg.get("model_name", "default") # The architecture for -Instruct v/s base models are the same, so for identifying the # architecture we replace "-Instruct" from the model_name and get the base model name model_name = model_name.replace("-Instruct", "") if "-Instruct" in model_name else model_name model_cfg = {} if model_name != "default": # First try relative to base config path model_config_path = os.path.join(os.path.dirname(config_path), "models", f"{model_name}.yml") # Try looking for "models" under "src/maxtext/configs/" if not os.path.isfile(model_config_path): model_config_path = os.path.join( os.path.dirname(os.path.dirname(config_path)), "models", f"{model_name}.yml", ) if not os.path.isfile(model_config_path): # Fallback to the default location within package dir_path = os.path.dirname(os.path.realpath(__file__)) model_config_path = os.path.join(dir_path, "models", f"{model_name}.yml") if os.path.exists(model_config_path): model_loaded_cfg = omegaconf.OmegaConf.load(model_config_path) # if override_model_config=True, only apply model configs for keys not present in overrides. if temp_cfg.get("override_model_config"): model_cfg = {k: v for k, v in model_loaded_cfg.items() if k not in overrides_cfg} else: model_cfg = model_loaded_cfg # Validate that no keys are overridden by both model config and CLI/kwargs validate_no_keys_overridden_twice(model_loaded_cfg.keys(), overrides_cfg.keys()) else: logger.warning("Model config for '%s' not found at %s", model_name, model_config_path) # Finally merge (base, model, then overrides) model_cfg_oc = omegaconf.OmegaConf.create(model_cfg) # 4. Manually merge logical_axis_rules to avoid OmegaConf's list replacement behavior. base_rules_oc = base_yml_config.get("logical_axis_rules", []) model_rules_oc = model_cfg_oc.get("logical_axis_rules", []) overrides_rules_oc = overrides_cfg.get("logical_axis_rules", []) base_rules = omegaconf.OmegaConf.to_container(base_rules_oc, resolve=True) if base_rules_oc else [] model_rules = omegaconf.OmegaConf.to_container(model_rules_oc, resolve=True) if model_rules_oc else [] overrides_rules = omegaconf.OmegaConf.to_container(overrides_rules_oc, resolve=True) if overrides_rules_oc else [] merged_rules = _apply_rules(base_rules, model_rules, model_cfg_oc) merged_rules = _apply_rules(merged_rules, overrides_rules, overrides_cfg) # Remove the rules from the original configs before the main merge if "logical_axis_rules" in base_yml_config: del base_yml_config["logical_axis_rules"] if "logical_axis_rules" in model_cfg_oc: del model_cfg_oc["logical_axis_rules"] if "logical_axis_rules" in overrides_cfg: del overrides_cfg["logical_axis_rules"] # 5. Final merge for all other keys final_config = omegaconf.OmegaConf.merge(base_yml_config, model_cfg_oc, overrides_cfg) final_config["logical_axis_rules"] = merged_rules raw_keys_dict = omegaconf.OmegaConf.to_container(final_config, resolve=True) # 6. Handle environment variable overrides cli_keys = frozenset(omegaconf.OmegaConf.to_container(cli_cfg, resolve=True).keys()) kwargs_keys = frozenset(kwargs.keys()) for k in tuple(raw_keys_dict.keys()): env_key = yaml_key_to_env_key(k) if env_key in os.environ: # Validate that no keys are overridden by both CLI/kwargs and environment variable if k in cli_keys or k in kwargs_keys: raise ValueError( f"Key '{k}' is overridden by both CLI/kwargs and environment variable '{env_key}'. This is not allowed." ) # Validate that no keys are overridden by both model config and environment variable if not temp_cfg.get("override_model_config") and k in model_cfg.keys(): raise ValueError( f"Key '{k}' is overridden by both model config and environment variable '{env_key}'." "This is not allowed, unless setting `override_model_config=True`." ) new_proposal = os.environ.get(env_key) original_value = raw_keys_dict.get(k) parser = None if isinstance(original_value, bool): parser = _yaml_types_to_parser[bool] elif isinstance(original_value, (str, int, float)): parser = type(original_value) if parser is None: raise TypeError(f"Type {type(original_value)} for key '{k}' not supported for ENV override.") try: raw_keys_dict[k] = parser(new_proposal) except (ValueError, KeyError) as e: raise ValueError(f"Couldn't parse value from ENV '{new_proposal}' for key '{k}'") from e pydantic_kwargs = _prepare_for_pydantic(raw_keys_dict) if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get("use_jax_splash"): raise ValueError("At most one of `use_tokamax_splash` and `use_jax_splash` can be set to True.") # Initialize JAX distributed system before device backend is initialized. if pydantic_kwargs.get("jax_debug_log_modules"): jax.config.update("jax_debug_log_modules", pydantic_kwargs["jax_debug_log_modules"]) # Do not initialize jax distributed system during pytest runs. if "pytest" not in sys.modules: max_utils.maybe_initialize_jax_distributed_system(pydantic_kwargs) if pydantic_kwargs.get("jax_cache_dir"): from jax.experimental.compilation_cache import compilation_cache # pylint: disable=import-outside-toplevel compilation_cache.set_cache_dir(os.path.expanduser(pydantic_kwargs["jax_cache_dir"])) pydantic_config = types.MaxTextConfig(**pydantic_kwargs) config = HyperParameters(pydantic_config) if config.log_config: for k, v in sorted(config.get_keys().items()): if k not in KEYS_NO_LOGGING: logger.info("Config param %s: %s", k, v) return pydantic_config
# Shim for backward compatibility with pyconfig_deprecated_test.py validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys __all__ = ["initialize", "initialize_pydantic"] class _CallablePyconfigModule(sys.modules[__name__].__class__): """Allows calling the module directly as mt.pyconfig().""" def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters: return initialize(argv, **kwargs) sys.modules[__name__].__class__ = _CallablePyconfigModule