tf.create_partitioned_variables 是如何工作的?

How does tf.create_partitioned_variables work?

我正在想办法使用 tf.create_partitioned_variables 我正在阅读文档,但我很难理解。

谁能解释一下它是如何工作的并举例说明它的用法?

据我了解,我可以使用它从变量中获取切片列表。 我只是不明白我是怎么得到切片的

例如: 我如何从 tf.Variable(np.array([[1.0],[3.0]]), dtype=tf.float32)

中获取 [[1.],[3.]] 的列表

列表
[[[1 0] [3 0]], [[0 5] [0 7]]]

来自

[[[1 0]
  [3 0]]

 [[0 5]
  [0 7]]]

前 3 个参数是必需的。第一个是输入张量的形状。二是拆分规范。 API 目前仅支持一维分割。拆分规范与形状具有相同的维数,一个拆分 >= 1,其他拆分为 1。最后一个参数是张量本身,或者 returns 它的可调用对象。

第一个例子:

tf.create_partitioned_variables(v.shape, [2, 1], v)

第二个例子:

[tf.squeeze(v) 
    for v in tf.create_partitioned_variables(
        v.shape, [2, 1, 1], v)]