Update MaxText dependencies#
Introduction#
This document provides a guide to updating dependencies in MaxText using the
seed-env
tool. seed-env helps generate deterministic and reproducible Python
environments by creating fully-pinned requirements.txt files from a base set
of requirements.
Please keep dependencies updated throughout development. This will allow each commit to work properly from both a feature and dependency perspective. We will periodically upload commits to PyPI for stable releases. But it is also critical to keep dependencies in sync for users installing MaxText from source.
Overview of the process#
To update dependencies, you will follow these general steps:
Modify base requirements: Update the desired dependencies in
src/dependencies/requirements/base_requirements/requirements.txtor the hardware-specific pre-training files (src/dependencies/requirements/base_requirements/tpu-requirements.txt,src/dependencies/requirements/base_requirements/cuda12-requirements.txt) or the post-training files (src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt).Find the JAX build commit hash: The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build.
Generate the requirement files: Run
src/dependencies/scripts/generate_requirements.sh, which internally invokesseed-envto produce fully-pinned requirements files.Verify the new dependencies: Test the new dependencies to ensure the project installs and runs correctly.
The following sections provide detailed instructions for each step.
Step 0: Install seed-env#
First, you need to install the seed-env command-line tool. We recommend
installing uv first following
uv’s official installation instructions
and then using it to install seed-env:
uv venv --python 3.12 --seed seed_venv
source seed_venv/bin/activate
uv pip install seed-env
Alternatively, follow the instructions in the
seed-env repository
if you want to build seed-env from source.
Step 1: Modify base requirements#
Update the desired dependencies in src/dependencies/requirements/base_requirements/requirements.txt or the hardware-specific pre-training files (src/dependencies/requirements/base_requirements/tpu-requirements.txt, src/dependencies/requirements/base_requirements/cuda12-requirements.txt) or the post-training files (src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt).
Step 2: Find the JAX build commit hash#
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from JAX build/ folder and copy its full commit hash.
Step 3: Generate the requirements files#
Next, run generate_requirements.sh to generate the new requirements files. This
script wraps the seed-env CLI and handles exporting the lock, and applying any
overrides. You will need to do this separately for the TPU and GPU environments.
Note: The current
src/dependencies/requirements/generated_requirements/in the repository were generated using JAX build commit hash: e0d2967b50abbefd651d563dbcd7afbcb963d08c.
TPU Pre-Training#
If you have made changes to TPU pre-training dependencies in src/dependencies/requirements/base_requirements/tpu-requirements.txt, you need to regenerate the pinned pre-training requirements in generated_requirements/ directory. Run the following command, replacing <jax-build-commit-hash> with the hash you copied in the previous step:
bash src/dependencies/scripts/generate_requirements.sh \
--base-requirements src/dependencies/requirements/base_requirements/tpu-requirements.txt \
--generated-requirements tpu-requirements.txt \
--override-requirements src/dependencies/extra_deps/tpu_overrides.txt \
--seed-commit <jax-build-commit-hash>
# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_artifacts/python3_12/tpu-requirements.txt \
src/dependencies/requirements/generated_requirements/tpu-requirements.txt
TPU Post-Training#
If you have made changes to the post-training dependencies in src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt, you need to regenerate the pinned post-training requirements in generated_requirements/ directory. Run the following command, replacing <jax-build-commit-hash> with the hash you copied in the previous step:
bash src/dependencies/scripts/generate_requirements.sh \
--base-requirements src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
--generated-requirements tpu-post-train-requirements.txt \
--override-requirements src/dependencies/extra_deps/tpu_post_train_overrides.txt \
--seed-commit <jax-build-commit-hash>
# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_artifacts/python3_12/tpu-post-train-requirements.txt \
src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt
GPU Pre-Training#
If you have made changes to the GPU pre-training dependencies in src/dependencies/requirements/base_requirements/cuda12-requirements.txt, you need to regenerate the pinned pre-training requirements in generated_requirements/ directory. Run the following command, replacing <jax-build-commit-hash> with the hash you copied in the previous step:
bash src/dependencies/scripts/generate_requirements.sh \
--base-requirements src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
--generated-requirements cuda12-requirements.txt \
--seed-commit <jax-build-commit-hash> \
--hardware cuda12
# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_artifacts/python3_12/cuda12-requirements.txt \
src/dependencies/requirements/generated_requirements/cuda12-requirements.txt
Step 4: Verify the new dependencies#
Finally, test that the new dependencies install correctly and that MaxText runs as expected.
Install MaxText and dependencies: For instructions on installing MaxText on your VM, please refer to the official documentation.
Run tests: Run MaxText tests to ensure there are no regressions.