Skip to content

Commit

Permalink
renamed run_pipline args to evaluate pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Meganton committed Nov 8, 2024
1 parent 1ce8f4a commit 42af969
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions neps/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,14 @@ def run_optimization(args: argparse.Namespace) -> None:
run_args = Path("run_config.yaml")
else:
run_args = args.run_args
if not isinstance(args.run_pipeline, Default):
module_path, function_name = args.run_pipeline.split(":")
if not isinstance(args.evaluate_pipeline, Default):
module_path, function_name = args.evaluate_pipeline.split(":")
evaluate_pipeline = load_and_return_object(
module_path, function_name, EVALUATE_PIPELINE
)

else:
evaluate_pipeline = args.run_pipeline
evaluate_pipeline = args.evaluate_pipeline

kwargs = {}
if args.searcher_kwargs:
Expand Down
6 changes: 3 additions & 3 deletions neps/utils/run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_run_args_from_yaml(path: str | Path) -> dict:
validates these arguments, and then returns them in a dictionary. It checks for the
presence and validity of expected parameters, and distinctively handles more complex
configurations, specifically those that are dictionaries(e.g. pipeline_space) or
objects(e.g. run_pipeline) requiring loading.
objects(e.g. evaluate_pipeline) requiring loading.
Args:
path (str): The file path to the YAML configuration file.
Expand All @@ -67,7 +67,7 @@ def get_run_args_from_yaml(path: str | Path) -> dict:
settings = {}

# List allowed NePS run arguments with simple types (e.g., string, int). Parameters
# like 'run_pipeline', 'preload_hooks', 'pipeline_space',
# like 'evaluate_pipeline', 'preload_hooks', 'pipeline_space',
# and 'searcher' are excluded due to needing specialized processing.
expected_parameters = [
ROOT_DIRECTORY,
Expand Down Expand Up @@ -146,7 +146,7 @@ def config_loader(path: str | Path) -> dict:

def extract_leaf_keys(d: dict, special_keys: dict | None = None) -> tuple[dict, dict]:
"""Recursive function to extract leaf keys and their values from a nested dictionary.
Special keys (e.g.'run_pipeline') are also extracted if present
Special keys (e.g.'evaluate_pipeline') are also extracted if present
and their corresponding values (dict) at any level in the nested structure.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def evaluate_pipeline_constant(learning_rate, optimizer, epochs, batch_size):
description="Run NEPS optimization with run_args.yml."
)
parser.add_argument("run_args", type=str, help="Path to the YAML configuration file.")
parser.add_argument("--run_pipeline", action="store_true")
parser.add_argument("--evaluate_pipeline", action="store_true")
args = parser.parse_args()

if args.run_pipeline:
neps.run(run_args=args.run_args, evaluate_pipeline=run_pipeline_constant)
if args.evaluate_pipeline:
neps.run(run_args=args.run_args, evaluate_pipeline=evaluate_pipeline_constant)
else:
neps.run(run_args=args.run_args)
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_run_with_yaml_and_run_pipeline() -> None:

try:
subprocess.check_call(
[sys.executable, BASE_PATH / "neps_run.py", yaml_path, "--run_pipeline"]
[sys.executable, BASE_PATH / "neps_run.py", yaml_path, "--evaluate_pipeline"]
)
except subprocess.CalledProcessError as e:
pytest.fail(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_yaml_run_args/test_yaml_run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def are_functions_equivalent(
# Compare keys with a function/list of functions as their values
# Special because they include a module loading procedure by a path and the name of
# the function
for special_key in ["run_pipeline", "pre_load_hooks"]:
for special_key in ["evaluate_pipeline", "pre_load_hooks"]:
if special_key in expected_output:
func_expected = expected_output.pop(special_key)
func_output = output.pop(special_key)
Expand All @@ -105,7 +105,7 @@ def are_functions_equivalent(
(
"run_args_full.yaml",
{
"run_pipeline": evaluate_pipeline,
"evaluate_pipeline": evaluate_pipeline,
"pipeline_space": pipeline_space,
"root_directory": "test_yaml",
"max_evaluations_total": 20,
Expand All @@ -129,7 +129,7 @@ def are_functions_equivalent(
(
"run_args_full_same_level.yaml",
{
"run_pipeline": evaluate_pipeline,
"evaluate_pipeline": evaluate_pipeline,
"pipeline_space": pipeline_space,
"root_directory": "test_yaml",
"max_evaluations_total": 20,
Expand Down Expand Up @@ -181,7 +181,7 @@ def are_functions_equivalent(
(
"run_args_optional_loading_format.yaml",
{
"run_pipeline": evaluate_pipeline,
"evaluate_pipeline": evaluate_pipeline,
"pipeline_space": pipeline_space,
"root_directory": "test_yaml",
"max_evaluations_total": 20,
Expand Down

0 comments on commit 42af969

Please sign in to comment.