In this notebook, we will show you how we profile and optimize a canonical computer vision training instance: ResNet50 + CIFAR100. In the end, we would achive near 10x speedup from the baseline.
The notebook is brokendown into "Optimization stages", were we incrementally permute our training pipeline to perform system-level optimization on different parts of the training logic.
Please follow along this notebook, follow the comments that show ########## Optimization X ################### for code changes in each optimization stage.
# Install profiler dependencies
!pip install -q torch_tb_profiler tensorboard==2.12.0 tensorboard-plugin-profile==2.11.2 tensorflow==2.12.0 protobuf==3.20.3
import torch
from torch import nn
We include a basic implementation of ResNet50 based on torchvision's implementation, but removes
the extra boilerplate code for better readability.
Please navigate to model.py
to check model details
from model import ResNet50
For the notebook to run in a reasonable time, we pick CIFAR100 as our dataloader, which is an image dataset with 3x32x32 images from 100 classes.
We select a few data-augmentation techniques, they include:
import torchvision
import torchvision.transforms as transforms
def get_loaders(train_bs, val_bs):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.GaussianBlur(3),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=train_bs, shuffle=True)
testset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=val_bs, shuffle=False)
return trainloader, testloader
This training loop below is a standard supervised image classification training procedure as suggested by PyTorch examples. Please skim over this quickly as the details (such as logging, and metrics reporting) don't matter as much for our workshop.
from torch import optim
from tqdm.notebook import tqdm
import sys, os
import time
start_epoch = 0
end_epoch = 2
lr = 0.1
best_acc = 0.0
criterion = nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device", device)
Using device cuda
def train(model, optimizer, epoch, trainloader, prof=None):
print('\nEpoch: %d' % epoch)
model.train()
train_loss = 0
correct = 0
total = 0
with tqdm(total=len(trainloader), file=sys.stdout, ) as pbar:
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(-1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_description('[%3d]/[%3d]Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (batch_idx, len(trainloader), train_loss/(batch_idx+1),
100.*correct/total, correct, total),)
pbar.update(1)
if prof is not None:
prof.step()
if batch_idx == 20:
return
def test(model, optimizer, epoch, testloader):
global best_acc
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
with tqdm(total=len(testloader), file=sys.stdout) as pbar:
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(-1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_description('[%3d]/[%3d]Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (batch_idx, len(testloader), test_loss/(batch_idx+1),
100.*correct/total, correct, total),)
pbar.update(1)
# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
state = {
'model': model.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_acc = acc
The baseline simply runs the training loop for 2 epochs, and reports the training time for the second epoch (when things are more stable).
We use tqdm
to report training progression for each epoch, As it trains, you should notice the loss on the left of the progress bar to decrease, as well as the iteration time on the right of the progress bar.
The expected runtime for this snippet is 160 (80*2) seconds.
trainloader, valloader = get_loaders(128, 128)
import torchvision
model = ResNet50(num_classes=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(start_epoch, end_epoch):
epoch_start_time = time.time()
train(model, optimizer, epoch, trainloader)
epoch_end_time = time.time()
test(model, optimizer, epoch, valloader)
scheduler.step()
if epoch > 0:
print("Training for one epoch takes {:.3f}s".format(epoch_end_time - epoch_start_time))
baseline_time = epoch_end_time - epoch_start_time
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/391 [00:00<?, ?it/s]
0%| | 0/79 [00:00<?, ?it/s]
Epoch: 1
0%| | 0/391 [00:00<?, ?it/s]
0%| | 0/79 [00:00<?, ?it/s]
Training for one epoch takes 78.477s
We choose to profile with the built-in PyTorch profiler for its ease of use.
For more detailed profiling, you could also use vtune
(for Intel CPU), nsys
(for NVIDIA GPU), and
other vendor-specific profiling tools.
However, they require more careful installation and launching procedures, which we do not have the resources to cover here.
trainloader, valloader = get_loaders(128, 128)
model = ResNet50(num_classes=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=5e-4)
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=10, wait=5, warmup=1, active=1, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./profile/baseline'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
train(model, optimizer, 0, trainloader, prof)
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/391 [00:00<?, ?it/s]
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation [W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation STAGE:2023-03-25 06:48:52 11607:11607 ActivityProfilerController.cpp:311] Completed Stage: Warm Up [W CPUAllocator.cpp:235] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event STAGE:2023-03-25 06:48:53 11607:11607 ActivityProfilerController.cpp:317] Completed Stage: Collection STAGE:2023-03-25 06:48:53 11607:11607 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
asplos23-tutorial/project
, run tensorboard --logdir profile --bind_all
to run TensorBoard.<machine_ip>:6006
, then select pytorch_profiler
from the dropdown menu.Views
, select Trace
to view the execution trace.The first optimization concerns with data-augmentation. While many augmentation techniques are insensitive to ordering (i.e. crop after blur is identical to blur after crop), the performance implications are significant.
Instead of issuing ToTensor
last, we issue ToTensor
first in the data augmentation pipeline, this would allow
subsequent operations to run with Tensor
objects which have better hardware utilization due to its better implementation.
def get_loaders(train_bs, val_bs,):
transform_train = transforms.Compose([
####### OPTIMIZATION 1 #################
transforms.ToTensor(),
####### OPTIMIZATION 1 #################
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.GaussianBlur(3),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=train_bs, shuffle=True)
testset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=val_bs, shuffle=False)
return trainloader, testloader
model = ResNet50(num_classes=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
trainloader, valloader = get_loaders(128, 128)
for epoch in range(start_epoch, end_epoch):
epoch_start_time = time.time()
train(model, optimizer, epoch, trainloader)
epoch_end_time = time.time()
test(model, optimizer, epoch, valloader)
scheduler.step()
if epoch > 0:
print("Training for one epoch takes {:.3f}s".format(epoch_end_time - epoch_start_time))
print("Speedup over baseline: {:.2f}".format(baseline_time / (epoch_end_time - epoch_start_time)))
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/391 [00:00<?, ?it/s]
0%| | 0/79 [00:00<?, ?it/s]
Epoch: 1
0%| | 0/391 [00:00<?, ?it/s]
0%| | 0/79 [00:00<?, ?it/s]
Training for one epoch takes 60.969s Speedup over baseline: 1.29
By default, PyTorch does not enable data pre-fetching or loading data with multiple CPUs.
To enable this, we need to set the num_workers
argument in the dataloader.
For a single GPU training setup, the best number is number of cpus - 1 on your machine (7 in this case)
.
Similarly, for GPU tensors, pin_memory=True
would allow CPU tensors to be directly created in the pinned memory region,
which is then copied to the GPU. Otherwise, we would incur an extra CPU-CPU copy.
Lastly, batch size should be increased to the maximum of what your algorithm allows (for convergence) and what your hardware allows (before getting out of memory) for the best GPU utilization due to increased parallelism and memory reuse.
def get_loaders(train_bs, val_bs,):
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.GaussianBlur(3),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
####### OPTIMIZATION 2.1 #################
trainset, batch_size=train_bs, shuffle=True,
pin_memory=True,
num_workers=7,
####### OPTIMIZATION 2.1 #################
)
testset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
####### OPTIMIZATION 2 #################
testset, batch_size=val_bs, shuffle=False,
pin_memory=True,
num_workers=7,
####### OPTIMIZATION 2 #################
)
return trainloader, testloader
####### OPTIMIZATION 2.2 #################
trainloader, valloader = get_loaders(256, 512)
####### OPTIMIZATION 2.2 #################
model = ResNet50(num_classes=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(start_epoch, end_epoch):
epoch_start_time = time.time()
train(model, optimizer, epoch, trainloader)
epoch_end_time = time.time()
test(model, optimizer, epoch, valloader)
scheduler.step()
if epoch > 0:
print("Training for one epoch takes {:.3f}s".format(epoch_end_time - epoch_start_time))
print("Speedup over baseline: {:.2f}".format(baseline_time / (epoch_end_time - epoch_start_time)))
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/196 [00:00<?, ?it/s]
0%| | 0/20 [00:00<?, ?it/s]
Epoch: 1
0%| | 0/196 [00:00<?, ?it/s]
0%| | 0/20 [00:00<?, ?it/s]
Training for one epoch takes 22.582s Speedup over baseline: 3.48
GPUs after Volta microarchitecture (V100, T4, A100, H100, etc) features TensorCores.
These are much faster compute units than traditional 32-bit IEEE-754 floating numbers. Mixed precision leverages these tensor-cores and does additional numerical adjustments to recover the numerical discrepancies (although not identical)
def train(model, optimizer, grad_scalar, epoch, trainloader, prof=None):
print('\nEpoch: %d' % epoch)
model.train()
train_loss = 0
correct = 0
total = 0
with tqdm(total=len(trainloader), file=sys.stdout, ) as pbar:
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
####### OPTIMIZATION 3.1 #################
with torch.autocast(device_type=device):
outputs = model(inputs)
loss = criterion(outputs, targets)
####### OPTIMIZATION 3.1 #################
####### OPTIMIZATION 3.2 #################
grad_scalar.scale(loss).backward()
grad_scalar.step(optimizer)
grad_scalar.update()
####### OPTIMIZATION 3.2 #################
train_loss += loss.item()
_, predicted = outputs.max(-1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_description('[%3d]/[%3d]Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (batch_idx, len(trainloader), train_loss/(batch_idx+1),
100.*correct/total, correct, total),)
pbar.update(1)
if prof is not None:
prof.step()
torch.backends.cuda.benchmark = True
torch.backends.cuda.deterministic = False
trainloader, valloader = get_loaders(256, 512)
model = ResNet50(num_classes=100).to(device)
model.train()
####### OPTIMIZATION 3.3 #################
model.to(memory_format=torch.channels_last)
####### OPTIMIZATION 3.3 #################
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
grad_scalar = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(start_epoch, end_epoch):
epoch_start_time = time.time()
train(model, optimizer, grad_scalar, epoch, trainloader)
epoch_end_time = time.time()
test(model, optimizer, epoch, valloader)
scheduler.step()
if epoch > 0:
print("Training for one epoch takes {:.3f}s".format(epoch_end_time - epoch_start_time))
print("Speedup over baseline: {:.2f}".format(baseline_time / (epoch_end_time - epoch_start_time)))
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/196 [00:00<?, ?it/s]
0%| | 0/20 [00:00<?, ?it/s]
Epoch: 1
0%| | 0/196 [00:00<?, ?it/s]
0%| | 0/20 [00:00<?, ?it/s]
Training for one epoch takes 11.938s Speedup over baseline: 6.57
Just-in-time compilation is a technique to compile PyTorch models for better utilization.
The details of jit require another session to explain, but the APIs are pretty simple, please see below.
Jit works best with static input shapes, so we make the dataloader to drop_last
, which keeps the batch size
always consistent.
def get_loaders(train_bs, val_bs,):
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.GaussianBlur(3),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=train_bs, shuffle=True,
pin_memory=True,
num_workers=8,
####### OPTIMIZATION 4.1 #################
drop_last=True,
####### OPTIMIZATION 4.1 #################
)
testset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=val_bs, shuffle=False,
pin_memory=True,
num_workers=8,
####### OPTIMIZATION 4.2 #################
drop_last=True,
####### OPTIMIZATION 4.2 #################
)
return trainloader, testloader
trainloader, valloader = get_loaders(256, 512)
model = ResNet50(num_classes=100).to(device)
model.train()
model.to(memory_format=torch.channels_last)
####### OPTIMIZATION 4.3 #################
traced_model = torch.jit.trace(model, (torch.rand(256, 3, 32, 32, device=device),))
####### OPTIMIZATION 4.3 #################
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
grad_scalar = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(start_epoch, end_epoch):
epoch_start_time = time.time()
train(traced_model, optimizer, grad_scalar, epoch, trainloader)
epoch_end_time = time.time()
test(model, optimizer, epoch, valloader)
scheduler.step()
if epoch > 0:
print("Training for one epoch takes {:.3f}s".format(epoch_end_time - epoch_start_time))
print("Speedup over baseline: {:.2f}".format(baseline_time / (epoch_end_time - epoch_start_time)))
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/195 [00:00<?, ?it/s]
0%| | 0/19 [00:00<?, ?it/s]
Epoch: 1
0%| | 0/195 [00:00<?, ?it/s]
0%| | 0/19 [00:00<?, ?it/s]
Training for one epoch takes 11.360s Speedup over baseline: 6.91
CUDAGraph captures the sequence of GPU operations and optimizes them into a single GPU operation. This reduces overhead significantly.
CUDAGraph also requires static shapes and computation patterns, which have limited use-cases. Please use it with caution.
trainloader, valloader = get_loaders(256, 512)
model = ResNet50(num_classes=100).to(device)
model.train()
model.to(memory_format=torch.channels_last)
traced_model = torch.jit.trace(model, (torch.rand(256, 3, 32, 32, device=device),))
####### OPTIMIZATION 5 #################
with torch.amp.autocast(device_type=device, cache_enabled=False):
graphed_model = torch.cuda.make_graphed_callables(traced_model, (torch.rand(256, 3, 32, 32, device=device),))
####### OPTIMIZATION 5 #################
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
grad_scalar = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(start_epoch, end_epoch):
epoch_start_time = time.time()
train(graphed_model, optimizer, grad_scalar, epoch, trainloader)
epoch_end_time = time.time()
test(model, optimizer, epoch, valloader)
scheduler.step()
if epoch > 0:
print("Training for one epoch takes {:.3f}s".format(epoch_end_time - epoch_start_time))
print("Speedup over baseline: {:.2f}".format(baseline_time / (epoch_end_time - epoch_start_time)))
Files already downloaded and verified Files already downloaded and verified Epoch: 0
0%| | 0/195 [00:00<?, ?it/s]
0%| | 0/19 [00:00<?, ?it/s]
Epoch: 1
0%| | 0/195 [00:00<?, ?it/s]
0%| | 0/19 [00:00<?, ?it/s]
Training for one epoch takes 9.217s Speedup over baseline: 8.51
Optimizations are heavily dependent on your workload and metrics, and is fairly complex. The code changes above may look simple, but it was a significant engineering effort to interatively profile and subsequently modify the source code. One would argue that: a performance engineer can only be as good as the profiler she uses :)
Created by Xin Li (xin@centml.ai), adapted by Yubo Gao (ybgao@centml.ai) for ASPLOS '23.
Please email xin@centml.ai for any questions!