Source code for maxtext.input_pipeline.tokenizer

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides op for tokenizing a dataset."""

from typing import Literal, Sequence, Collection
from pathlib import Path
from maxtext.utils import max_logging
import transformers
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from sentencepiece import SentencePieceProcessor


[docs] class TikTokenTokenizer: """ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ special_tokens: dict[str, int] num_reserved_special_tokens = 256 pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # pylint: disable=line-too-long def __init__(self, model_path: str, add_bos: bool, add_eos: bool): """ Initializes the Tokenizer with a Tiktoken model. Args: model_path (str): The path to the Tiktoken model file. """ mergeable_ranks = load_tiktoken_bpe(model_path) num_base_tokens = len(mergeable_ranks) special_tokens = [ "<|begin_of_text|>", "<|end_of_text|>", "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", "<|reserved_special_token_2|>", "<|reserved_special_token_3|>", "<|start_header_id|>", "<|end_header_id|>", "<|reserved_special_token_4|>", "<|eot_id|>", # end of turn ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)] self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)} self.model = tiktoken.Encoding( name=Path(model_path).name, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens, ) self.eos = add_eos self.bos = add_bos max_logging.log(f"Reloaded tiktoken model from {model_path}") self.n_words: int = self.model.n_vocab # BOS / EOS token IDs self.bos_id: int = self.special_tokens["<|begin_of_text|>"] self.eos_id: int = self.special_tokens["<|end_of_text|>"] self.pad_id: int = -1 self.stop_tokens = { self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"], } max_logging.log(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
[docs] def encode( self, s: str, *, allowed_special: Literal["all"] | Collection[str] = (), disallowed_special: Literal["all"] | Collection[str] = (), ) -> list[int]: """ Encodes a string into a list of token IDs. Args: s (str): The input string to be encoded. bos (bool): Whether to prepend the beginning-of-sequence token. eos (bool): Whether to append the end-of-sequence token. allowed_tokens ("all"|set[str]): allowed special tokens in string disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string Returns: list[int]: A list of token IDs. By default, setting disallowed_special=() encodes a string by ignoring special tokens. Specifically: - Setting `disallowed_special` to () will cause all text corresponding to special tokens to be encoded as natural text (insteading of raising an error). - Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens. """ assert isinstance(s, str) # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 # https://github.com/openai/tiktoken/issues/195 # Here we iterate over subsequences and split if we exceed the limit # of max consecutive non-whitespace or whitespace characters. MAX_NO_WHITESPACES_CHARS = 25_000 substrs = ( substr for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) for substr in self._split_whitespaces_or_nonwhitespaces( s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS ) ) t: list[int] = [] for substr in substrs: t.extend( self.model.encode( substr, allowed_special=set(allowed_special), disallowed_special=disallowed_special, ) ) if self.bos: t.insert(0, self.bos_id) if self.eos: t.append(self.eos_id) return t
[docs] def decode(self, t) -> str: """ Decodes a list of token IDs into a string. Args: t (list[int]): The list of token IDs to be decoded. Returns: str: The decoded string. """ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(t)
@staticmethod def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int): """ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` consecutive whitespaces or consecutive non-whitespaces. """ current_slice_len = 0 current_slice_is_space = s[0].isspace() if len(s) > 0 else False slice_start = 0 for i, _ in enumerate(s): is_now_space = s[i].isspace() if current_slice_is_space ^ is_now_space: current_slice_len = 1 current_slice_is_space = is_now_space else: current_slice_len += 1 if current_slice_len > max_consecutive_slice_len: yield s[slice_start:i] slice_start = i current_slice_len = 1 yield s[slice_start:]
[docs] class SentencePieceTokenizer: """ Tokenizing and encoding/decoding text using the native sentencepiece library. Supports both local and GCS (gs://) model paths. """ def __init__(self, model_path: str, add_bos: bool, add_eos: bool): max_logging.log(f"Loading sentencepiece tokenizer: {model_path}") self._tokenizer_model = SentencePieceProcessor() try: if model_path.startswith("gs://"): from maxtext.utils.gcs_utils import read_bytes_from_gcs # pylint: disable=import-outside-toplevel model_proto = read_bytes_from_gcs(model_path) self._tokenizer_model.LoadFromSerializedProto(model_proto) else: self._tokenizer_model.Load(model_path) except Exception as e: raise ValueError(f"Failed to load sentencepiece tokenizer from {model_path}: {e}") from e self.pad_id = self._tokenizer_model.pad_id() self.unk_id = self._tokenizer_model.unk_id() self.bos_id = self._tokenizer_model.bos_id() self.eos_id = self._tokenizer_model.eos_id() self.add_bos = add_bos self.add_eos = add_eos
[docs] def encode(self, s: str) -> list[int]: token_ids = self._tokenizer_model.EncodeAsIds(s) if self.add_bos: token_ids = [self.bos_id] + token_ids if self.add_eos: token_ids += [self.eos_id] return token_ids
[docs] def decode(self, t: Sequence[int]) -> str: return self._tokenizer_model.DecodeIds(t)
[docs] class HFTokenizer: """ Tokenizing using huggingface tokenizer """ def __init__(self, model_path: str, add_bos: bool, add_eos: bool, hf_access_token: str): max_logging.log(f"Loading HF tokenizer: {model_path}") self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, add_bos_token=add_bos, add_eos_token=add_eos, token=hf_access_token, ) self.pad_id = self.tokenizer.pad_token_id self.unk_id = self.tokenizer.unk_token_id self.bos_id = self.tokenizer.bos_token_id self.eos_id = self.tokenizer.eos_token_id
[docs] def encode(self, s: str) -> list[int]: return self.tokenizer.encode(s)
[docs] def decode(self, t: Sequence[int]) -> str: return self.tokenizer.decode(t)
[docs] def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token): """Loads the tokenizer at `tokenizer_path`""" max_logging.log(f"Tokenizer path: {tokenizer_path}") if tokenizer_type == "tiktoken": assert "tiktoken" in tokenizer_path, f"Invalid tokenizer type: {tokenizer_type} chosen for {tokenizer_path}" return TikTokenTokenizer(tokenizer_path, add_bos, add_eos) elif tokenizer_type == "huggingface": return HFTokenizer(tokenizer_path, add_bos, add_eos, hf_access_token) elif tokenizer_type == "sentencepiece": return SentencePieceTokenizer(tokenizer_path, add_bos, add_eos) else: raise ValueError(f"Invalid tokenizer_type:{tokenizer_type} chosen in config")