ninetoothed 是 tensor-oriented meta-programming,这意味着 ninetoothed 可以分别定义出张量的布局 (arrangement) 和计算逻辑 (application),最后再生成相对应的 kernel.

Kernel Generation

在生成 kernel 的过程中,最重要的步骤是 generation.py,其工作主要是

  1. Parse Python into AST
  2. 将 Helper functions 内联
  3. 将 AST 改写为 Symbolic Tensor Annotations
  4. 将剩余的 ninetoothed reference 转换为 Triton reference
  5. 将 AST 重新转换为 Python 并 cache
  6. 调用 jit.py 导入转换后的 Python 代码并导出最终的 Kernel

可以看到,对 Kernel 的优化也是主要对 AST 进行的.

AST 解析

这一块由 CodeGenerator.__call__._get_tree() 生成 AST,通过调用 Python 的自带模块 ast.parse 完成

1
2
3
4
5
6
def _get_tree(func):
func_def = ast.parse(textwrap.dedent(inspect.getsource(func)))

# ...

return func_def

内联化

这一块由 _Inliner 这个类完成,主要作用就是把各种函数内联进 kernel 里

1
2
3
4
5
6
7
8
9
10
11
12
13
14
inliner = _Inliner(func.__globals__)
inliner.visit(func_def)

if inliner.libdevice_used:
libdevice_alias = ast.alias(
name="libdevice", asname=inliner.LIBDEVICE_ALIAS
)
libdevice_import = ast.ImportFrom(
module="triton.language.extra",
names=[libdevice_alias],
level=0,
)

func_def.body.insert(0, libdevice_import)

Kernel 函数的参数准备

对于函数签名,进行解析,方便生成 Triton Kernel Function 的参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def visit_FunctionDef():
# ...
symbols = {
name.node.id: name
for arg in self._args
for name in arg.names()
if name != "ninetoothed"
}

meta_names = {name for name in names if naming.is_meta(name)}
non_meta_names = {name for name in names if name not in meta_names}

node.args.args = [
ast.arg(arg=name) if not naming.is_constexpr(name)
else ast.arg(arg=name, annotation=attribute("constexpr").node)
for name in non_meta_names
] + [
ast.arg(arg=name, annotation=attribute("constexpr").node)
for name in meta_names
]

Auto Tuning

接下来,ninetoothed 根据是否有 meta parameter,用 self._generate_autotune() 创建 @autotune(...) wrapper

这个函数的工作原理大致是:检查所有的 meta parameters,然后枚举值,根据 device limit 等等条件进行过滤,然后 combine 上 num_stages, num_wraps,生成 autotune 装饰器

1
2
3
4
5
6
7
8
9
10
11
# example decorator
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=8,
num_stages=4),
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=8,
num_stages=4),
...
],
key=["input_size_0", "input_size_1", "other_size_0", "other_size_1", ...],
)

张量重写

这里算是很重要的部分了,涉及到张量的 loadstore

Bare Tensor Name

如果出现了裸的张量名,说明需要预先进行加载,visit_Name() 函数会直接将其改写为 triton.load

1
2
3
# visit_Name():
if self._in_context(node) and isinstance(node.ctx, ast.Load):
return self._generate_load(self._context[node.id])

Indexed Tensor

出现例如 input[k] 这样的张量读取则由 visit_Subscript() 处理

1
2
if self._in_context(node.value) and isinstance(tensor := self._context[node.value.id], Tensor):
return self._generate_load(tensor, indices=...)

那么 _generate_load() 是如何工作的呢?简而言之,它生成 pointermask 和默认值 other,然后直接转写为 tl.load(...)

1
2
3
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
other = type(self)._generate_other(tensor)
return call("load", pointers, mask=mask, other=other).node

举例来说,input[k] 会生成这样的东西:

1
tl.load(input_pointers + computed_offsets, mask=mask, other=0)

具体的地址由 _generate_pointers_and_mask()_generate_overall_offsets_and_mask() 计算得到,本质是 Symbolic 代数运算:

1
2
3
4
5
overall_offsets = sum(
offsets[source_dim] * Symbol(tensor.source.stride_string(source_dim))
for source_dim in range(tensor.source.ndim)
)
pointers = name_for_pointers + overall_offsets

因而,对于 matmul 运算中的 input[k] tile,其生成的代码类似:

1
2
3
4
5
6
input_ptrs = (
input_pointer
+ input_row_offsets * input_stride_0
+ input_k_offsets * input_stride_1
)
input_tile = tl.load(input_ptrs, mask=input_mask, other=0)

Program ID