Hidet is designed to be extensible. It is easy to add new operators to Hidet. There are two ways to add and schedule an operator.
In this tutorial, we will walk through how to define the computation of an operator and schedule it using the two methods.
import hidet
Each operator takes a list of input tensors and produces a list of output tensors:
inputs: List[Tensor]
outputs: List[Tensor] = operator(inputs)
The precise mathematical definition of each operator in Hidet is defined through a domain-specific-language (DSL). In this tutorial, we will show how to define the mathematical definition of a new operator in Hidet using this DSL, which is defined in the hidet.ir.compute module.
Hidet provides compute primitives to define the mathematical computation of an operator.
tensor_input(name: sttr, dtype: str, shape: List[int])
The tensor_input()
primitive defines a tensor inputby specifying the name hint, scalar data type, and shape of the tensor.
Examples:
a = tensor_input('a', dtype='float32', shape=[10, 10])
b = tensor_input('b', dtype='float32', shape=[])
b = tensor_input('data', dtype='float16', shape=[1, 3, 224, 224])
compute(name: str, shape: List[int], fcompute: Callable[[Var, ...], Expr])
The compute()
primitive defines a tensor by specifying
The computation of each element of the tensor is independent with each other and can be computed in parallel.
Semantics:
# compute primitive
out = compute(
name='hint_name',
shape=[n1, n2, ..., nk],
fcompute=lambda i1, i2, ..., ik: f(i1, i2, ..., ik)
)
# semantics
for i1 in range(n1):
for i2 in range(n2):
...
for ik in range(nk):
out[i1, i2, ..., ik] = f(i1, i2, ..., ik)
Examples:
# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])
# example 1: slice the first column of a
b = compute('slice', shape=[10], fcompute=lambda i: a[i, 0])
# example 2: reverse the rows of matrix a
c = compute('reverse', shape=[10, 10], fcompute=lambda i, j: a[9 - i, j])
# example 3: add 1 to the diagonal elements of a
from hidet.ir.expr import if_then_else
d = compute(
name='diag_add',
shape=[10, 10],
fcompute=lambda i, j: if_then_else(i == j, then_expr=a[i, j] + 1.0, else_expr=a[i, j])
)
reduce(shape: List[int], fcompute: Callable[[Var, ...], Expr], reduce_type='sum')
The reduce()
primitive conducts a reduction operation on a domain with the given shape. It returns a scalar value and can be used in compute()
primitive.
Semantics:
# reduce primitive
out = reduce(
name='hint_name',
shape=[n1, n2, ..., nk],
fcompute=lambda i1, i2, ..., ik: f(i1, i2, ..., ik)
reduce_type='sum' | 'max' | 'min' | 'avg'
)
# semantics
values = []
for i1 in range(n1):
for i2 in range(n2):
...
for ik in range(nk):
values.append(f(i1, i2, ..., ik))
out = reduce_type(values)
Examples:
# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])
# example 1: sum all elements of a
c = reduce(shape=[10, 10], fcompute=lambda i, j: a[i, j], reduce_type='sum')
# example 2: sum the first column of a
d = reduce(shape=[10], fcompute=lambda i: a[i, 0], reduce_type='sum')
# example 3: matrix multiplication
b = tensor_input('b', dtype='float32', shape=[10, 10])
e = compute(
name='e',
shape=[10, 10],
fcompute=lambda i, j: reduce(
shape=[10],
fcompute=lambda k: a[i, k] * b[k, j],
reduce_type='sum'
)
)
The computation of each operator can be described as a directed acyclic graph (DAG). The DAG is composed of tensor nodes. Both tensor_input()
and compute()
primitives create tensor nodes. The edges of the DAG are the dependencies between the tensor nodes. Such a DAG is stored in a Task
object.
class Task(name: str, inputs: List[TensorNode], outputs: List[TensorNode])
Each task has a name, a list of inputs, and a list of outputs, correspongding to the inputs and outputs of the operator. The following example shows how to create a task.
def demo_task():
from hidet.ir.compute import tensor_input, compute
from hidet.ir.task import Task
# define the computation DAG through the compute primitives
a = tensor_input('a', dtype='float32', shape=[10])
b = tensor_input('b', dtype='float32', shape=[10])
c = compute('c', [10], lambda i: a[i] + i)
d = compute('d', [10], lambda i: c[9 - i])
e = compute('e', [10], lambda i: a[i] + b[i])
# create a task object
task = Task(name='task', inputs=[a, b], outputs=[d, e])
print(task)
demo_task()
Task( name: task parameters: a: tensor(float32, [10]) b: tensor(float32, [10]) d: tensor(float32, [10]) e: tensor(float32, [10]) inputs: [a, b] outputs: [d, e] computations: b: tensor(float32, [10]) e: float32[10] where e[v] = (a[v] + b[v]) a: tensor(float32, [10]) c: float32[10] where c[v_1] = (a[v_1] + v_1) d: float32[10] where d[v_2] = c[(9 - v_2)] attributes: {} )
In the above example, there are 5 tensor nodes, where node a
and b
are inputs and node d
and e
. The computation of node c
depends on the computation of node a
. Node d
depends on node c
, and node e
depends on both nodes a
and b
.
We provide a driver function hidet.driver.build_task()
to build a task into callable function. The build_task()
function does the following steps to lower the task into a callable function:
IRModule
.IRModule
.source.cu
).nvcc
) to compile the source code into a dynamic library (i.e., lib.so
).CompiledFunction
that can be directly called.We can define the following function to build and run a task.
from typing import List
from hidet.ir.task import Task
def run_task(task: Task, inputs: List[hidet.Tensor], outputs: List[hidet.Tensor]):
"""Run given task and print inputs and outputs"""
from hidet.runtime import CompiledFunction
# build the task
func: CompiledFunction = hidet.driver.build_task(task, target_device='cpu')
params = inputs + outputs
# run the compiled task
func(*params)
print('Task:', task.name)
print('Inputs:')
for tensor in inputs:
print(tensor)
print('Output:')
for tensor in outputs:
print(tensor)
print()
The following code shows how to 1) define the computation, 2) define the task, and 3) build and run the task.
from hidet.ir.compute import tensor_input, reduce, compute, arg_reduce, TensorNode
def add_example():
a: TensorNode = tensor_input(name='a', dtype='float32', shape=[5])
b: TensorNode = tensor_input(name='b', dtype='float32', shape=[5])
c: TensorNode = compute(name='c', shape=[5], fcompute=lambda i: a[i] + b[i])
task = Task(name='add', inputs=[a, b], outputs=[c])
run_task(task, [hidet.randn([5]), hidet.randn([5])], [hidet.empty([5])])
add_example()
Task: add Inputs: Tensor(shape=(5,), dtype='float32', device='cpu') [-0.6754078 0.31349733 -1.6539606 0.9715885 -0.63018745] Tensor(shape=(5,), dtype='float32', device='cpu') [-0.31293884 -0.46955818 0.5940415 -2.2681134 -0.5740183 ] Output: Tensor(shape=(5,), dtype='float32', device='cpu') [-0.9883467 -0.15606084 -1.0599191 -1.2965249 -1.2042058 ]
def reduce_sum_example():
a = tensor_input('a', dtype='float32', shape=[4, 3])
b = compute(
'b',
shape=[4],
fcompute=lambda i: reduce(
shape=[3], fcompute=lambda j: a[i, j], reduce_type='sum'
),
)
task = Task('reduce_sum', inputs=[a], outputs=[b])
run_task(task, [hidet.randn([4, 3])], [hidet.empty([4])])
reduce_sum_example()
Task: reduce_sum Inputs: Tensor(shape=(4, 3), dtype='float32', device='cpu') [[-1.6562179 -0.21710683 0.784083 ] [-1.0092926 0.19702992 -0.6048088 ] [ 0.2052615 0.22812256 -0.6434674 ] [ 0.62891465 -0.14734371 -2.736242 ]] Output: Tensor(shape=(4,), dtype='float32', device='cpu') [-1.0892417 -1.4170715 -0.21008337 -2.254671 ]
def matmul_example():
a = tensor_input('a', dtype='float32', shape=[3, 3])
b = tensor_input('b', dtype='float32', shape=[3, 3])
c = compute(
'c',
shape=[3, 3],
fcompute=lambda i, j: reduce(
shape=[3], fcompute=lambda k: a[i, k] * b[k, j], reduce_type='sum'
),
)
task = Task('matmul', inputs=[a, b], outputs=[c])
run_task(task, [hidet.randn([3, 3]), hidet.randn([3, 3])], [hidet.empty([3, 3])])
matmul_example()
Task: matmul Inputs: Tensor(shape=(3, 3), dtype='float32', device='cpu') [[-1.4392755 -1.1094588 0.21107866] [-0.39272723 -0.36380586 0.32664967] [ 0.56057245 -1.1792772 0.01077561]] Tensor(shape=(3, 3), dtype='float32', device='cpu') [[ 1.5182452 -0.65522957 0.8335112 ] [ 0.86957115 0.6943332 0.26812094] [ 0.68326145 -0.47217155 -0.41662267]] Output: Tensor(shape=(3, 3), dtype='float32', device='cpu') [[-3.0057044 0.07305647 -1.5850616 ] [-0.6894243 -0.14951068 -0.56097615] [-0.16701637 -1.1912029 0.14656514]]
So far, we have learned how to define the computation using compute primitives and wrap it into a Task
. In this section, we will learn how to add an Operator
with the given computation definition, and use hidet's privided rule-based scheduler to automatically schedule the computation into a tensor program.
There are three steps to define a new operator in Hidet.
Task
.Operator
.We will take the batch matrix multiplication as an example to illustrate the three steps.
We define the computation task class BatchMatmulTask
by inheriting Task
class. The BatchMatmulTask
class’s constructor function takes two arguments, a
and b
that are the input tensor nodes of the batch matrix multiplication.
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
class BatchMatmulTask(Task):
def __init__(self, a: TensorNode, b: TensorNode):
# get the input sizes
batch_size, m_size, k_size = a.const_shape()
batch_size, k_size, n_size = b.const_shape()
# define the computation
c = compute(
name='c',
shape=[batch_size, m_size, n_size],
fcompute=lambda p, i, j: reduce(
shape=[k_size],
fcompute=lambda k: a[p, i, k] * b[p, k, j],
reduce_type='sum',
),
)
# call the parent class constructor to initialize the task
super().__init__(
name='batch_matmul', # the name of the task
inputs=[a, b], # the input tensor nodes
outputs=[c], # the output tensor nodes
)
from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like
class BatchMatmulOp(Operator):
def __init__(self, a: Tensor, b: Tensor):
# call the parent class constructor to initialize the operator
super().__init__(
inputs=[a, b], # the input tensors
task=BatchMatmulTask( # the task of the operator
# create tensor nodes (TensorNode) with the same shape and dtype as the tensors (Tensor)
input_like(a, 'a'),
input_like(b, 'b'),
),
)
We define a function batch_matmul
to create the operator instance BatchMatmulOp
and return the output tensor.
def batch_matmul(a: Tensor, b: Tensor) -> Tensor:
# get_output(0) returns the first output tensor of the operator
return BatchMatmulOp(a, b).get_output(0)
The new operator has no difference with the hidet provided operators, as we define hidet operators in the same way. For example, when we optimize the flow graph, this new operator can also fuse surrounding operators.
def demo_usage():
a = hidet.randn([2, 2, 3])
b = hidet.randn([2, 3, 2])
c = batch_matmul(a, b)
print(a)
print(b)
print(c)
demo_usage()
Tensor(shape=(2, 2, 3), dtype='float32', device='cpu') [[[ 0.53905475 1.0962611 -1.6998249 ] [-0.86538684 0.18080942 0.99257547]] [[ 1.7569121 -2.3756104 1.0889981 ] [-2.5179853 -1.078981 -0.48675174]]] Tensor(shape=(2, 3, 2), dtype='float32', device='cpu') [[[-0.1587849 -3.019243 ] [-0.13630112 1.0938233 ] [-0.16682196 1.531905 ]] [[ 0.22785178 -0.98469055] [ 1.9682748 0.08787339] [ 2.2873135 0.24584681]]] Tensor(shape=(2, 2, 2), dtype='float32', device='cpu') [[[ 0.04855275 -3.0323915 ] [-0.05281755 4.331118 ]] [[-1.7846584 -1.6710411 ] [-3.8108125 2.2649562 ]]]
We only define the computation of the operator, and leave the scheduling to the rule-based scheduler provided by hidet. We call this method of scheduling as rule-based scheduling. Most hidet operators are using the same rule-based scheduler as we used in this example. Our experience shows that the rule-based scheduler can achieve good performance for operators that do not have large amount of reduction. However, for operators like matrix multiplication, convolution, etc., the rule-based scheduler may not be able to achieve the best performance as it does not use shared memory to cache the data loading. Thus, hidet also provides another scheduling mechanism, the template-based scheduling.
Template-based scheduling allows us to define a tensor program template, and the template will be instantiated for different input shapes and tunable hyper-parameters.
implement_cuda()
method¶The Task
class have two methods, implement_cpu()
and implement_cuda()
that we can override when we define a new task.
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
class BatchMatmulFp16Task(Task):
def __init__(self, a: TensorNode, b: TensorNode):
batch_size, m_size, k_size = a.const_shape()
batch_size, k_size, n_size = b.const_shape()
c = compute(
name='c',
shape=[batch_size, m_size, n_size],
fcompute=lambda p, i, j: reduce(
shape=[k_size],
fcompute=lambda k: a[p, i, k] * b[p, k, j],
reduce_type='sum',
),
)
super().__init__(
name='batch_matmul_fp16',
inputs=[a, b],
outputs=[c],
attributes={
'batch_size': batch_size,
'm_size': m_size,
'n_size': n_size,
'k_size': k_size,
},
)
def allow_epilogue(self) -> bool:
return False
def implement_cuda(self, working_dir: str) -> IRModule:
# override this method to use template-based scheduling
return batch_matmul_mma_fp16_schedule(self)
In above task definition, we override the implement_cuda()
method to use template-based scheduling. Inside the implement_cuda()
method, we call the batch_matmul_mma_fp16_schedule()
function which we will write to get a tensor program that implements the computation defined in the task.
We can implement the batch_matmul_mma_fp16_schedule()
function in the following way. This function is written using Hidet Script, a DSL for writing tensor programs, which we will explore in detail in the next section. Understanding the below code requires knowledge in Hidet Script and efficient CUDA programming. For now, we can skip the details of this implementation.
def batch_matmul_mma_fp16_schedule(task: BatchMatmulFp16Task) -> IRModule:
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf, cast
from hidet.lang.mapping import repeat, spatial
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.lang.cuda import MmaConfig, mma_sync
from hidet.transforms.tools import add_packed_func
# get the workload size
bs = task.attributes['batch_size']
m_size = task.attributes['m_size']
n_size = task.attributes['n_size']
k_size = task.attributes['k_size']
# define the template hyper-parameters
mma_config = MmaConfig.m16n8k8_f16_f16()
block_m, block_n, block_k = 128, 128, 8
warp_m, warp_n, warp_k = 64, 64, 8
warp_count_m, warp_count_n, warp_count_k = 2, 2, 1
mma_m, mma_n, mma_k = mma_config.m, mma_config.n, mma_config.k # 16, 8, 8
mma_count_m, mma_count_n, mma_count = 4, 8, 1
threads = warp_count_m * warp_count_n * warp_count_k * 32
# define the tensor program
with hidet.script_module() as module:
@hidet.script
def load_regs_a(
smem_a: f16[block_m, block_k], regs_a: f16[4, mma_config.a_elements]
):
"""Load A registers from shared memory."""
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
warp_id
):
for mi in range(mma_count_m):
p = 0
for i, k in mma_config.a_load_map.on(lane_id):
regs_a[mi, p] = smem_a[
wi * warp_m + mi * mma_m + i, wk * warp_k + k
]
p += 1
@hidet.script
def load_regs_b(
smem_b: f16[block_k, block_n], regs_b: f16[8, mma_config.b_elements]
):
"""Load B registers from shared memory."""
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
warp_id
):
for mj in range(mma_count_n):
p = 0
for k, j in mma_config.b_load_map.on(lane_id):
regs_b[mj, p] = smem_b[
wk * warp_k + k, wj * warp_n + mj * mma_n + j
]
p += 1
@hidet.script
def warp_mma(
regs_a: f16[4, mma_config.a_elements],
regs_b: f16[8, mma_config.b_elements],
regs_c: f16[4, 8, mma_config.c_elements],
):
"""Perform warp-level matrix multiplication."""
for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
mma_sync(mma_config, ~regs_a[mi, 0], ~regs_b[mj, 0], ~regs_c[mi, mj, 0])
@hidet.script
def store_c(regs_c: f16[4, 8, mma_config.c_elements], c: f16[bs, m_size, n_size]):
"""Store C registers to global memory."""
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
gmem_c = c[blockIdx.z, offset_m:, offset_n:]
for k_round in range(warp_count_k):
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(
warp_id
):
if wk == k_round:
for mi, mj in repeat(mma_count_m, mma_count_n).on(0):
p = 0
for i, j in mma_config.c_store_map.on(lane_id):
gmem_c.write(
[
wi * warp_m + mi * mma_m + i,
wj * warp_n + mj * mma_n + j,
],
regs_c[mi, mj, p],
protected=True,
)
p += 1
@hidet.script
def batch_matmul_kernel(
a: f16[bs, m_size, k_size],
b: f16[bs, k_size, n_size],
c: f16[bs, m_size, n_size],
):
"""Batch matrix multiplication kernel."""
attr.cuda_grid_dim = (
(m_size + block_m - 1) // block_m,
(n_size + block_n - 1) // block_n,
bs,
)
attr.cuda_block_dim = threads
offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
smem_a = tensor('shared', 'float16', [block_m, block_k])
smem_b = tensor('shared', 'float16', [block_k, block_n])
regs_a = tensor('register', 'float16', [4, mma_config.a_elements])
regs_b = tensor('register', 'float16', [8, mma_config.b_elements])
regs_c = tensor('register', 'float16', [4, 8, mma_config.c_elements])
for i, j, p in grid(4, 8, mma_config.c_elements):
regs_c[i, j, p] = 0.0
for k0 in range((k_size + block_k - 1) // block_k):
offset_k = k0 * block_k
gmem_a = a[blockIdx.z, offset_m:, offset_k:]
gmem_b = b[blockIdx.z, offset_k:, offset_n:]
for i, k in repeat(8, 1).spatial(16, 8).on(threadIdx.x):
smem_a[i, k] = gmem_a.read([i, k], protected=True)
for k, j in repeat(8, 1).spatial(1, 128).on(threadIdx.x):
smem_b[k, j] = gmem_b.read([k, j], protected=True)
syncthreads()
load_regs_a(smem_a, regs_a)
load_regs_b(smem_b, regs_b)
warp_mma(regs_a, regs_b, regs_c)
syncthreads()
store_c(regs_c, c)
ir_module = module.ir_module()
# conduct the fusion (when the task has prologue or epilogue) and generate the packed function
# ir_module = fuse_and_pack(ir_module, kernel_func=batch_matmul_kernel, task=task)
add_packed_func(ir_module, func=batch_matmul_kernel, pack_func_name=task.name)
return ir_module
The remaining part is the same as the rule-based scheduling method to add new operator.
from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like
class BatchMatmulFp16Op(Operator):
def __init__(self, a: Tensor, b: Tensor):
assert a.dtype == hidet.float16 and b.dtype == hidet.float16
super().__init__(
inputs=[a, b],
task=BatchMatmulFp16Task(input_like(a, 'a'), input_like(b, 'b')),
)
def batch_matmul_fp16(a: Tensor, b: Tensor) -> Tensor:
return BatchMatmulFp16Op(a, b).get_output(0)
def demo_usage():
a = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
b = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
c = batch_matmul_fp16(a, b)
print(a)
print(b)
print(c)
demo_usage()
In this section, we have learned how to add a new operator to Hidet. We first define the computation task of the operator through compute primitives. We then define the operator class and either use rule-based scheduling or write our own schedule template and use template-based scheduling to implement the operator. In the next section, we will learn about Hidet Script, a DSL which allows us to conveniently write our own efficient schedules.