tf.split 的输出是什么?

What is the output of tf.split?

所以假设我有这个:

TensorShape([Dimension(None), Dimension(32)])

然后我在张量 _X 上使用 tf.split,维度在上面:

_X = tf.split(_X, 128, 0) 

这个新张量的形状是什么?输出是一个列表,所以很难知道这个新张量的形状。

tf.split() returns张量对象列表。你可以知道每个张量对象的形状如下

import tensorflow as tf

X = tf.random_uniform([256, 32]);
Y = tf.split(X,128,0)
Y_shape = tf.shape(Y[1])

sess = tf.Session()
X_v,Y_v,Y_shape_v = sess.run([X,Y,Y_shape]) 
# numpy style
print X_v.shape
print len(Y_v)
print Y_v[100].shape
# TF style
print len(Y)
print Y_shape_v

输出:

(256, 32)
128
(2, 32)
128
[ 2 32]

希望对您有所帮助!

tf.split(X, row = n, column = m) 用于将变量的数据集拆分为 n 行方向的片段和 m 列方向的片段。

例如,我们有 data_set x 大小 (10,10), 然后 tf.split(x, 2, 0) 将打破 x 的 data_set 在 2 组大小 (5, 10)

但是如果我们取 tf.split(x, 2, 2), 那么我们会得到4组大小为(5, 5).

的数据

新版tensorflow定义split函数如下:

tf.split( 价值, num_or_size_splits, 轴=0, 数=None, 姓名='split' )

然而,当我尝试在 R 中 运行 时:

X = tf$random_uniform(minval=0,
                      maxval=10,shape(256, 32),name = "X");

Y = tf$split(X,num_or_size_splits = 2,axis = 0)

它报告错误消息:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split. Argument provided: 2.0