# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input pipeline for synthetic dataset."""
from collections.abc import Callable
from typing import Any
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from maxtext.input_pipeline import multihost_dataloading
from maxtext.configs import pyconfig
from maxtext.utils import sharding
[docs]
class SyntheticDataIterator:
"""Creates a synthetic data iterator for performance testing work"""
data_generator: Callable[[pyconfig.HyperParameters, tuple[Any, ...]], dict]
def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = sharding.remove_size_one_mesh_axis(P(*config.data_sharding), mesh)
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(
SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
)
tokens = jax.random.randint(
jax.random.PRNGKey(0),
(config.global_batch_size_to_load, config.max_target_length + 1),
0,
config.vocab_size,
dtype=jnp.int32,
)
sequence_positions = jnp.arange(0, config.max_target_length + 1, dtype=jnp.int32).reshape(1, -1)
batch_positions = jnp.broadcast_to(
sequence_positions, (config.global_batch_size_to_load, config.max_target_length + 1)
)
segmentation = jnp.ones((config.global_batch_size_to_load, config.max_target_length), dtype=jnp.int32)
self.data = (tokens, batch_positions, segmentation)
[docs]
def reset(self):
pass # Synthetic data is stateless; nothing to reset.
def __iter__(self):
return self
def __next__(self):
with self.mesh:
return self.data_generator(self.config, self.data) # pylint: disable=not-callable
[docs]
@staticmethod
def raw_generate_synthetic_data(config: pyconfig.HyperParameters, data):
"""Generates a single batch of synthetic data"""
tokens, positions, segmentation = data
output = {}
output["inputs"] = tokens[:, :-1]
output["inputs_position"] = positions[:, :-1]
output["inputs_segmentation"] = segmentation
output["targets"] = tokens[:, 1:]
output["targets_position"] = positions[:, 1:]
output["targets_segmentation"] = segmentation
return output
[docs]
class PlaceHolderDataIterator:
"""Creates a place holder synthetic data iterator for loading on subset of hosts"""
def __init__(self, config: pyconfig.HyperParameters, mesh):
self.mesh = mesh
dataset = PlaceHolderDataIterator.get_place_holder_synthetic_data(config)
self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)
def __iter__(self):
return self.data_generator
def __next__(self):
return next(self.data_generator)
[docs]
@staticmethod
def get_place_holder_synthetic_data(config: pyconfig.HyperParameters):
"""fill negative value in synthetic data"""
batch_size = config.global_batch_size_to_load // jax.process_count()
neg_ones = np.full((batch_size, config.max_target_length), -1, dtype=np.int32)
batch = {
"inputs": neg_ones,
"inputs_position": neg_ones,
"inputs_segmentation": neg_ones,
"targets": neg_ones,
"targets_position": neg_ones,
"targets_segmentation": neg_ones,
}
def infinite_iterator():
while True:
yield batch
return infinite_iterator()