Train a Tensorflow Model on TPU

Tensor Processing Units (TPUs) are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning workloads.

Edit your training script

Define the training strategy and create the model within strategy.scope():

import tensorflow as tf

try: # detect TPUs
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
  strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
  strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines

with strategy.scope():
  model = tf.keras.Sequential() # standard tf.keras code from here

Create the TPU VM

Open a terminal and declare some variables for convenience:

export TPU_NAME="my-tpu"
export BUCKET_NAME="my-bucket"
export ZONE="europe-west4-a"

We will create a TPU VM on Google Cloud with the gcloud command:

gcloud alpha compute tpus tpu-vm create $TPU_NAME \
  --zone=$ZONE \
  --accelerator-type="v3-8" \
  --version="tpu-vm-tf-2.8.0" \

Connect with SSH

Once created, we can connect to the VM via SSH:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Transfer dataset

We need to have all the files for the workload within Cloud Storage or a local folder:

  • the training script (i.e
  • the requirements.txt file

When you are working with TPUs, it is recommended to have your training dataset as TFRecords files hosted on Google Cloud Storage.

Then, we can transfer the data from Cloud Storage to the TPU VM with gsutil:

gsutil cp -r gs://$BUCKET_NAME/ .

Or from the local machine to the TPU VM with scp:

gcloud compute tpus tpu-vm scp ~/my-folder/ $TPU_NAME: --recurse --zone=$ZONE

After the training, you can transfer the data (i.e. the trained model) from the TPU VM back to your local machine:

gcloud compute tpus tpu-vm scp $TPU_NAME:~/my-folder ~/my-local-folder --recurse --zone=$ZONE

Clean up

To avoid incurring charges, you can delete the TPU VM:

gcloud alpha compute tpus tpu-vm delete $TPU_NAME --zone=$ZONE

Ensure your TPU VM is not running anymore:

gcloud alpha compute tpus tpu-vm list --zone=$ZONE