Train a Tensorflow Model on TPU
Tensor Processing Units (TPUs) are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning workloads.
Edit your training script
Define the training strategy and create the model within strategy.scope()
:
import tensorflow as tf
try: # detect TPUs
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
with strategy.scope():
model = tf.keras.Sequential() # standard tf.keras code from here
Create the TPU VM
Open a terminal and declare some variables for convenience:
export TPU_NAME="my-tpu"
export BUCKET_NAME="my-bucket"
export ZONE="europe-west4-a"
We will create a TPU VM on Google Cloud with the gcloud
command:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE \
--accelerator-type="v3-8" \
--version="tpu-vm-tf-2.8.0" \
--preemptible
Connect with SSH
Once created, we can connect to the VM via SSH:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
Transfer dataset
We need to have all the files for the workload within Cloud Storage or a local folder:
- the training script (i.e
train.py
) - the
requirements.txt
file
When you are working with TPUs, it is recommended to have your training dataset as TFRecords files hosted on Google Cloud Storage.
Then, we can transfer the data from Cloud Storage to the TPU VM with gsutil
:
gsutil cp -r gs://$BUCKET_NAME/ .
Or from the local machine to the TPU VM with scp
:
gcloud compute tpus tpu-vm scp ~/my-folder/ $TPU_NAME: --recurse --zone=$ZONE
After the training, you can transfer the data (i.e. the trained model) from the TPU VM back to your local machine:
gcloud compute tpus tpu-vm scp $TPU_NAME:~/my-folder ~/my-local-folder --recurse --zone=$ZONE
Clean up
To avoid incurring charges, you can delete the TPU VM:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --zone=$ZONE
Ensure your TPU VM is not running anymore:
gcloud alpha compute tpus tpu-vm list --zone=$ZONE