假设我们的 ops.cpp 已经编写好了如下代码,主要是根据 Tensor 的后端,判断使用 CUDA 还是 CPU 进行计算.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#include<torch/extension.h> #include<vector>
torch::Tensor myadd_cuda(torch::Tensor a, torch::Tensor b); torch::Tensor myadd_cpu(torch::Tensor a, torch::Tensor b);
torch::Tensor myadd(torch::Tensor a, torch::Tensor b){ TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have same shape"); TORCH_CHECK(a.scalar_type() == b.scalar_type(), "dtype mismatch"); TORCH_CHECK(a.device() == b.device(), "device mismatch");
if (a.is_cuda()) { returnmyadd_cuda(a, b); } returnmyadd_cpu(a, b); }
接下来,我们需要在 ops.cpp 文件里,将其注册为算子.
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// 定义 schema TORCH_LIBRARY(myops, m) { m.def("myadd(Tensor a, Tensor b) -> Tensor"); }
// 注册 CPU 实现 TORCH_LIBRARY_IMPL(myops, CPU, m) { m.impl("myadd", &myadd_cpu); }
// 注册 CUDA 实现 TORCH_LIBRARY_IMPL(myops, CUDA, m) { m.impl("myadd", &myadd_cuda); }