This example implements a matrix multiplication without any optimizations, which is about 40~ slower than the carefully implemented matrix multiplication kernel.
import numpy as np
import numpy.testing
import hidet
def matmul_func(m_size, n_size, k_size):
from hidet.lang import attr, f32
from hidet.lang.cuda import threadIdx, blockIdx, blockDim
from hidet.transforms.tools import add_packed_func
def ceil_div(a, b):
return (a + b - 1) // b
tile_size = 16
with hidet.script_module() as script_module:
@hidet.script
def kernel(
a: f32[m_size, k_size],
b: f32[k_size, n_size],
c: f32[m_size, n_size]
):
attr.func_kind = 'cuda_kernel'
attr.cuda_block_dim = (tile_size, tile_size)
attr.cuda_grid_dim = ceil_div(m_size, tile_size), ceil_div(n_size, tile_size)
# current thread only works on the c[i, j] element
i = threadIdx.x + blockIdx.x * blockDim.x
j = threadIdx.y + blockIdx.y * blockDim.y
if i < m_size and j < n_size:
acc = f32(0.0)
for k in range(k_size):
acc += a[i, k] * b[k, j]
c[i, j] = acc
ir_module = script_module.ir_module()
add_packed_func(ir_module, func=kernel, pack_func_name='matmul')
return hidet.driver.build_ir_module(ir_module, func_name='matmul')
m_size, n_size, k_size = 1024, 1024, 1024
matmul = matmul_func(m_size, n_size, k_size)
print(matmul.source(color=True))
#include <stdint.h> #include <cuda_fp16.h> #include <cuda_bf16.h> #include <hidet/runtime/cuda_context.h> #include <hidet/runtime/cpu_context.h> typedef float tfloat32_t; #define __float_to_tf32(x) (x) extern "C" { __global__ void __launch_bounds__(256) hidet_kernel(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) { int32_t i = ((int)threadIdx.x + ((int)blockIdx.x * 16)); int32_t j = ((int)threadIdx.y + ((int)blockIdx.y * 16)); float acc = 0.0f; for (int32_t k = 0; (k < 1024); k = (k + 1)) { acc = (acc + (a[((i * 1024) + k)] * b[((k * 1024) + j)])); } c[((i * 1024) + j)] = acc; } __host__ void hidet_matmul(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) { assert(((void)"Expect 3 arguments", (num_args == 3))); assert(((void)"The 0-th argument should be TensorPointerType(tensor(float32, [1024, 1024]))", (arg_types[0] == 3))); assert(((void)"The 1-th argument should be TensorPointerType(tensor(float32, [1024, 1024]))", (arg_types[1] == 3))); assert(((void)"The 2-th argument should be TensorPointerType(tensor(float32, [1024, 1024]))", (arg_types[2] == 3))); hidet_kernel<<<dim3(64, 64, 1), dim3(16, 16, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((float*)(args[0])), ((float*)(args[1])), ((float*)(args[2]))); } }
a = hidet.randn([m_size, k_size]).cuda()
b = hidet.randn([k_size, n_size]).cuda()
c = hidet.empty([m_size, n_size]).cuda()
matmul(a, b, c)
np_a = a.cpu().numpy()
np_b = b.cpu().numpy()
np_c = np.matmul(np_a, np_b)
numpy.testing.assert_allclose(c.cpu().numpy(), np_c, rtol=1e-4, atol=1e-4)
print('Correctness: Pass')
Correctness: Pass
latency = hidet.utils.benchmark_func(lambda: matmul(a, b, c))
print('Latency: {:.2f} ms'.format(latency))
Latency: 4.12 ms