# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for encoder layers."""
import jax
from flax import nnx
from jax.sharding import Mesh
from maxtext.common.common_types import Config
from maxtext.layers import nnx_wrappers
from maxtext.layers import initializers
[docs]
class VisionEncoder(nnx.Module):
"""Vision encoder to encode images into soft tokens."""
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs):
self.config = config
self.mesh = mesh
self.rngs = rngs
self.encoder_name, self.projector_name = self._setup_vision_encoder_layers()
def _setup_vision_encoder_layers(self):
"""Setup vision encoder layers specific to the model, instantiate NNX modules."""
if self.config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
from maxtext.models import gemma3 # pylint: disable=import-outside-toplevel
encoder_name = "Gemma3VisionEncoderLayer_0"
projector_name = "VisionEmbedder_0"
setattr(self, encoder_name, gemma3.Gemma3VisionEncoderLayer(config=self.config, mesh=self.mesh, rngs=self.rngs))
setattr(self, projector_name, gemma3.VisionEmbedder(config=self.config, mesh=self.mesh, rngs=self.rngs))
return encoder_name, projector_name
elif self.config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
from maxtext.models import llama4 # pylint: disable=import-outside-toplevel
encoder_name = "Llama4VisionModel_0"
projector_name = "Llama4MultiModalProjector_0"
setattr(self, encoder_name, llama4.Llama4VisionModel(config=self.config, mesh=self.mesh, rngs=self.rngs))
setattr(self, projector_name, llama4.Llama4MultiModalProjector(config=self.config, mesh=self.mesh, rngs=self.rngs))
return encoder_name, projector_name
elif self.config.model_name in ["qwen3-omni-30b-a3b"]:
from maxtext.models import qwen3 # pylint: disable=import-outside-toplevel
encoder_name = "Qwen3OmniMoeVisionEncoder_0"
projector_name = "Qwen3OmniMoeVisionProjector_0"
setattr(self, encoder_name, qwen3.Qwen3OmniMoeVisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs))
setattr(self, projector_name, qwen3.Qwen3OmniMoeVisionProjector(config=self.config, rngs=self.rngs))
return encoder_name, projector_name
elif self.config.model_name in ["gemma4-26b", "gemma4-31b"]:
from maxtext.models import gemma4_vision # pylint: disable=import-outside-toplevel
encoder_name = "Gemma4VisionEncoderLayer_0"
projector_name = "Gemma4VisionProjector_0"
setattr(
self, encoder_name, gemma4_vision.Gemma4VisionEncoderLayer(config=self.config, mesh=self.mesh, rngs=self.rngs)
)
setattr(
self, projector_name, gemma4_vision.Gemma4VisionProjector(config=self.config, mesh=self.mesh, rngs=self.rngs)
)
return encoder_name, projector_name
else:
raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet")
def __call__(self, input_images, deterministic=False):
# vision encoder output, frozen params in many cases
encoder = getattr(self, self.encoder_name)
encoder_output = encoder(input_images, deterministic=deterministic)
deep_feats = None
if isinstance(encoder_output, tuple):
embeddings = encoder_output[0]
deep_feats = encoder_output[1]
else:
embeddings = encoder_output
if self.config.freeze_vision_encoder_params:
embeddings = jax.lax.stop_gradient(embeddings)
if deep_feats is not None:
deep_feats = [jax.lax.stop_gradient(feat) for feat in deep_feats]
# vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model
projector = getattr(self, self.projector_name)
embeddings = projector(embeddings)
return embeddings, deep_feats
[docs]
class AudioEncoder(nnx.Module):
"""Audio encoder to encode audio features into soft tokens."""
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs):
self.config = config
self.mesh = mesh
self.rngs = rngs
self.encoder_name, self.projector_name = self._setup_audio_encoder_layers()
def _setup_audio_encoder_layers(self):
"""Setup audio encoder layers specific to the model, instantiate NNX modules."""
if self.config.model_name in ["qwen3-omni-30b-a3b"]:
from maxtext.models import qwen3 # pylint: disable=import-outside-toplevel
encoder_name = "Qwen3OmniAudioEncoder_0"
projector_name = "Qwen3OmniAudioProjector_0"
setattr(self, encoder_name, qwen3.Qwen3OmniAudioEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs))
setattr(self, projector_name, qwen3.Qwen3OmniAudioProjector(config=self.config, rngs=self.rngs))
return encoder_name, projector_name
else:
raise ValueError(f"No AudioEncoder implemented for {self.config.model_name} yet")
def __call__(self, input_audio, deterministic=False):
# audio encoder output (includes convs + encoder, outputs before projector)
encoder = getattr(self, self.encoder_name)
embeddings = encoder(input_audio, deterministic=deterministic)
if self.config.freeze_audio_encoder_params:
embeddings = jax.lax.stop_gradient(embeddings)
# audio projector layer
projector = getattr(self, self.projector_name)
embeddings = projector(embeddings)
return embeddings
[docs]
def vision_encoder_as_linen(
config: Config,
mesh: Mesh,
):
"""Creates a VisionEncoder module."""
module = nnx_wrappers.to_linen(
VisionEncoder,
config=config,
mesh=mesh,
name="vision_encoder",
abstract_init=False,
metadata_fn=initializers.variable_to_logically_partitioned,
)
return module
[docs]
def audio_encoder_as_linen(
config: Config,
mesh: Mesh,
):
"""Creates an AudioEncoder module."""
module = nnx_wrappers.to_linen(
AudioEncoder,
config=config,
mesh=mesh,
name="audio_encoder",
abstract_init=False,
metadata_fn=initializers.variable_to_logically_partitioned,
)
return module