diff --git a/llmebench/benchmark.py b/llmebench/benchmark.py index b1b996bd..79180c98 100644 --- a/llmebench/benchmark.py +++ b/llmebench/benchmark.py @@ -1,5 +1,3 @@ -import argparse - import importlib import json import logging @@ -324,10 +322,18 @@ def find_assets(self, filter_str="*.py"): def main(): - parser = argparse.ArgumentParser() - parser.add_argument("benchmark_dir", nargs="?", type=Path) - parser.add_argument("results_dir", nargs="?", type=Path) - parser.add_argument( + parser = utils.ArgumentParserWithDefaultSubcommand() + parser.set_default_subparser("benchmark") + subparsers = parser.add_subparsers( + help="Defaults to 'benchmark'. Specify a command before the help flag to see detailed usage for each command.", + dest="subparser_name", + ) + + parser_main = subparsers.add_parser("benchmark", help="Run the benchmark") + + parser_main.add_argument("benchmark_dir", type=Path) + parser_main.add_argument("results_dir", type=Path) + parser_main.add_argument( "-f", "--filter", default="*.py", @@ -335,8 +341,8 @@ def main(): " Examples are '*ZeroShot*', 'Demography*', '*.py' (default)." " The .py extension is added automatically if missing.", ) - parser.add_argument("--ignore_cache", action="store_true") - parser.add_argument( + parser_main.add_argument("--ignore_cache", action="store_true") + parser_main.add_argument( "-l", "--limit", default=-1, @@ -344,40 +350,34 @@ def main(): help="Limit the number of input instances that will be processed", ) - parser.add_argument( + parser_main.add_argument( "-e", "--env", type=Path, help="Path to an .env file to load model parameters" ) - parser.add_argument( + parser_main.add_argument( "--dry-run", action="store_true", help="Do not run any actual models, but load all the data and process" " few shots. Existing cache will be ignored and overwritten.", ) - data_args = parser.add_argument_group("Data") - data_args.add_argument( - "--data_dir", - default="data/", - type=Path, - help="Default path for data. All relative paths will be resolved by" - " using this as the base path", + parser_download = subparsers.add_parser( + "download", help="Download specific dataset" ) - data_args.add_argument( + parser_download.add_argument( "--download_server", type=str, default="https://llmebench.qcri.org/data/", help="URL to server containing dataset archives", ) - - data_args.add_argument( - "--download", + parser_download.add_argument( + "dataset_name", type=str, help="Download the dataset with the given name (e.g Aqmar)", ) - few_shot_args = parser.add_argument_group("Few Shot Experiments") + few_shot_args = parser_main.add_argument_group("Few Shot Experiments") few_shot_args.add_argument( "-n", "--n_shots", @@ -389,6 +389,16 @@ def main(): " and when it is non-zero, only few shot experiments will be run.", ) + # Common options + for subparser in [parser_main, parser_download]: + subparser.add_argument( + "--data_dir", + default="data/", + type=Path, + help="Default path for data. All relative paths will be resolved by" + " using this as the base path", + ) + args = parser.parse_args() logging.basicConfig( @@ -397,11 +407,9 @@ def main(): format="%(asctime)s %(levelname)s %(message)s", ) - if args.env: - load_dotenv(args.env) - - if args.download: - dataset_name = args.download + # Handle downloading of datasets + if args.subparser_name == "download": + dataset_name = args.dataset_name if not dataset_name.endswith("Dataset"): dataset_name = f"{dataset_name}Dataset" try: @@ -412,13 +420,16 @@ def main(): return dataset.download_dataset(args.data_dir, default_url=args.download_server) return - else: - if args.benchmark_dir is None or args.results_dir is None: - logging.error(parser.print_usage()) - logging.error( - "The following arguments are required: benchmark_dir, results_dir" - ) - return + + if args.env: + load_dotenv(args.env) + + if args.benchmark_dir is None or args.results_dir is None: + logging.error(parser.print_usage()) + logging.error( + "The following arguments are required: benchmark_dir, results_dir" + ) + return benchmark = Benchmark(args.benchmark_dir) diff --git a/llmebench/utils.py b/llmebench/utils.py index 65fbe7b9..70334b4c 100644 --- a/llmebench/utils.py +++ b/llmebench/utils.py @@ -1,3 +1,4 @@ +import argparse import importlib.util import sys @@ -7,6 +8,32 @@ from typing import TYPE_CHECKING +# https://stackoverflow.com/a/4575792 +class ArgumentParserWithDefaultSubcommand(argparse.ArgumentParser): + __default_subparser = None + + def set_default_subparser(self, name): + self.__default_subparser = name + + def _parse_known_args(self, arg_strings, *args, **kwargs): + in_args = set(arg_strings) + d_sp = self.__default_subparser + if d_sp is not None and not {"-h", "--help"}.intersection(in_args): + for x in self._subparsers._actions: + subparser_found = isinstance( + x, argparse._SubParsersAction + ) and in_args.intersection(x._name_parser_map.keys()) + if subparser_found: + break + else: + # insert default in first position, this implies no + # global options without a sub_parsers specified + arg_strings = [d_sp] + arg_strings + return super(ArgumentParserWithDefaultSubcommand, self)._parse_known_args( + arg_strings, *args, **kwargs + ) + + # https://stackoverflow.com/a/41595552 def import_source_file(fname: Path, modname: str) -> "types.ModuleType": """