Source code for maxtext.input_pipeline.synthetic_data_processing

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input pipeline for synthetic dataset."""

from collections.abc import Callable
from typing import Any

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

from maxtext.input_pipeline import multihost_dataloading
from maxtext.configs import pyconfig
from maxtext.utils import sharding


[docs] class SyntheticDataIterator: """Creates a synthetic data iterator for performance testing work""" data_generator: Callable[[pyconfig.HyperParameters, tuple[Any, ...]], dict] def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = sharding.remove_size_one_mesh_axis(P(*config.data_sharding), mesh) data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) self.data_generator = jax.jit( SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0 ) tokens = jax.random.randint( jax.random.PRNGKey(0), (config.global_batch_size_to_load, config.max_target_length + 1), 0, config.vocab_size, dtype=jnp.int32, ) sequence_positions = jnp.arange(0, config.max_target_length + 1, dtype=jnp.int32).reshape(1, -1) batch_positions = jnp.broadcast_to( sequence_positions, (config.global_batch_size_to_load, config.max_target_length + 1) ) segmentation = jnp.ones((config.global_batch_size_to_load, config.max_target_length), dtype=jnp.int32) self.data = (tokens, batch_positions, segmentation)
[docs] def reset(self): pass # Synthetic data is stateless; nothing to reset.
def __iter__(self): return self def __next__(self): with self.mesh: return self.data_generator(self.config, self.data) # pylint: disable=not-callable
[docs] @staticmethod def raw_generate_synthetic_data(config: pyconfig.HyperParameters, data): """Generates a single batch of synthetic data""" tokens, positions, segmentation = data output = {} output["inputs"] = tokens[:, :-1] output["inputs_position"] = positions[:, :-1] output["inputs_segmentation"] = segmentation output["targets"] = tokens[:, 1:] output["targets_position"] = positions[:, 1:] output["targets_segmentation"] = segmentation return output
[docs] class PlaceHolderDataIterator: """Creates a place holder synthetic data iterator for loading on subset of hosts""" def __init__(self, config: pyconfig.HyperParameters, mesh): self.mesh = mesh dataset = PlaceHolderDataIterator.get_place_holder_synthetic_data(config) self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) def __iter__(self): return self.data_generator def __next__(self): return next(self.data_generator)
[docs] def reset(self): pass
[docs] @staticmethod def get_place_holder_synthetic_data(config: pyconfig.HyperParameters): """fill negative value in synthetic data""" batch_size = config.global_batch_size_to_load // jax.process_count() neg_ones = np.full((batch_size, config.max_target_length), -1, dtype=np.int32) batch = { "inputs": neg_ones, "inputs_position": neg_ones, "inputs_segmentation": neg_ones, "targets": neg_ones, "targets_position": neg_ones, "targets_segmentation": neg_ones, } def infinite_iterator(): while True: yield batch return infinite_iterator()