ninetoothed 是 tensor-oriented meta-programming,这意味着 ninetoothed 可以分别定义出张量的布局 (arrangement) 和计算逻辑 (application),最后再生成相对应的 kernel.
Kernel Generation
在生成 kernel 的过程中,最重要的步骤是 generation.py,其工作主要是
- Parse Python into AST
- 将 Helper functions 内联
- 将 AST 改写为 Symbolic Tensor Annotations
- 将剩余的 ninetoothed reference 转换为 Triton reference
- 将 AST 重新转换为 Python 并 cache
- 调用
jit.py导入转换后的 Python 代码并导出最终的 Kernel
可以看到,对 Kernel 的优化也是主要对 AST 进行的.
AST 解析
这一块由 CodeGenerator.__call__._get_tree() 生成 AST,通过调用 Python 的自带模块 ast.parse 完成
1 | def _get_tree(func): |
内联化
这一块由 _Inliner 这个类完成,主要作用就是把各种函数内联进 kernel 里
1 | inliner = _Inliner(func.__globals__) |
Kernel 函数的参数准备
对于函数签名,进行解析,方便生成 Triton Kernel Function 的参数
1 | def visit_FunctionDef(): |
Auto Tuning
接下来,ninetoothed 根据是否有 meta parameter,用 self._generate_autotune() 创建 @autotune(...) wrapper
这个函数的工作原理大致是:检查所有的 meta parameters,然后枚举值,根据 device limit 等等条件进行过滤,然后 combine 上 num_stages, num_wraps,生成 autotune 装饰器
1 | # example decorator |
张量重写
这里算是很重要的部分了,涉及到张量的 load 与 store.
Bare Tensor Name
如果出现了裸的张量名,说明需要预先进行加载,visit_Name() 函数会直接将其改写为 triton.load
1 | # visit_Name(): |
Indexed Tensor
出现例如 input[k] 这样的张量读取则由 visit_Subscript() 处理
1 | if self._in_context(node.value) and isinstance(tensor := self._context[node.value.id], Tensor): |
那么 _generate_load() 是如何工作的呢?简而言之,它生成 pointer、mask 和默认值 other,然后直接转写为 tl.load(...)
1 | pointers, mask = self._generate_pointers_and_mask(tensor, indices) |
举例来说,
input[k]会生成这样的东西:
1 tl.load(input_pointers + computed_offsets, mask=mask, other=0)具体的地址由
_generate_pointers_and_mask()和_generate_overall_offsets_and_mask()计算得到,本质是 Symbolic 代数运算:
1
2
3
4
5 overall_offsets = sum(
offsets[source_dim] * Symbol(tensor.source.stride_string(source_dim))
for source_dim in range(tensor.source.ndim)
)
pointers = name_for_pointers + overall_offsets因而,对于 matmul 运算中的
input[k]tile,其生成的代码类似:
1
2
3
4
5
6 input_ptrs = (
input_pointer
+ input_row_offsets * input_stride_0
+ input_k_offsets * input_stride_1
)
input_tile = tl.load(input_ptrs, mask=input_mask, other=0)