Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active July 18, 2025 14:37
Show Gist options
  • Save vfdev-5/70f695e462443685a0922e79ce0ee899 to your computer and use it in GitHub Desktop.
Save vfdev-5/70f695e462443685a0922e79ce0ee899 to your computer and use it in GitHub Desktop.
Torch XLA, Multi-host TPUs run with profiling

Run PyTorch XLA script on a TPU Pod

We assume that gcloud cli is already installed.

Create a cloud storage for the profiling outputs

On your machine:

export OUTPUT_BUCKET_NAME=torch-xla-xprof-outputs
gcloud storage buckets create gs://${OUTPUT_BUCKET_NAME}

Create multi-host TPU VMs and run a training script

export ACCELERATOR_TYPE=v4-16
export RUNTIME_VERSION=tpu-ubuntu2204-base
export TPU_NAME=torch-xla-multihost-example-${ACCELERATOR_TYPE}

export OUTPUT_BUCKET_NAME=torch-xla-xprof-outputs
export DOCKER_IMAGE=us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_tpuvm

# Create a start-up script:
cat << EOF > example_startup.sh
#!/bin/bash
set -eux

# get torch_xla docker image
docker pull $DOCKER_IMAGE

# Install gcsfuse
export GCSFUSE_REPO=gcsfuse-\`lsb_release -c -s\`
echo "deb https://packages.cloud.google.com/apt \$GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
apt-get update
apt-get install -y fuse gcsfuse --no-install-recommends

# Mount GCS bucket
mkdir -p /root/logs
gcsfuse ${OUTPUT_BUCKET_NAME} /root/logs

EOF

gcloud compute tpus tpu-vm create $TPU_NAME --spot \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION \
    --metadata-from-file=startup-script=example_startup.sh

Copy local script to all TPU VMs:

gcloud compute tpus tpu-vm ssh $TPU_NAME --worker=all --command "mkdir -p /root/example/"
gcloud compute tpus tpu-vm scp --worker=all mnist_xla.py $TPU_NAME:/root/example/mnist_xla.py

Define the command to run inside the docker container:

docker_run_args="--rm --privileged --net=host --ipc=host -e PJRT_DEVICE=TPU -v /root:/root -w /root/example"
docker_run_cmd="docker run ${docker_run_args} $DOCKER_IMAGE"
full_command="cd /root/example && ${docker_run_cmd} python -u mnist_xla.py"
mkdir -p logs
gcloud compute tpus tpu-vm ssh $TPU_NAME --worker=all --command "${full_command}" &> logs/mnist_xla.log

We can check local logs on run at logs/mnist_xla.log. Profiling logs are stored in the Cloud Storage bucket: torch-xla-xprof-outputs.

Open profiler logs in TensorBoard

We can use a TPU VM to open TensorBoard or use your machine with gcsfuse.

We first install the TensorBoard package and profiling plugin:

pip install -U tensorboard-plugin-profile tensorboard

On the your machine we mount the bucket using gcsfuse:

mkdir -p remote_logs
gcsfuse torch-xla-xprof-outputs remote_logs

Open TensorBoard:

tensorboard --logdir=remote_logs/ --port 7007
Using ssh batch size of 1. Attempting to SSH into 1 nodes with a total of 2 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
Running on XLA device: xla:0
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Running on XLA device: xla:0
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
0.3%
0.7%
0.3%
0.7%
1.0%
1.3%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100.0%
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
100.0%
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
2.0%
4.0%
2.0%
6.0%
7.9%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
95.4%
97.4%
99.4%
100.0%
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
100.0%
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
100.0%
Starting training...
Starting training...
Train Epoch: 1 [0/60000 (0%)] Loss: 2.326566
Train Epoch: 1 [0/60000 (0%)] Loss: 2.326566
Train Epoch: 1 [6400/60000 (11%)] Loss: 1.838014
Train Epoch: 1 [6400/60000 (11%)] Loss: 1.838014
Train Epoch: 1 [12800/60000 (21%)] Loss: 1.902737
Train Epoch: 1 [12800/60000 (21%)] Loss: 1.902737
Train Epoch: 1 [19200/60000 (32%)] Loss: 1.888527
Train Epoch: 1 [19200/60000 (32%)] Loss: 1.888527
Train Epoch: 1 [25600/60000 (43%)] Loss: 1.797652
Train Epoch: 1 [25600/60000 (43%)] Loss: 1.797652
Train Epoch: 1 [32000/60000 (53%)] Loss: 1.761881
Train Epoch: 1 [32000/60000 (53%)] Loss: 1.761881
Train Epoch: 1 [38400/60000 (64%)] Loss: 1.610887
Train Epoch: 1 [38400/60000 (64%)] Loss: 1.610887
Train Epoch: 1 [44800/60000 (75%)] Loss: 1.588187
Train Epoch: 1 [44800/60000 (75%)] Loss: 1.588187
Train Epoch: 1 [51200/60000 (85%)] Loss: 1.547786
Train Epoch: 1 [51200/60000 (85%)] Loss: 1.547786
Train Epoch: 1 [57600/60000 (96%)] Loss: 1.436790
Train Epoch: 1 [57600/60000 (96%)] Loss: 1.436790
Train Epoch: 2 [0/60000 (0%)] Loss: 1.581392
Train Epoch: 2 [0/60000 (0%)] Loss: 1.581392
Train Epoch: 2 [6400/60000 (11%)] Loss: 1.562857
Train Epoch: 2 [6400/60000 (11%)] Loss: 1.562857
Train Epoch: 2 [12800/60000 (21%)] Loss: 1.432041
Train Epoch: 2 [12800/60000 (21%)] Loss: 1.432041
Train Epoch: 2 [19200/60000 (32%)] Loss: 1.458446
Train Epoch: 2 [19200/60000 (32%)] Loss: 1.458446
Train Epoch: 2 [25600/60000 (43%)] Loss: 1.431952
Train Epoch: 2 [25600/60000 (43%)] Loss: 1.431952
Train Epoch: 2 [32000/60000 (53%)] Loss: 1.420807
Train Epoch: 2 [32000/60000 (53%)] Loss: 1.420807
Train Epoch: 2 [38400/60000 (64%)] Loss: 1.450929
Train Epoch: 2 [38400/60000 (64%)] Loss: 1.450929
Train Epoch: 2 [44800/60000 (75%)] Loss: 1.333202
Train Epoch: 2 [44800/60000 (75%)] Loss: 1.333202
Train Epoch: 2 [51200/60000 (85%)] Loss: 1.522884
Train Epoch: 2 [51200/60000 (85%)] Loss: 1.522884
Train Epoch: 2 [57600/60000 (96%)] Loss: 1.220276
Train Epoch: 2 [57600/60000 (96%)] Loss: 1.220276
Training finished!
Training finished!
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
from torch_xla import runtime as xr
import torch_xla.distributed.spmd as xs
import os
import torch_xla.debug.profiler as xp
os.environ["XLA_IR_DEBUG"] = "1"
os.environ["XLA_HLO_DEBUG"] = "1"
# Enable the SPMD
xr.use_spmd()
# Declare mesh meshes
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
conv_mesh_shape = (int(num_devices/2), 2, 1, 1)
conv_mesh = xs.Mesh(device_ids, conv_mesh_shape, ('data', 'dim1', 'dim2', 'dim3'))
linear_mesh_shape = (int(num_devices/2), 2)
linear_mesh = xs.Mesh(device_ids, linear_mesh_shape, ('data', 'model'))
# Define the CNN Model
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1).to(xm.xla_device())
xs.mark_sharding(self.conv1.weight, conv_mesh, ('data', None, None, None))
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1).to(xm.xla_device())
xs.mark_sharding(self.conv2.weight, conv_mesh, ('data', None, None, None))
self.fc1 = nn.Linear(7*7*64, 128).to(xm.xla_device()) # Adjusted for 28x28 image, 2 pooling layers
xs.mark_sharding(self.fc1.weight, linear_mesh, ('data', None))
self.fc2 = nn.Linear(128, 10).to(xm.xla_device())
xs.mark_sharding(self.fc2.weight, linear_mesh, ('data', 'model'))
def forward(self, x):
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train_mnist():
# Training parameters
epochs = 2
learning_rate = 0.01
momentum = 0.5
batch_size = 64
# 1. Acquire the XLA device
device = xm.xla_device()
print(f"Running on XLA device: {device}")
# Load MNIST dataset
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# 2. Initialize the model and move it to the XLA device
model = MNISTNet().to(device)
# Define loss function and optimizer
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
print("Starting training...")
for epoch in range(1, epochs + 1):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
# 3. Move data and target to the XLA device
data, target = data.to(device), target.to(device)
# 4. Shard input
xs.mark_sharding(data, conv_mesh, ('data', 'dim1', None, None))
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
torch_xla.sync()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
print("Training finished!")
if __name__ == '__main__':
server = xp.start_server(9012)
xp.start_trace('/root/logs/')
train_mnist()
xp.stop_trace()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment