Source code for maxtext.integration.tunix.tunix_adapter
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adapter for integrating MaxText Transformer models with Tunix.
This module provides the `TunixMaxTextAdapter` class, which wraps a MaxText
Transformer model to expose a call signature compatible with Tunix Trainers.
It also handles weight mapping for compatibility with Hugging Face models.
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
from flax import nnx
from jax import Array
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports
from maxtext.integration.tunix.utils import VllmWeightMapping
from maxtext.models.models import Transformer
[docs]
class TunixMaxTextAdapter(nnx.Module):
"""Adapter exposing Tunix Trainer call signature over a Transformer model."""
def __init__(
self,
base_model: Transformer,
use_standalone_mappings: bool = True,
use_no_op_mappings: bool = False,
):
super().__init__()
self.base = base_model
self._vllm_weight_mapping = VllmWeightMapping(
self.base.config.model_name,
HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(),
use_standalone_mappings,
)
self.use_no_op_mappings = use_no_op_mappings
# ------------------------------------------------------------------ #
# Tunix call signature
# ------------------------------------------------------------------ #
def __call__(
self,
input_tokens: Array, # [B, L]
positions: Array, # [B, L]
cache: Optional[Any], # Tunix currently passes None from Trainers
attention_mask: Optional[Array], # [B, L, L] or None
decoder_segment_ids: Optional[Array] = None,
output_hidden_states: bool = False, # ignored
) -> Tuple[Array, None]:
"""Forward compatible with Tunix Trainers default loss.
Returns logits, None.
"""
logits = self.base(
decoder_input_tokens=input_tokens,
decoder_positions=positions,
decoder_segment_ids=decoder_segment_ids,
)
return logits, None
[docs]
def to_hf_mappings(self):
if self.use_no_op_mappings:
return {}
return self._vllm_weight_mapping.to_hf_mapping()
[docs]
def to_hf_transpose_keys(self):
if self.use_no_op_mappings:
return {}
return self._vllm_weight_mapping.to_hf_transpose_keys()
[docs]
def to_hf_hook_fns(self):
if self.use_no_op_mappings:
return {}
return self._vllm_weight_mapping.to_hf_hook_fns()
[docs]
def lora_to_hf_mappings(self):
if self.use_no_op_mappings:
return {}
return self._vllm_weight_mapping.lora_to_hf_mappings()