Example scripts on how to use Tensorflow's Estimator class.
This repository as an accompanying blogpost at https://medium.com/@peter.roelants/tensorflow-estimator-dataset-apis-caeb71e6e196
The main file of interest will be srs/mnist_estimator.py
, which defines an example Estimator to train an network on mnist.
With Anaconda Python:
conda env create -f env.yml
source activate tensorflow
After setting up the environment you can run the training locally with:
./src/mnist_estimator.py
Training can be monitored with TensorBoard:
tensorboard --logdir=./mnist_training
After training you can check the inference with:
./src/mnist_inference.py
- Create a [new project in the cloud resource manager as described here. (I named my project
mnist-estimator
) - Install the Google Cloud SDK
- Enable the ML Engine APIs.
- Set up a Google Cloud Storage (GCS) bucket as described here. This will be needed to save our model checkpoints. I named my bucket
estimator-data
.
Run the training job on Google Cloud with:
gcloud ml-engine jobs submit training mnist_estimator_`date +%s` \
--project mnist-estimator \
--runtime-version 1.8 \
--python-version 3.5 \
--job-dir gs://estimator-data/train \
--scale-tier BASIC \
--region europe-west1 \
--module-name src.mnist_estimator \
--package-path src/ \
-- \
--train-steps 6000 \
--batch-size 128
Note:
- Replace
gs://estimator-data/
with the link to the bucket you created. - Latest Python supported on gcloud is 3.5 (although I'm using 3.6 locally)
- The
--project
flag will refer to the gcloud project (mnist-estimator
in my case). To avoid using this flag you can set the default project in this case withgcloud config set core/project mnist-estimator
. - You can feed in arguments to the script by adding an empty
--
after the gcloud parameters and adding your custom arguments after, liketrain-steps
andbatch-size
in this case. - Note that the
job-dir
argument will be fed into the arguments ofmnist_estimator
. This script should thus always accept this parameter.
You can follow the training with tensorboard by:
tensorboard --logdir=gs://estimator-data/train
However, tensorboard seems to update very slowly when connected to a gcloud bucket. Sometimes it didn't even want to display all data.
After training you can download the checkpoint files from the gcloud bucket.
There is a Google Cloud blogpost going into more detail on training an estimator in the cloud if you're interested.