了解 C++ 中新 Tensorflow 运算符的定义

Understanding the Definition of New Tensorflow Operators in C++

我正在尝试按照官方指南在 tensorflow 中定义新的运算符。 https://www.tensorflow.org/extend/adding_an_op

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
      c->set_output(0, c->input(0));
      return Status::OK();
    });

但是我找不到这段代码的逐行解释,特别是我不明白 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext 的作用是什么* c) 及其语法。我也对 InferenceContext 感到困惑,我猜这是一种将任何数组的元素一个接一个地连续传递的方法。我在任何地方都找不到明确的定义,也许我正在寻找错误的地方,有人可以帮助我解释或参考吗? 我想深入了解这段代码在幕后做了什么。

你看到这里关于形状推理函数的部分了吗? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c

其中对 ShapeInferenceContext class 和编写您自己的函数的机制进行了大量讨论。如果那没有涵盖您感兴趣的内容,您能否提供更多详细信息?