为每个 TensorFlow 操作获取 Protobuf 消息

Getting Protobuf Message for every TensorFlow Operation

来自 here,

有几种方法可以获取已注册操作的操作定义列表:

  • TF_GetAllOpList in the C API retrieves all registered OpDef protocol messages. This can be used to write the generator in the client language. This requires that the client language have protocol buffer support in order to interpret the OpDef messages.
  • The C++ function OpRegistry::Global()->GetRegisteredOps() returns the same list of all registered OpDefs (defined in [tensorflow/core/framework/op.h]). This can be used to write the generator in C++ (particularly useful for languages that do not have protocol buffer support).
  • The ASCII-serialized version of that list is periodically checked in to [tensorflow/core/ops/ops.pbtxt] by an automated process.

但是,唉,我想在 Python 中这样做,

import tensorflow as tf
from google.protobuf import json_format
json_string = json_format.MessageToJson(tf.GetAllOpsList())

我想要一种方法来获取 Tensorflow 中每个操作的 Protobuf 消息,以便我可以通过

将其转储为 JSON

它在 ops.txt。以下示例列出了生成字符串输出的操作的所有 OpDef 消息。

import tensorflow as tf

from tensorflow.core.framework import op_def_pb2
from google.protobuf import text_format

def get_op_types(op):
    for attr in op.attr:
        if attr.type != 'type':
            continue
        return list(attr.allowed_values.list.type)
    return []

# directory where you did "git clone"
tensorflow_git_base = "/Users/yaroslav/tensorflow.git"
ops_file = tensorflow_git_base+"/tensorflow/tensorflow/core/ops/ops.pbtxt"
ops = op_def_pb2.OpList()
text_format.Merge(open(ops_file).read(), ops)

for op in ops.op:
    # get templated string types
    if tf.string in get_op_types(op):
        print(op.name, op.summary)
    #for arg in op.input_arg:
    for arg in op.output_arg:
        if arg.type == tf.string:
            print(op.name, op.summary)
            break

** 已添加 ** 如果您想对添加的新操作敏感,您可以逆向工程当前 Python 包装器的工作方式。例如,考虑 gen_array_ops.py 文件。它有以下片段

def _InitOpDefLibrary():
  op_list = _op_def_pb2.OpList()
  _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
  _op_def_registry.register_op_list(op_list)
  op_def_lib = _op_def_library.OpDefLibrary()
  op_def_lib.add_op_list(op_list)
  return op_def_lib


_InitOpDefLibrary.op_list_ascii = """op {
  name: "BatchMatrixBandPart"
  input_arg {
    name: "input"
    type_attr: "T"
  }
  input_arg {
    name: "num_lower"
    type: DT_INT64
  }
  input_arg {
    name: "num_upper"
    type: DT_INT64
  }
  output_arg {
    name: "band"
    type_attr: "T"
  }
  attr {
    name: "T"
    type: "type"
  }
  deprecation {
    version: 14
    explanation: "Use MatrixBandPart"
  }
}

所以这些消息 protobufs 是在 gen_array_ops 生成期间从底层 C 代码生成的。要了解它们的生成方式,请参阅