使用 Dataset 和 .batch 后展开张量

Exploding tensor after using Dataset and .batch

我有一个形状为 (100,4,30) 的 numpy 数组。这表示 4 个长度为 30 的编码样本中的 100 个样本。每行的 4 个样本是相关的。

我想获取一个 TensorFlow 数据集,批处理,其中相关样本在同一批中。

我正在努力:

首先,使用np.vsplit得到一个长度为100的列表,其中列表中的每个元素都是4个相关样本的列表。

现在,如果我在这个列表列表上调用 tf.data.Dataset.from_tensor_slices(...).batch(1),我会得到一个包含形状张量 (4,1,30) 的批次。

我希望这批包含 4 个形状为 (1,30) 的张量。

我怎样才能做到这一点?

我可能误解了你的意思,但如果你只是省略了“vsplit”:

data = np.zeros((100, 4, 30))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(1)
for element in data_ds.take(1):
    print(element.shape)

您将获得:

(1, 4, 30)

(所以一批包含所有 4 个相关编码)。

如果你真的希望批次内的维度是 4 倍 (1, 30) 你可以这样做:

data = np.expand_dims(data, axis=2)

创建数据集之前。

编辑:

我想我刚刚明白了你的问题。您希望每个批次都有 4 个元素,而这些是相关的编码?您可以通过以下方式实现:

data = np.swapaxes(data, 0, 1)
data = np.reshape(data, (100*4, -1))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(4)