From 4e29e270e6413fe759e64e20c294561690fa1620 Mon Sep 17 00:00:00 2001 From: Gonzalo Gasca Meza Date: Mon, 19 Aug 2019 15:11:27 -0700 Subject: [PATCH] Add tracking URI support --- tutorials/tensorflow/mlflow_gcp/README.md | 24 +++++++++++++++++++ .../tensorflow/mlflow_gcp/trainer/task.py | 10 +++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tutorials/tensorflow/mlflow_gcp/README.md b/tutorials/tensorflow/mlflow_gcp/README.md index 273d33c1..283e5a4b 100644 --- a/tutorials/tensorflow/mlflow_gcp/README.md +++ b/tutorials/tensorflow/mlflow_gcp/README.md @@ -181,6 +181,30 @@ gcloud ai-platform local train --package-path trainer \ --eval-steps $EVAL_STEPS ``` +#### Run via the `gcloud` command in AI Platform: + +``` +DATE=`date '+%Y%m%d_%H%M%S'` +export JOB_NAME=mlflow_$DATE +export REGION=us-central1 +export GCS_JOB_DIR=gs://mlflow_gcp/jobs/$JOB_NAME + +gcloud ai-platform job sumit training $JOB_NAME \ + --stream-logs \ + --runtime-version 1.14 \ + --package-path trainer \ + --module-name trainer.task \ + --region $REGION \ + -- \ + --train-files $TRAIN_FILE \ + --eval-files $EVAL_FILE \ + --job-dir $GCS_JOB_DIR \ + --train-steps $TRAIN_STEPS \ + --eval-steps $EVAL_STEPS + --mlflow-tracking-uri http://:5000 +``` + + #### Hyperparameter tuning: You can optionally perform hyperparameter tuning by using the included diff --git a/tutorials/tensorflow/mlflow_gcp/trainer/task.py b/tutorials/tensorflow/mlflow_gcp/trainer/task.py index 6501ff94..2e63d9ba 100644 --- a/tutorials/tensorflow/mlflow_gcp/trainer/task.py +++ b/tutorials/tensorflow/mlflow_gcp/trainer/task.py @@ -102,6 +102,12 @@ def get_args(): '--project-id', type=str, help='AI Platform project id') + # + parser.add_argument( + '--mlflow-tracking-uri', + type=str, + default='mlflow tracking URI', + help='MLFlow tracking URI') parser.add_argument( '--gcs-bucket', type=str, @@ -187,8 +193,10 @@ def train_and_evaluate(args): batch_size=num_eval_examples) start_time = time() + # Set tracking URI + if args.mlflow_tracking_uri: + mlflow.set_tracking_uri(args.mlflow_tracking_uri) # Train model - with mlflow.start_run() as active_run: run_id = active_run.info.run_id # Setup Learning Rate decay.