Source code for maxtext.experimental.agent.ckpt_conversion_agent.analysis
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A plan agent to analysis the HF & Maxtext models architecture and generate a conversion plan in json format.
"""
import json
import os
from maxtext.experimental.agent.ckpt_conversion_agent.utils.utils import load_prompt_template, load_json, load_text_file
from maxtext.experimental.agent.ckpt_conversion_agent.base import BaseAgent
[docs]
class AnalysisAgent(BaseAgent):
"""
An agent that demonstrates a multi-step prompt chain to generate a model
conversion script, with verification that every parameter is mapped.
"""
def __init__(self, api_key, dir_path, target_model="gemma3", max_retries=3):
"""
Initializes the PlanAgent.
Args:
target_model (str): The target model for conversion.
max_retries (int): The maximum number of retries for generation.
"""
super().__init__(api_key)
self.target_model = target_model
self.max_retries = max_retries
self.dir_path = dir_path
self.maxtext_params = load_json(f"{dir_path}/context/{target_model}/maxtext_params.json")
self.hf_params = load_json(f"{dir_path}/context/{target_model}/hf_params.json")
self.dsl = load_text_file(f"{dir_path}/context/dsl.txt")
self.prompt_templates = self._load_prompt_templates()
def _load_prompt_templates(self):
"""Loads all necessary prompt templates."""
templates = {
"analysis": load_prompt_template(f"{self.dir_path}/prompts/01_analysis.txt"),
"pitfalls": load_prompt_template(f"{self.dir_path}/prompts/04_pitfalls.txt"),
}
return templates
[docs]
def analyze_model_structures(self):
"""
Analyzes the model structures of MaxText and Hugging Face parameters.
"""
if not self.maxtext_params or not self.hf_params:
print("Could not perform analysis due to missing parameter files.")
return
print("Analysis Agent: Analyzing model structures...")
# analysis
prompt1 = self.prompt_templates["analysis"].format(
target_model=self.target_model,
maxtext_params_json=json.dumps(self.maxtext_params, indent=2),
hf_params_json=json.dumps(self.hf_params, indent=2),
dsl=self.dsl,
pitfalls=self.prompt_templates["pitfalls"],
)
# Generate the analysis
analysis = self.generate_text(prompt1)
# Save the analysis to a file
output_dir = f"{self.dir_path}/outputs"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_path = os.path.join(output_dir, "analysis.txt")
try:
with open(file_path, "wt", encoding="utf-8") as f:
f.write(analysis)
print(f"Analysis successfully saved to {file_path}")
except IOError as e:
print(f"Error saving analysis file: {e}")