mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-19 16:22:15 +08:00
137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Date : 8/23/2024 20:00 PM
|
|
# @Author : didi
|
|
# @Desc : Entrance of AFlow.
|
|
|
|
import argparse
|
|
from typing import Dict, List
|
|
|
|
from metagpt.configs.models_config import ModelsConfig
|
|
from metagpt.ext.aflow.data.download_data import download
|
|
from metagpt.ext.aflow.scripts.optimizer import Optimizer
|
|
|
|
|
|
class ExperimentConfig:
|
|
def __init__(self, dataset: str, question_type: str, operators: List[str]):
|
|
self.dataset = dataset
|
|
self.question_type = question_type
|
|
self.operators = operators
|
|
|
|
|
|
EXPERIMENT_CONFIGS: Dict[str, ExperimentConfig] = {
|
|
"DROP": ExperimentConfig(
|
|
dataset="DROP",
|
|
question_type="qa",
|
|
operators=["Custom", "AnswerGenerate", "ScEnsemble"],
|
|
),
|
|
"HotpotQA": ExperimentConfig(
|
|
dataset="HotpotQA",
|
|
question_type="qa",
|
|
operators=["Custom", "AnswerGenerate", "ScEnsemble"],
|
|
),
|
|
"MATH": ExperimentConfig(
|
|
dataset="MATH",
|
|
question_type="math",
|
|
operators=["Custom", "ScEnsemble", "Programmer"],
|
|
),
|
|
"GSM8K": ExperimentConfig(
|
|
dataset="GSM8K",
|
|
question_type="math",
|
|
operators=["Custom", "ScEnsemble", "Programmer"],
|
|
),
|
|
"MBPP": ExperimentConfig(
|
|
dataset="MBPP",
|
|
question_type="code",
|
|
operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"],
|
|
),
|
|
"HumanEval": ExperimentConfig(
|
|
dataset="HumanEval",
|
|
question_type="code",
|
|
operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"],
|
|
),
|
|
}
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AFlow Optimizer")
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
choices=list(EXPERIMENT_CONFIGS.keys()),
|
|
required=True,
|
|
help="Dataset type",
|
|
)
|
|
parser.add_argument("--sample", type=int, default=4, help="Sample count")
|
|
parser.add_argument(
|
|
"--optimized_path",
|
|
type=str,
|
|
default="metagpt/ext/aflow/scripts/optimized",
|
|
help="Optimized result save path",
|
|
)
|
|
parser.add_argument("--initial_round", type=int, default=1, help="Initial round")
|
|
parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds")
|
|
parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop")
|
|
parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds")
|
|
parser.add_argument(
|
|
"--if_first_optimize",
|
|
type=lambda x: x.lower() == "true",
|
|
default=True,
|
|
help="Whether to download dataset for the first time",
|
|
)
|
|
parser.add_argument(
|
|
"--opt_model_name",
|
|
type=str,
|
|
default="claude-3-5-sonnet-20240620",
|
|
help="Specifies the name of the model used for optimization tasks.",
|
|
)
|
|
parser.add_argument(
|
|
"--exec_model_name",
|
|
type=str,
|
|
default="gpt-4o-mini",
|
|
help="Specifies the name of the model used for execution tasks.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
|
|
config = EXPERIMENT_CONFIGS[args.dataset]
|
|
|
|
models_config = ModelsConfig.default()
|
|
opt_llm_config = models_config.get(args.opt_model_name)
|
|
if opt_llm_config is None:
|
|
raise ValueError(
|
|
f"The optimization model '{args.opt_model_name}' was not found in the 'models' section of the configuration file. "
|
|
"Please add it to the configuration file or specify a valid model using the --opt_model_name flag. "
|
|
)
|
|
|
|
exec_llm_config = models_config.get(args.exec_model_name)
|
|
if exec_llm_config is None:
|
|
raise ValueError(
|
|
f"The execution model '{args.exec_model_name}' was not found in the 'models' section of the configuration file. "
|
|
"Please add it to the configuration file or specify a valid model using the --exec_model_name flag. "
|
|
)
|
|
|
|
download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize)
|
|
|
|
optimizer = Optimizer(
|
|
dataset=config.dataset,
|
|
question_type=config.question_type,
|
|
opt_llm_config=opt_llm_config,
|
|
exec_llm_config=exec_llm_config,
|
|
check_convergence=args.check_convergence,
|
|
operators=config.operators,
|
|
optimized_path=args.optimized_path,
|
|
sample=args.sample,
|
|
initial_round=args.initial_round,
|
|
max_rounds=args.max_rounds,
|
|
validation_rounds=args.validation_rounds,
|
|
)
|
|
|
|
# Optimize workflow via setting the optimizer's mode to 'Graph'
|
|
optimizer.optimize("Graph")
|
|
|
|
# Test workflow via setting the optimizer's mode to 'Test'
|
|
# optimizer.optimize("Test")
|