Source code for maxtext.integration.tunix.tunix_adapter

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Adapter for integrating MaxText Transformer models with Tunix.

This module provides the `TunixMaxTextAdapter` class, which wraps a MaxText
Transformer model to expose a call signature compatible with Tunix Trainers.
It also handles weight mapping for compatibility with Hugging Face models.
"""

from __future__ import annotations

from typing import Any, Optional, Tuple

from flax import nnx
from jax import Array
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS  # pylint: disable=ungrouped-imports
from maxtext.integration.tunix.utils import VllmWeightMapping
from maxtext.models.models import Transformer


[docs] class TunixMaxTextAdapter(nnx.Module): """Adapter exposing Tunix Trainer call signature over a Transformer model.""" def __init__( self, base_model: Transformer, use_standalone_mappings: bool = True, use_no_op_mappings: bool = False, ): super().__init__() self.base = base_model self._vllm_weight_mapping = VllmWeightMapping( self.base.config.model_name, HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(), use_standalone_mappings, ) self.use_no_op_mappings = use_no_op_mappings # ------------------------------------------------------------------ # # Tunix call signature # ------------------------------------------------------------------ # def __call__( self, input_tokens: Array, # [B, L] positions: Array, # [B, L] cache: Optional[Any], # Tunix currently passes None from Trainers attention_mask: Optional[Array], # [B, L, L] or None decoder_segment_ids: Optional[Array] = None, output_hidden_states: bool = False, # ignored ) -> Tuple[Array, None]: """Forward compatible with Tunix Trainers default loss. Returns logits, None. """ logits = self.base( decoder_input_tokens=input_tokens, decoder_positions=positions, decoder_segment_ids=decoder_segment_ids, ) return logits, None
[docs] def to_hf_mappings(self): if self.use_no_op_mappings: return {} return self._vllm_weight_mapping.to_hf_mapping()
[docs] def to_hf_transpose_keys(self): if self.use_no_op_mappings: return {} return self._vllm_weight_mapping.to_hf_transpose_keys()
[docs] def to_hf_hook_fns(self): if self.use_no_op_mappings: return {} return self._vllm_weight_mapping.to_hf_hook_fns()
[docs] def lora_to_hf_mappings(self): if self.use_no_op_mappings: return {} return self._vllm_weight_mapping.lora_to_hf_mappings()