Skip to content

Commit

Permalink
trainScript: --run_id -> --run-id; remove --exp_id
Browse files Browse the repository at this point in the history
--exp_id was never used, due to how MLflow searches through its runs.
  • Loading branch information
raehik committed Aug 31, 2023
1 parent bcca14f commit adf1b45
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions src/gz21_ocean_momentum/trainScript.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,11 @@ def check_str_is_None(string_in: str):
description = (
"Trains a model on a chosen dataset from the store."
"Allows to set training parameters via the CLI."
"Use one of either --run-id or --forcing-data-path."
)
parser = argparse.ArgumentParser(description=description)

# access input forcing data via MLflow
parser.add_argument(
"--exp_id",
type=int,
help="Experiment id of the source dataset containing the " "training data.",
)
parser.add_argument("--run_id", type=str, help="Run id of the source dataset")
parser.add_argument("--run-id", type=str, help="MLflow run ID of data step containing forcing data to use")

# access input forcing data via absolute filepath
parser.add_argument("--forcing-data-path", type=str, help="Filepath of the forcing data")
Expand Down Expand Up @@ -175,22 +170,20 @@ def try_get_forcing_data_filepath(params):
Returns a filepath which should be a zarr dataset. Correctness is not
asserted, so the filepath may not exist or may not be a valid zarr dataset.
TODO bit wonky the way we do mutual exclusion.
"""
if params.run_id is not None and params.exp_id is not None:
if params.run_id is not None:
if params.forcing_data_path is not None:
# got both --run-id and --forcing-data-path: bad
raise argparse.ArgumentError("overlapping options provided (--forcing-data-path and --exp-id)")

# got run_id and exp_id: try to use
# TODO we don't actually use exp_id
mlflow.log_param("source.exp_id", params.exp_id)
mlflow.log_param("source.run_id", params.run_id)
# got only --run-id: obtain path via MLflow
mlflow.log_param("source.run-id", params.run_id)

mlflow_client = mlflow.tracking.MlflowClient()
return mlflow_client.download_artifacts(params.run_id, "forcing")

if params.forcing_data_path is not None:
# got only --forcing-data-path: use
return params.forcing_data_path

# if we get here, neither options were provided
Expand Down

0 comments on commit adf1b45

Please sign in to comment.