Skip to content

Commit

Permalink
trainScript: workaround MLproject opts type limitation
Browse files Browse the repository at this point in the history
  • Loading branch information
raehik committed Aug 31, 2023
1 parent adf1b45 commit fff986c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
9 changes: 4 additions & 5 deletions MLproject
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ entry_points:
factor: {type: float, default: 0}
chunk_size: {type: string, default: 50}
global: {type: str, default: 0}
command: "python src/gz21_ocean_momentum/cmip26.py {lat_min} {lat_max} {long_min} {long_max} --CO2 {CO2} --ntimes {ntimes} --factor {factor} --chunk_size {chunk_size} --global_ {global}"
command: "python src/gz21_ocean_momentum/cmip26.py {lat_min} {lat_max} {long_min} {long_max} --CO2 {CO2} --ntimes {ntimes} --factor {factor} --chunk_size {chunk_size} --global_ {global}"

train:
parameters:
forcing_data_path: {type: path, default: None}
#exp_id : {type: float, default: 0}
#run_id : {type: string}
forcing_data_path: {type: string, default: None}
run_id : {type: string, default: None}
batchsize : {type : float, default : 8}
learning_rate : {type : string, default : 0\1e-3}
n_epochs : {type : float, default : 100}
Expand All @@ -36,8 +35,8 @@ entry_points:
submodel : {type: string, default : transform3}
features_transform_cls_name : {type : string, default : None}
targets_transform_cls_name : {type : string, default : None}
#{exp_id} {run_id}
command: "python src/gz21_ocean_momentum/trainScript.py
--run-id {run_id}
--forcing-data-path {forcing_data_path}
--batchsize {batchsize} --learning_rate {learning_rate}
--n_epochs {n_epochs} --train_split {train_split}
Expand Down
42 changes: 25 additions & 17 deletions src/gz21_ocean_momentum/trainScript.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from utils import TaskInfo

from typing import Any

torch.autograd.set_detect_anomaly(True)


Expand Down Expand Up @@ -163,33 +165,39 @@ def check_str_is_None(string_in: str):
)
params = parser.parse_args()

def try_get_forcing_data_filepath(params):
"""
Try to get the forcing data filepath to use from the provided command line
options.
def argparse_get_mlflow_artifact_path_or_direct_or_fail(
mlflow_artifact_name: str, params: dict[str, Any]
) -> str:
"""Obtain a filepath either from an MLflow run ID and artifact name, or a
direct path if provided.
params must have keys run_id and forcing_data_path.
Returns a filepath which should be a zarr dataset. Correctness is not
asserted, so the filepath may not exist or may not be a valid zarr dataset.
Only one of run_id and path should be non-None.
Note that the filepath is not checked for validity (but for run_id, MLflow
probably will assert that it exists).
Effectful: errors result in immediate program exit.
"""
if params.run_id is not None:
if params.forcing_data_path is not None:
# got both --run-id and --forcing-data-path: bad
raise argparse.ArgumentError("overlapping options provided (--forcing-data-path and --exp-id)")
if params.run_id is not None and params.run_id != "None":
if params.forcing_data_path is not None and params.forcing_data_path != "None":
# got run ID and direct path: bad
raise TypeError("overlapping options provided (--forcing-data-path and --exp-id)")

# got only --run-id: obtain path via MLflow
# got only run ID: obtain path via MLflow
mlflow.log_param("source.run-id", params.run_id)

mlflow_client = mlflow.tracking.MlflowClient()
return mlflow_client.download_artifacts(params.run_id, "forcing")
return mlflow_client.download_artifacts(params.run_id, mlflow_artifact_name)

if params.forcing_data_path is not None:
# got only --forcing-data-path: use
if params.forcing_data_path is not None and params.forcing_data_path != "None":
# got only direct path: use
return params.forcing_data_path

# if we get here, neither options were provided
raise argparse.ArgumentError("require one of --forcing-data-path or --run-id")
raise TypeError("require one of --run-id or --forcing-data-path")

forcings_path = try_get_forcing_data_filepath(params)
forcings_path = argparse_get_mlflow_artifact_path_or_direct_or_fail("forcing", params)

# --------------------------
# SET UP TRAINING PARAMETERS
Expand Down

0 comments on commit fff986c

Please sign in to comment.