From fff986c83e0b8288d84db5a302fe0ee8b30ee562 Mon Sep 17 00:00:00 2001 From: Ben Orchard Date: Fri, 25 Aug 2023 12:36:53 +0100 Subject: [PATCH] trainScript: workaround MLproject opts type limitation --- MLproject | 9 +++--- src/gz21_ocean_momentum/trainScript.py | 42 +++++++++++++++----------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/MLproject b/MLproject index dacabac0..d3eee54e 100755 --- a/MLproject +++ b/MLproject @@ -14,13 +14,12 @@ entry_points: factor: {type: float, default: 0} chunk_size: {type: string, default: 50} global: {type: str, default: 0} - command: "python src/gz21_ocean_momentum/cmip26.py {lat_min} {lat_max} {long_min} {long_max} --CO2 {CO2} --ntimes {ntimes} --factor {factor} --chunk_size {chunk_size} --global_ {global}" + command: "python src/gz21_ocean_momentum/cmip26.py {lat_min} {lat_max} {long_min} {long_max} --CO2 {CO2} --ntimes {ntimes} --factor {factor} --chunk_size {chunk_size} --global_ {global}" train: parameters: - forcing_data_path: {type: path, default: None} - #exp_id : {type: float, default: 0} - #run_id : {type: string} + forcing_data_path: {type: string, default: None} + run_id : {type: string, default: None} batchsize : {type : float, default : 8} learning_rate : {type : string, default : 0\1e-3} n_epochs : {type : float, default : 100} @@ -36,8 +35,8 @@ entry_points: submodel : {type: string, default : transform3} features_transform_cls_name : {type : string, default : None} targets_transform_cls_name : {type : string, default : None} - #{exp_id} {run_id} command: "python src/gz21_ocean_momentum/trainScript.py + --run-id {run_id} --forcing-data-path {forcing_data_path} --batchsize {batchsize} --learning_rate {learning_rate} --n_epochs {n_epochs} --train_split {train_split} diff --git a/src/gz21_ocean_momentum/trainScript.py b/src/gz21_ocean_momentum/trainScript.py index 116348e1..dd374cba 100755 --- a/src/gz21_ocean_momentum/trainScript.py +++ b/src/gz21_ocean_momentum/trainScript.py @@ -43,6 +43,8 @@ from utils import TaskInfo +from typing import Any + torch.autograd.set_detect_anomaly(True) @@ -163,33 +165,39 @@ def check_str_is_None(string_in: str): ) params = parser.parse_args() -def try_get_forcing_data_filepath(params): - """ - Try to get the forcing data filepath to use from the provided command line - options. +def argparse_get_mlflow_artifact_path_or_direct_or_fail( + mlflow_artifact_name: str, params: dict[str, Any] + ) -> str: + """Obtain a filepath either from an MLflow run ID and artifact name, or a + direct path if provided. + + params must have keys run_id and forcing_data_path. - Returns a filepath which should be a zarr dataset. Correctness is not - asserted, so the filepath may not exist or may not be a valid zarr dataset. + Only one of run_id and path should be non-None. + + Note that the filepath is not checked for validity (but for run_id, MLflow + probably will assert that it exists). + + Effectful: errors result in immediate program exit. """ - if params.run_id is not None: - if params.forcing_data_path is not None: - # got both --run-id and --forcing-data-path: bad - raise argparse.ArgumentError("overlapping options provided (--forcing-data-path and --exp-id)") + if params.run_id is not None and params.run_id != "None": + if params.forcing_data_path is not None and params.forcing_data_path != "None": + # got run ID and direct path: bad + raise TypeError("overlapping options provided (--forcing-data-path and --exp-id)") - # got only --run-id: obtain path via MLflow + # got only run ID: obtain path via MLflow mlflow.log_param("source.run-id", params.run_id) - mlflow_client = mlflow.tracking.MlflowClient() - return mlflow_client.download_artifacts(params.run_id, "forcing") + return mlflow_client.download_artifacts(params.run_id, mlflow_artifact_name) - if params.forcing_data_path is not None: - # got only --forcing-data-path: use + if params.forcing_data_path is not None and params.forcing_data_path != "None": + # got only direct path: use return params.forcing_data_path # if we get here, neither options were provided - raise argparse.ArgumentError("require one of --forcing-data-path or --run-id") + raise TypeError("require one of --run-id or --forcing-data-path") -forcings_path = try_get_forcing_data_filepath(params) +forcings_path = argparse_get_mlflow_artifact_path_or_direct_or_fail("forcing", params) # -------------------------- # SET UP TRAINING PARAMETERS