1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// graph_builder.h
class GraphBuilder {
public
/**
* 添加自定义算子
* @param inputs 输入张量列表
* @param attr 算子属性
* @return 输出张量句柄
*/
Tensor xxx(
Tensor input_1,
Tensor input_2, ……
std::optional<Tensor> output,
float attr = 1.0f
);
/**
* 添加带多个属性的算子
*/
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// graph_builder.cc
Tensor GraphBuilder::xxx(
Tensor input_1,
Tensor input_2, ……
std::optional<Tensor> output,
float attr)
{
if (output.has_value()) {
g->addOpWithOutputs<XXXObj>(
std::move(input_1),
std::move(input_2), ……
std::move(output.value()),
attr);
return output.value();
} else {
return g->addOp<XXXObj>(
std::move(input_1),
std::move(input_2), ……
nullptr,
attr)
->getOutput(0);
}
}

Pybind11 绑定

1
2
3
4
5
6
7
8
void bind_graph_builder(py::module &m) {
py::class_<GraphObj, std::shared_ptr<GraphObj>>(m, "Graph");
// GraphBuilder
py::class_<GraphBuilderObj>(m, "GraphBuilder")
.def(py::init<Runtime>())
//新增函数
.def("function name", &GraphBuilderObj::xxx, py::arg("参数"), ......);
}