In [8]:
import numpy as np
import numpy.testing
import hidet


def matmul_func(m_size, n_size, k_size):
    from hidet.lang import attr, f32, tensor
    from hidet.lang import spatial, repeat
    from hidet.lang.cuda import threadIdx, blockIdx, blockDim, syncthreads
    from hidet.transforms.tools import add_packed_func

    def ceil_div(a, b):
        return (a + b - 1) // b

    tm, tn, tk = 32, 32, 128

    assert tk % tm == 0
    assert tk % tn == 0
    # make sure the matrix size is divisible by the tile size
    assert m_size % tm == 0 and n_size % tn == 0 and k_size % tk == 0

    with hidet.script_module() as script_module:

        @hidet.script
        def kernel(
            a: f32[m_size, k_size],
            b: f32[k_size, n_size],
            c: f32[m_size, n_size]
        ):
            attr.func_kind = 'cuda_kernel'
            attr.cuda_block_dim = tn, tm
            attr.cuda_grid_dim = n_size / tn, m_size / tm

            smem_a = tensor(scope='shared', dtype='float32', shape=[tm, tk])
            smem_b = tensor(scope='shared', dtype='float32', shape=[tk, tn])

            acc = f32(0.0)
            for k_tile in range(k_size / tk):
                gmem_a = a[blockIdx.y * tm: , k_tile * tk: ]
                gmem_b = b[k_tile * tk: , blockIdx.x * tn: ]

                # load data from global memory to shared memory
                tid = threadIdx.x + threadIdx.y * blockDim.y

                for i, k in repeat(1, tk / tn).spatial(tm, tn).on(tid):
                    smem_a[i, k] = gmem_a[i, k]

                for k, j in repeat(tk / tm, 1).spatial(tm, tn).on(tid):
                    smem_b[k, j] = gmem_b[k, j]

                syncthreads()

                # compute
                for k in range(tk):
                    acc += smem_a[threadIdx.y, k] * smem_b[k, threadIdx.x]
                syncthreads()

            # write result
            gi, gj = blockIdx.y * tm + threadIdx.y, blockIdx.x * tn + threadIdx.x
            c[gi, gj] = acc

    ir_module = script_module.ir_module()
    add_packed_func(ir_module, func=kernel, pack_func_name='matmul')
    return hidet.driver.build_ir_module(ir_module, func_name='matmul')


m_size, n_size, k_size = 1024, 1024, 1024
matmul = matmul_func(m_size, n_size, k_size)
print(matmul.source(color=True))
#include <stdint.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
extern "C" {

__global__ void __launch_bounds__(1024) hidet_kernel(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
  __shared__ float smem_a[4096];
  __shared__ float smem_b[4096];
  float acc = 0.0f;
  for (int32_t k_tile = 0; (k_tile < 8); k_tile = (k_tile + 1)) {
    int32_t tid = ((int)threadIdx.x + ((int)threadIdx.y * 32));
    for (int32_t i = 0; (i < 4); i = (i + 1)) {
      smem_a[(((tid / 32) * 128) + ((i * 32) + (tid % 32)))] = a[(((((int)blockIdx.y * 32) + (tid / 32)) * 1024) + ((k_tile * 128) + ((i * 32) + (tid % 32))))];
    } 
    for (int32_t i_1 = 0; (i_1 < 4); i_1 = (i_1 + 1)) {
      smem_b[((((i_1 * 32) + (tid / 32)) * 32) + (tid % 32))] = b[((((k_tile * 128) + ((i_1 * 32) + (tid / 32))) * 1024) + (((int)blockIdx.x * 32) + (tid % 32)))];
    } 
    __syncthreads();
    for (int32_t k = 0; (k < 128); k = (k + 1)) {
      acc = (acc + (smem_a[(((int)threadIdx.y * 128) + k)] * smem_b[((k * 32) + (int)threadIdx.x)]));
    } 
    __syncthreads();
  } 
  c[(((((int)blockIdx.y * 32) + (int)threadIdx.y) * 1024) + (((int)blockIdx.x * 32) + (int)threadIdx.x))] = acc;
}

__host__ void hidet_matmul(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
  assert(((void)"Expect 3 arguments", (num_args == 3)));
  assert(((void)"The 0-th argument should be TensorPointerType(tensor(float32, [1024, 1024]))", (arg_types[0] == 3)));
  assert(((void)"The 1-th argument should be TensorPointerType(tensor(float32, [1024, 1024]))", (arg_types[1] == 3)));
  assert(((void)"The 2-th argument should be TensorPointerType(tensor(float32, [1024, 1024]))", (arg_types[2] == 3)));
  hidet_kernel<<<dim3(32, 32, 1), dim3(32, 32, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((float*)(args[0])), ((float*)(args[1])), ((float*)(args[2])));
}

}

In [19]:
a = hidet.randn([m_size, k_size]).cuda()
b = hidet.randn([k_size, n_size]).cuda()
c = hidet.empty([m_size, n_size]).cuda()
matmul(a, b, c)

np_a = a.cpu().numpy()
np_b = b.cpu().numpy()
np_c = np.matmul(np_a, np_b)

numpy.testing.assert_allclose(c.cpu().numpy(), np_c, rtol=1e-4, atol=1e-4)
print('Correctness: Pass')

latency = hidet.utils.benchmark_func(lambda: matmul(a, b, c), number=20, repeat=20)
print('    Latency: {:.2f} ms'.format(latency))
Correctness: Pass
    Latency: 0.69 ms
In [ ]: