如何使 TensorFlow 中的自定义 Op 可在 Python 中导入?

How to make custom Op in TensorFlow importable in Python?

我已经为我的自定义 Op 实现了一个内核,并将其作为 custom_op.cc 放入 /tensorflow/core/user_ops。在 Op 中,我做了所有的注册工作,比如 REGISTER_OPREGISTER_KERNEL_BUILDER.

然后我在 Python 中为这个 Op 实现了渐变,并将它放在与 custom_op_grad.py 相同的文件夹中。我也在这里完成了所有注册 (@ops.RegisterGradient)。

我已经创建了包含以下内容的 BUILD 文件:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
        name = "custom_op.so",
        srcs = ["custom_op.cc"],
)

py_library(
        name = "custom_op_grad",
        srcs = ["custom_op_grad.py"],
        srcs_version = "PY2",
        deps = [
        ":custom_op_grad",
        "//tensorflow:tensorflow_py",
        ],
)

之后,我重建了 Tensorflow:

pip uninstall tensorflow
bazel clean
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
cp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl

当我在这一切之后尝试使用我的 Op 时,通过调用 tf.user_ops.custom_op 它告诉我模块没有它。

也许我还需要执行一些额外的步骤?或者我对 BUILD 文件做错了什么?

好的,我找到了解决方案。我刚刚删除了 BUILD 文件,我的自定义 Op 已成功构建,并且可以使用 tensorflow.user_ops.custom_op() 在 Python 中导入。

要使用渐变,我必须将它的代码直接放在 tensorflow/python/user_ops/user_ops.py 中。不是最优雅的解决方案,但目前有效。