# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for OLMo-core-style numpy FSL datasets.
Dependency-free layer. AI2's mix files describe a virtual concatenation of
flat token-ID arrays; instances are non-overlapping ``sequence_length``-token
windows of that stream. This module builds the index that maps a global
instance index to (file, byte-offset), and ports OLMo-core's repeated-n-gram
filter (``olmo_core/data/utils.py::find_periodic_sequences``).
"""
from __future__ import annotations
import ast
import bisect
import dataclasses
import hashlib
import json
import os
import struct
from dataclasses import dataclass, field
from typing import BinaryIO, Generator, List, NamedTuple, Optional, Sequence, Tuple
import numpy as np
# Bumped whenever the on-disk index format or fingerprint inputs change.
INDEX_FORMAT_VERSION = "1"
[docs]
@dataclass(frozen=True)
class OlmoNpyFileEntry:
"""One file in the mix: ``n_tokens // sequence_length`` instances starting
at global index ``instance_offset``. Trailing tokens are dropped (matches
OLMo-core)."""
path: str
label: str
n_tokens: int
n_instances: int
instance_offset: int
[docs]
@dataclass
class OlmoNpyIndex:
"""Index over the files in an OLMo data mix. Build via
:func:`build_index`, persist via :meth:`save`, restore via
:func:`load_index`. Mutating fields invalidates :attr:`fingerprint`."""
format_version: str
sequence_length: int
dtype: str # numpy dtype string, e.g. "uint32"
tokenizer: str # informational, e.g. "allenai/dolma3-tokenizer"
files: Tuple[OlmoNpyFileEntry, ...]
total_instances: int
total_tokens: int
fingerprint: str = ""
# Lazily computed bisect helper: cumulative instance_offsets+sentinel for
# binary search in ``global_to_local``. Not serialized.
_instance_offset_starts: Optional[List[int]] = field(default=None, repr=False, compare=False)
def __post_init__(self):
starts = [f.instance_offset for f in self.files]
starts.append(self.total_instances) # sentinel for bisect
object.__setattr__(self, "_instance_offset_starts", starts)
[docs]
def to_json_dict(self) -> dict:
"""Return a JSON-serializable view (drops cached lookup helpers)."""
return {
"format_version": self.format_version,
"sequence_length": self.sequence_length,
"dtype": self.dtype,
"tokenizer": self.tokenizer,
"total_instances": self.total_instances,
"total_tokens": self.total_tokens,
"fingerprint": self.fingerprint,
"files": [dataclasses.asdict(f) for f in self.files],
}
[docs]
def save(self, path: str) -> None:
"""Write the index as JSON to ``path`` (local filesystem)."""
with open(path, "w", encoding="utf-8") as fh:
json.dump(self.to_json_dict(), fh, indent=2)
[docs]
def load_index(path: str) -> OlmoNpyIndex:
"""Load an index from JSON written by :meth:`OlmoNpyIndex.save`.
Args:
path: Local filesystem path to the JSON file.
Returns:
The materialized :class:`OlmoNpyIndex`.
Raises:
ValueError: If ``format_version`` doesn't match this code's expectation.
"""
with open(path, encoding="utf-8") as fh:
data = json.load(fh)
if data.get("format_version") != INDEX_FORMAT_VERSION:
raise ValueError(
f"Index format version mismatch: file has "
f"{data.get('format_version')!r}, code expects {INDEX_FORMAT_VERSION!r}."
)
files = tuple(OlmoNpyFileEntry(**entry) for entry in data["files"])
return OlmoNpyIndex(
format_version=data["format_version"],
sequence_length=data["sequence_length"],
dtype=data["dtype"],
tokenizer=data["tokenizer"],
files=files,
total_instances=data["total_instances"],
total_tokens=data["total_tokens"],
fingerprint=data["fingerprint"],
)
[docs]
def global_to_local(index: OlmoNpyIndex, instance_id: int) -> Tuple[int, int]:
"""Global instance index → ``(file_idx, token_offset)``.
``token_offset`` is in *tokens* (not bytes); the slice
``arr[token_offset : token_offset + sequence_length]`` is the instance.
"""
if instance_id < 0 or instance_id >= index.total_instances:
raise IndexError(f"instance_id {instance_id} out of range " f"[0, {index.total_instances})")
starts = index._instance_offset_starts # type: ignore[attr-defined] # pylint: disable=protected-access
file_idx = bisect.bisect_right(starts, instance_id) - 1
local_instance = instance_id - index.files[file_idx].instance_offset
token_offset = local_instance * index.sequence_length
return file_idx, token_offset
[docs]
def compute_fingerprint(
sequence_length: int,
dtype: str,
tokenizer: str,
files: Sequence[OlmoNpyFileEntry],
) -> str:
"""Stable hash over the fields a restart must preserve.
If any of these change, the global instance ordering changes and resuming
training from a checkpoint would silently produce different batches.
"""
h = hashlib.sha256()
h.update(INDEX_FORMAT_VERSION.encode("utf-8"))
h.update(b"\x00")
h.update(str(sequence_length).encode("utf-8"))
h.update(b"\x00")
h.update(dtype.encode("utf-8"))
h.update(b"\x00")
h.update(tokenizer.encode("utf-8"))
h.update(b"\x00")
for f in files:
h.update(f.path.encode("utf-8"))
h.update(b"\x00")
h.update(str(f.n_tokens).encode("utf-8"))
h.update(b"\x00")
return f"sha256:{h.hexdigest()}"
_NPY_MAGIC = b"\x93NUMPY"
[docs]
def has_npy_magic(first_bytes: bytes) -> bool:
"""Quick check: does this look like a real .npy file?"""
return len(first_bytes) >= 6 and first_bytes[:6] == _NPY_MAGIC
def _file_entry_from_header(
path: str,
label: str,
dtype: str,
shape: Tuple[int, ...],
sequence_length: int,
instance_offset: int,
) -> OlmoNpyFileEntry:
"""Build a file entry from a parsed .npy header (validates shape is 1-D)."""
if len(shape) != 1:
raise ValueError(f"Expected 1-D .npy array for {path}, got shape {shape}.")
n_tokens = int(shape[0])
n_instances = n_tokens // sequence_length
return OlmoNpyFileEntry(
path=path,
label=label,
n_tokens=n_tokens,
n_instances=n_instances,
instance_offset=instance_offset,
)
[docs]
def build_index(
paths_and_labels: Sequence[Tuple[str, str]],
sequence_length: int,
*,
tokenizer: str,
header_reader=read_npy_header_from_path,
) -> OlmoNpyIndex:
"""Build an :class:`OlmoNpyIndex` from ``(path, label)`` entries.
Order matters — global instance ordering is the concatenation in this
order. ``header_reader`` is the seam tests use to avoid disk; production
paths pass a GCS-aware reader.
"""
if sequence_length <= 0:
raise ValueError(f"sequence_length must be positive, got {sequence_length}")
if not paths_and_labels:
raise ValueError("paths_and_labels must be non-empty")
entries: List[OlmoNpyFileEntry] = []
observed_dtype: Optional[str] = None
cum_offset = 0
for path, label in paths_and_labels:
dtype, shape = header_reader(path)
if observed_dtype is None:
observed_dtype = dtype
elif dtype != observed_dtype:
raise ValueError(f"Heterogeneous dtypes across mix files: {observed_dtype!r} " f"and {dtype!r} (at {path}).")
entry = _file_entry_from_header(
path=path,
label=label,
dtype=dtype,
shape=shape,
sequence_length=sequence_length,
instance_offset=cum_offset,
)
entries.append(entry)
cum_offset += entry.n_instances
files = tuple(entries)
total_instances = cum_offset
total_tokens = sum(f.n_tokens for f in files)
fingerprint = compute_fingerprint(
sequence_length=sequence_length,
dtype=observed_dtype or "",
tokenizer=tokenizer,
files=files,
)
return OlmoNpyIndex(
format_version=INDEX_FORMAT_VERSION,
sequence_length=sequence_length,
dtype=observed_dtype or "",
tokenizer=tokenizer,
files=files,
total_instances=total_instances,
total_tokens=total_tokens,
fingerprint=fingerprint,
)
[docs]
class RepetitionTuple(NamedTuple):
"""``arr[start:end]`` is a periodic span of length ``period``,
``times = (end - start) // period``."""
start: int
end: int
period: int
times: int
def _find_end_first_consecutive_true(arr: np.ndarray) -> int:
"""End offset (exclusive) of the leading run of True in ``arr``.
Returns 0 if ``arr[0]`` is False, ``len(arr)`` if all True.
"""
if not arr[0]:
return 0
prog = np.cumsum(arr)
if prog[-1] == len(arr):
return int(len(arr))
# First index where the cumulative sum stops increasing == start of False run.
true_locs = np.where(prog[:-1] == prog[1:])[0]
return int(true_locs[0] + 1)
def _find_start_last_consecutive_true(arr: np.ndarray) -> int:
"""Start offset of the trailing run of True in ``arr``, or -1 if none."""
reverse = _find_end_first_consecutive_true(arr[::-1])
return len(arr) - reverse if reverse > 0 else -1
def _group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndarray]:
"""Split a 1-D array of ints into runs of consecutive values."""
if len(arr) == 0:
return []
return np.split(arr, np.where(np.diff(arr) != stepsize)[0] + 1)
[docs]
def find_periodic_sequences(
arr: np.ndarray,
max_period: int,
min_period: int = 1,
mask_value: int = -1,
) -> Generator[RepetitionTuple, None, None]:
"""Yield :class:`RepetitionTuple` for periodic spans of length ≥ 3 in ``arr``.
``mask_value`` is reshape padding and must not appear in ``arr``. Default
-1 is the max uint32 value, above any realistic vocab; pass an
out-of-vocab sentinel if your vocab hits that id.
"""
if (arr == mask_value).sum() > 0:
raise ValueError("`mask_value` is in the array")
max_period = min(max_period, len(arr) // 3)
for period in range(min_period, max_period + 1):
pad = (period - (len(arr) % period)) % period
padded_arr = np.pad(arr, (0, pad), constant_values=mask_value) if pad else arr
shaped_arr = padded_arr.reshape(-1, period)
is_equal_to_prev_row = shaped_arr == np.roll(shaped_arr, shift=1, axis=0)
rows_with_period = np.where(is_equal_to_prev_row.all(axis=1))[0]
if len(rows_with_period) == 0:
continue
for sequence in _group_consecutive_values(rows_with_period):
start_row = int(sequence[0])
end_row = int(sequence[-1])
start_offset = _find_start_last_consecutive_true(is_equal_to_prev_row[start_row - 1])
start_offset = period - start_offset if start_offset > 0 else 0
end_offset = _find_end_first_consecutive_true(is_equal_to_prev_row[(end_row + 1) % shaped_arr.shape[0]])
start_pos = (start_row - 1) * period - start_offset
end_pos = ((end_row + 1) * period) + end_offset
out = RepetitionTuple(
start=start_pos,
end=end_pos,
period=period,
times=(end_pos - start_pos) // period,
)
if out.times > 2:
yield out
[docs]
def is_clean_instance(
input_ids: np.ndarray,
*,
repetition_max_period: int = 13,
repetition_min_period: int = 1,
repetition_max_count: int = 32,
mask_value: int = -1,
) -> bool:
"""``False`` iff ``input_ids`` has any periodic span (period ∈
[min, max]) that repeats ≥ ``repetition_max_count`` times. Defaults
match OLMo-core's ``_validate_instance``."""
for m in find_periodic_sequences(
input_ids,
max_period=repetition_max_period,
min_period=repetition_min_period,
mask_value=mask_value,
):
if m.times >= repetition_max_count:
return False
return True