maxtext.integration.vllm.maxtext_vllm_adapter.adapter module#
vLLM adapter for MaxText models.
- maxtext.integration.vllm.maxtext_vllm_adapter.adapter.next_power_of_two(x)[source]#
Finds the smallest power of 2 >= x using bit manipulation.
- Parameters:
x (int) – The input number (should be an integer).
- Returns:
The smallest integer power of 2 that is >= x.
- Return type:
int
- maxtext.integration.vllm.maxtext_vllm_adapter.adapter.generate_maxtext_config(vllm_config, mesh)[source]#
Generates a MaxText configuration from a vLLM configuration.
This function takes a vLLM configuration object and translates relevant parameters into a MaxText HyperParameters object. It handles loading paths and model names from the vLLM config, and applies a base MaxText vLLM configuration file.
- Parameters:
vllm_config (vllm.config.VllmConfig) – The vLLM configuration object containing model and load parameters.
mesh (Mesh) – The JAX mesh device for model sharding.
- Returns:
A pyconfig.HyperParameters object configured for MaxText.
- Raises:
ValueError – If hf_config_path is not provided in the vLLM model config.
- Return type:
HyperParameters
- class maxtext.integration.vllm.maxtext_vllm_adapter.adapter.MaxTextForCausalLM(*args, **kwargs)[source]#
Bases:
ModuleA vLLM-compatible causal language model wrapper for MaxText.
This class serves as the primary interface for integrating MaxText models into the vLLM serving framework, specifically for causal language modeling tasks. It handles configuration generation, model initialization, and execution of the decoding step.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- forward(*args, **kwargs)[source]#
Alias for __call__ for compatibility.
- Parameters:
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
The result of the __call__ method.
- get_input_embeddings()[source]#
Returns the input embeddings of the model.
- Returns:
A JAX array representing the input embeddings.
- Return type:
Array
- embed_input_ids(input_ids)[source]#
Embeds the input token IDs using the model’s token embedder.
- Parameters:
input_ids (Array) – A JAX array of input token IDs.
- Returns:
A JAX array of embedded input tokens.
- Return type:
Array