Source code for maxtext.input_pipeline.instruction_data_processing
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for instruction dataset."""
import json
import importlib
import os
import re
from maxtext.utils import max_logging
[docs]
def load_data_template_from_file(template_path):
"""Loads a data template from a file."""
if not template_path:
return None
current_dir = os.path.dirname(os.path.abspath(__file__))
repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
template_full_path = os.path.join(repo_root, template_path)
if not os.path.isfile(template_full_path):
return None
if template_full_path.endswith(".json"):
with open(template_full_path, "r", encoding="utf-8") as f:
try:
return json.load(f)
except json.JSONDecodeError:
return None
return None
[docs]
def load_chat_template_from_file(template_path):
"""Loads a chat template from a file."""
if not template_path:
return None
current_dir = os.path.dirname(os.path.abspath(__file__))
repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
template_full_path = os.path.join(repo_root, template_path)
if not os.path.isfile(template_full_path):
return None
if template_full_path.endswith((".jinja", ".j2", ".txt")):
with open(template_full_path, "r", encoding="utf-8") as f:
return f.read()
if template_full_path.endswith(".json"):
with open(template_full_path, "r", encoding="utf-8") as f:
try:
template_config = json.load(f)
if isinstance(template_config, dict) and "chat_template" in template_config:
return template_config["chat_template"]
except json.JSONDecodeError:
return None
return None
[docs]
def get_template_placeholders(template):
"""Dynamically extracts the format keys (placeholders) from a template string."""
# Finds all names inside {...}
return set(re.findall(r"(?<!{){([a-zA-Z0-9_]+)}(?!})", template))
[docs]
def extract_reasoning_and_answer(text, separator):
if separator not in text:
return None, None
[reasoning, answer] = text.split(separator)
return reasoning, answer
[docs]
def math_qa_formatting(example, template_config=None):
"""Maps question-answer pairs to conversational format."""
# Initialize prompt and completion with fallback templates
prompt = {"role": "user", "content": example["question"]}
completion = {"role": "assistant", "content": example["answer"]}
# Apply templates to prompt and completion, if provided
if template_config:
# Apply PROMPT_TEMPLATE to prompt, if provided
if "PROMPT_TEMPLATE" in template_config:
placeholders = get_template_placeholders(template_config["PROMPT_TEMPLATE"])
if "question" not in placeholders:
max_logging.log("PROMPT_TEMPLATE has no 'question' placeholder. No template will be applied to prompt.")
else:
prompt = {
"role": "user",
"content": template_config["PROMPT_TEMPLATE"].format(question=example["question"].strip()),
}
else:
max_logging.log("PROMPT_TEMPLATE is empty. No template will be applied to prompt.")
# Apply COMPLETION_TEMPLATE to completion, if provided
if "COMPLETION_TEMPLATE" in template_config:
placeholders = get_template_placeholders(template_config["COMPLETION_TEMPLATE"])
if "REASONING_ANSWER_SEPARATOR" in template_config:
reasoning, answer = extract_reasoning_and_answer(example["answer"], template_config["REASONING_ANSWER_SEPARATOR"])
if "reasoning" not in placeholders or "answer" not in placeholders:
max_logging.log(
"COMPLETION_TEMPLATE is missing 'reasoning' or 'answer' placeholder."
" No template will be applied to completion."
" Remove REASONING_ANSWER_SEPARATOR from template or update COMPLETION_TEMPLATE."
)
elif reasoning is None or answer is None:
max_logging.log(
"REASONING_ANSWER_SEPARATOR is present in template but not found in answer."
" No template will be applied to completion."
" Update REASONING_ANSWER_SEPARATOR in the template."
)
else:
completion = {
"role": "assistant",
"content": template_config["COMPLETION_TEMPLATE"].format(
reasoning=reasoning.strip(), answer=answer.strip()
),
}
else:
max_logging.log(
"REASONING_ANSWER_SEPARATOR not found in chat template."
" Using only 'answer' placeholder for COMPLETION_TEMPLATE."
)
if "answer" not in placeholders:
max_logging.log(
"COMPLETION_TEMPLATE is missing 'answer' placeholder. No template will be applied to completion."
)
else:
completion = {
"role": "assistant",
"content": template_config["COMPLETION_TEMPLATE"].format(answer=example["answer"].strip()),
}
else:
max_logging.log("COMPLETION_TEMPLATE is empty. No template will be applied to completion.")
example["messages"] = [prompt, completion]
return example
[docs]
def load_formatter(formatting_func_path, **kwargs):
"""Loads a formatter function from a given path.
Returns a callable that takes a dataset and applies the formatter via .map().
"""
module_path, method_name = formatting_func_path.rsplit(".", 1)
module = importlib.import_module(module_path)
func = getattr(module, method_name)
def formatter(dataset, dataset_features):
remove_cols = []
if kwargs:
remove_cols = kwargs.pop("remove_columns", None)
return dataset.map(
func,
fn_kwargs=kwargs if kwargs else None,
features=dataset_features,
remove_columns=remove_cols,
)
return formatter
[docs]
def convert_to_conversational_format(
dataset,
data_columns,
formatting_func_path=None,
formatting_func_kwargs=None,
):
"""Converts instruction dataset to conversational format."""
import datasets # pylint: disable=import-outside-toplevel
dataset_features = datasets.Features(
{"messages": [{"content": datasets.Value("string"), "role": datasets.Value("string")}]}
)
if formatting_func_path:
if not formatting_func_kwargs:
formatting_func_kwargs = {}
formatting_func_kwargs["remove_columns"] = data_columns
template_path = formatting_func_kwargs.pop("template_path", None)
if template_path:
formatting_func_kwargs["template_config"] = load_data_template_from_file(template_path)
formatter = load_formatter(formatting_func_path, **(formatting_func_kwargs))
dataset = formatter(dataset, dataset_features)
data_columns = ["messages"]
return dataset, data_columns
if "question" in data_columns and "answer" in data_columns:
dataset = dataset.map(
math_qa_formatting,
fn_kwargs={},
remove_columns=data_columns,
features=dataset_features,
)
data_columns = ["messages"]
return dataset, data_columns
return dataset, data_columns