Hidet is a high-performance inference engine for deep learning models. It is designed to be easy to use and has integrated with PyTorch and ONNX frontend. In this tutorial, we will walk through how to optimize PyTorch models with Hidet.
import torch
assert torch.__version__ >= '2.0'
The most convenient way to use Hidet is to use the torch.compile
API. It will automatically optimize the model by dispatching it to the specified backend. This feature is introduced in PyTorch 2.0.
model = torch.hub.load(
'pytorch/vision:v0.6.0',
model='resnet18',
pretrained=True,
verbose=False
).cuda().eval()
x = torch.randn(1, 3, 224, 224).cuda()
y1 = model(x)
We will use the ResNet 50 model as an example. The model is downloaded from the PyTorch model zoo. The model is first run on the CPU to get the reference output y1
.
import hidet
# hidet.torch.dynamo_config.search_space(2)
model_opt = torch.compile(model, backend='hidet')
y2 = model_opt(x)
/home/yaoyao/miniconda3/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm Compiling cuda task rearrange(x=float32[1000, 512])... Compiling cuda task mul(x=float32[128, 64, 3, 3], y=float32[128, 1, 1, 1])... Compiling cuda task mul(x=float32[128, 64, 1, 1], y=float32[128, 1, 1, 1])... Compiling cuda task mul(x=float32[256, 128, 3, 3], y=float32[256, 1, 1, 1])... Compiling cuda task mul(x=float32[256, 128, 1, 1], y=float32[256, 1, 1, 1])... Compiling cuda task mul(x=float32[512, 256, 3, 3], y=float32[512, 1, 1, 1])... Compiling cuda task broadcast(data=float32[512, 1000])... Compiling cuda task rearrange(x=float32[1, 8, 64, 1000])... Compiling cuda task rearrange(x=float32[1, 128, 64, 3, 3])... Compiling cuda task rearrange(x=float32[1, 128, 64, 1, 1])... Compiling cuda task rearrange(x=float32[1, 256, 128, 3, 3])... Compiling cuda task rearrange(x=float32[1, 256, 128, 1, 1])... Compiling cuda task rearrange(x=float32[1, 512, 256, 3, 3])... Compiling cuda task rearrange(x=float32[1, 4, 144, 128])... Compiling cuda task rearrange(x=float32[1, 4, 16, 128])... Compiling cuda task rearrange(x=float32[1, 8, 144, 256])... Compiling cuda task rearrange(x=float32[1, 8, 16, 256])... Compiling cuda task rearrange(x=float32[1, 8, 288, 512])... Compiling cuda task rearrange(x=float32[1, 8, 32, 512])... Compiling cuda task batch_matmul(a=float32[4, 3136, 144], b=float32[4, 144, 64], batch_size=4, m_size=3136, n_size=64, k_size=144, mma=simt) (5 fused)... Compiling cuda task reduce_sum(x=float32[1, 4, 3136, 64], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task batch_matmul(a=float32[4, 784, 144], b=float32[4, 144, 128], batch_size=4, m_size=784, n_size=128, k_size=144, mma=simt) (5 fused)... Compiling cuda task reduce_sum(x=float32[1, 4, 784, 128], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (2 fused)... Compiling cuda task max_pool2d(x=float32[1, 64, 112, 112])... Compiling cuda task batch_matmul(a=float32[4, 784, 16], b=float32[4, 16, 128], batch_size=4, m_size=784, n_size=128, k_size=16, mma=simt) (4 fused)... Compiling cuda task reduce_sum(x=float32[1, 4, 784, 128], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task reduce_sum(x=float32[1, 4, 784, 128], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task reduce_sum(x=float32[1, 8, 196, 256], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (2 fused)... Compiling cuda task batch_matmul(a=float32[4, 784, 288], b=float32[4, 288, 128], batch_size=4, m_size=784, n_size=128, k_size=288, mma=simt) (5 fused)... Compiling cuda task batch_matmul(a=float32[8, 196, 144], b=float32[8, 144, 256], batch_size=8, m_size=196, n_size=256, k_size=144, mma=simt) (5 fused)... Compiling cuda task batch_matmul(a=float32[8, 196, 16], b=float32[8, 16, 256], batch_size=8, m_size=196, n_size=256, k_size=16, mma=simt) (4 fused)... Compiling cuda task reduce_sum(x=float32[1, 8, 196, 256], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task batch_matmul(a=float32[8, 196, 288], b=float32[8, 288, 256], batch_size=8, m_size=196, n_size=256, k_size=288, mma=simt) (5 fused)... Compiling cuda task reduce_sum(x=float32[1, 8, 196, 256], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task batch_matmul(a=float32[8, 49, 288], b=float32[8, 288, 512], batch_size=8, m_size=49, n_size=512, k_size=288, mma=simt) (5 fused)... Compiling cuda task batch_matmul(a=float32[8, 49, 32], b=float32[8, 32, 512], batch_size=8, m_size=49, n_size=512, k_size=32, mma=simt) (4 fused)... Compiling cuda task reduce_sum(x=float32[1, 8, 49, 512], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (2 fused)... Compiling cuda task reduce_sum(x=float32[1, 8, 49, 512], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task batch_matmul(a=float32[8, 49, 576], b=float32[8, 576, 512], batch_size=8, m_size=49, n_size=512, k_size=576, mma=simt) (5 fused)... Compiling cuda task reduce_sum(x=float32[1, 8, 49, 512], dims=[1], keep_dim=False, reduce_type=sum, accumulate_dtype=float32) (5 fused)... Compiling cuda task adaptive_avg_pool2d(x=float32[1, 512, 7, 7], output_size=[1, 1]) (1 fused)... Compiling cuda task batch_matmul(a=float32[8, 1, 64], b=float32[8, 64, 1000], batch_size=8, m_size=1, n_size=1000, k_size=64, mma=simt) (4 fused)...
We can optimize a model by calling torch.compile
with the backend
argument set to hidet
. The optimized model is returned as model_opt
. The optimized model is run on the GPU to get the output y2
.
Hidet will automatically conduct a series of graph-level optimizations and generate kernels for each operator. We can search the most performant kernel for each operator and its input shape by setting the search space. By default, the search space is 0 and indicates no searching. The search space can be set to 1 or 2. The larger the search space is, the more kernels will be searched and the longer the compilation time will be. The search space can be set by calling hidet.torch.dynamo_config.search_space(1 or 2)
.
# Check the results
torch.testing.assert_close(actual=y2, expected=y1, rtol=1e-2, atol=1e-2)
The results are compared to the reference output y1
to make sure the optimization is correct.
Besides the search_space
, hidet provides several options to control the optimization. All the options can be set by calling hidet.torch.dynamo_config.xxx()
where xxx
is the option name. The following options are available:
config = hidet.torch.dynamo_config
config.search_space(level=2) # set the search space, level can be 0, 1 or 2.
config.correctness_report() # print the correctness report (the error between the optimized model and the reference pytorch model)
config.use_fp16() # convert the model to FP16
config.dump_graph_ir(output_dir='./outs') # debug: dump the hidet graph IR
config.print_input_graph() # debug: print the received torch.fx.Graph from torch dynamo
<hidet.graph.frontend.torch.dynamo_config.DynamoConfig at 0x7fddd86acb20>