maxtext.integration.vllm.maxtext_vllm_adapter.adapter module#

vLLM adapter for MaxText models.

maxtext.integration.vllm.maxtext_vllm_adapter.adapter.next_power_of_two(x)[source]#

Finds the smallest power of 2 >= x using bit manipulation.

Parameters:

x (int) – The input number (should be an integer).

Returns:

The smallest integer power of 2 that is >= x.

Return type:

int

maxtext.integration.vllm.maxtext_vllm_adapter.adapter.generate_maxtext_config(vllm_config, mesh)[source]#

Generates a MaxText configuration from a vLLM configuration.

This function takes a vLLM configuration object and translates relevant parameters into a MaxText HyperParameters object. It handles loading paths and model names from the vLLM config, and applies a base MaxText vLLM configuration file.

Parameters:
  • vllm_config (vllm.config.VllmConfig) – The vLLM configuration object containing model and load parameters.

  • mesh (Mesh) – The JAX mesh device for model sharding.

Returns:

A pyconfig.HyperParameters object configured for MaxText.

Raises:

ValueError – If hf_config_path is not provided in the vLLM model config.

Return type:

HyperParameters

class maxtext.integration.vllm.maxtext_vllm_adapter.adapter.MaxTextForCausalLM(*args, **kwargs)[source]#

Bases: Module

A vLLM-compatible causal language model wrapper for MaxText.

This class serves as the primary interface for integrating MaxText models into the vLLM serving framework, specifically for causal language modeling tasks. It handles configuration generation, model initialization, and execution of the decoding step.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

forward(*args, **kwargs)[source]#

Alias for __call__ for compatibility.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The result of the __call__ method.

get_input_embeddings()[source]#

Returns the input embeddings of the model.

Returns:

A JAX array representing the input embeddings.

Return type:

Array

embed_input_ids(input_ids)[source]#

Embeds the input token IDs using the model’s token embedder.

Parameters:

input_ids (Array) – A JAX array of input token IDs.

Returns:

A JAX array of embedded input tokens.

Return type:

Array

compute_logits(hidden_states)[source]#

Computes the logits from the hidden states using the underlying decoder model.

Parameters:

hidden_states (Array) – A JAX array of hidden states.

Returns:

A JAX array of logits.

Return type:

Array

load_weights(rng_key)[source]#

Loads model weights using the underlying decoder model.

Parameters:

rng_key (Array) – A JAX random key for model initialization.

Return type:

None