了解 C++ 中新 Tensorflow 运算符的定义
Understanding the Definition of New Tensorflow Operators in C++
我正在尝试按照官方指南在 tensorflow 中定义新的运算符。 https://www.tensorflow.org/extend/adding_an_op
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
c->set_output(0, c->input(0));
return Status::OK();
});
但是我找不到这段代码的逐行解释,特别是我不明白 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext 的作用是什么* c) 及其语法。我也对 InferenceContext 感到困惑,我猜这是一种将任何数组的元素一个接一个地连续传递的方法。我在任何地方都找不到明确的定义,也许我正在寻找错误的地方,有人可以帮助我解释或参考吗?
我想深入了解这段代码在幕后做了什么。
你看到这里关于形状推理函数的部分了吗?
https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c
其中对 ShapeInferenceContext class 和编写您自己的函数的机制进行了大量讨论。如果那没有涵盖您感兴趣的内容,您能否提供更多详细信息?
我正在尝试按照官方指南在 tensorflow 中定义新的运算符。 https://www.tensorflow.org/extend/adding_an_op
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
c->set_output(0, c->input(0));
return Status::OK();
});
但是我找不到这段代码的逐行解释,特别是我不明白 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext 的作用是什么* c) 及其语法。我也对 InferenceContext 感到困惑,我猜这是一种将任何数组的元素一个接一个地连续传递的方法。我在任何地方都找不到明确的定义,也许我正在寻找错误的地方,有人可以帮助我解释或参考吗? 我想深入了解这段代码在幕后做了什么。
你看到这里关于形状推理函数的部分了吗? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c
其中对 ShapeInferenceContext class 和编写您自己的函数的机制进行了大量讨论。如果那没有涵盖您感兴趣的内容,您能否提供更多详细信息?