如何将 tf.train.batch 与 enqueue_many=true 一起使用

How to use tf.train.batch with enqueue_many=true

我正在寻找将 tf.train.batchenqueue_many=True 结合使用的示例。

在我的例子中,我有一个形状为 [299,299,3] 的图像张量,当我调用函数 get_distortions(image) 时,它将 return 一个形状为 [10,299,299,3] 的新张量(在这个例子中,它将对图像应用 10 次扭曲,并且 return 它们全部作为一个新的张量)。然后我想通过调用 tf.train.batch.

将所有这些排队

我试过这个:

example_batch = tf.train.batch(tf.unpack(distortions), 5, enqueue_many=True)

但是当我 sess.run(example_batch) 时,我得到了一个长度为 10 的列表(我期望一批大小为 5)。

此外,在这种情况下,如何将标签添加到 tf.train.batch?所有 10 个扭曲的标签都相同。

不要解压 distortionsenqueue_many 的语义是你给它一个张量,第一维是批处理维度,所以 [10, 299, 299, 3] 张量 enqueue_many 将产生十个单独的项目,每个项目的形状为 299, 299, 3 正在排队——这就是你想要的。

tf.train.batch 的文档告诉您:

If enqueue_many is True, tensors is assumed to represent a batch of examples, where the first dimension is indexed by example, and all members of tensors should have the same size in the first dimension. If an input tensor has shape [*, x, y, z], the output will have shape [batch_size, x, y, z]. The capacity argument controls the how long the prefetching is allowed to grow the queues.

这正是您的情况:[10, 299, 299, 3],其中 10 是批量大小。所以你不需要做任何解包,tf.train.batch(distortions, 5, enqueue_many=True) 会完成这项工作。