如何从 tf.data.Dataset.zip((images, labels)) 中获取两个 tf.dataset

How to get two tf.dataset from tf.data.Dataset.zip((images, labels))

我正在编写 Python/tensorflow/mnist 教程。

自从使用来自 tensorflow 网站的原始代码几周以来,我收到警告说图像数据集很快就会被弃用,我应该使用以下数据集: https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py

我使用我的代码加载它:

from tensorflow.models.official.mnist import dataset
trainfile = dataset.train(data_dir)

哪个 returns :

tf.data.Dataset.zip((images, labels))

问题是我找不到以下列方式将它们分开的方法,例如:

  trainfile = dataset.train(data_dir)
  train_data= trainfile.images
  train_label= trainfile.label

但这显然是行不通的,因为属性图像和标签不存在。训练文件是 tf.dataset.

知道 tf.dataset 是由 int32 和 float32 组成的我试过了:

  train_data = trainfile.map(lambda x,y : x.dtype == tf.float32)

但它 returns 和空数据集。

我坚持(但会公开)这样做(两批完整的图像和标签),因为这就是教程的工作方式:

https://www.tensorflow.org/tutorials/estimators/cnn

我看到了很多从数据集中获取元素的解决方案,但没有什么可以从以下代码中完成的 zip 操作返回

tf.data.Dataset.zip((images, labels))

预先感谢您的帮助。

与其分成两个数据集,一个用于图像,另一个用于标签,最好制作一个迭代器,其中 returns 图像和标签。

这是首选的原因是,即使在一系列复杂的洗牌、重新排序、过滤等之后,也更容易确保您将每个示例与其标签相匹配,就像您在非平凡的输入管道中可能遇到的那样.

希望对您有所帮助:

inputs = tf.placeholder(tf.float32, shape=(None, 784), name='inputs')
outputs = tf.placeholder(tf.float32, shape=(None,), name='outputs')

#Prepare a tensorflow dataset
ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))

ds = ds.shuffle(buffer_size=10, reshuffle_each_iteration=True).batch(batch_size=batch_size, drop_remainder=True).repeat()
iter = ds.make_one_shot_iterator()
next = iter.get_next()

inputs = next[0]
outputs = next[1]

您可以可视化图像并找到其关联的标签

ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))

ds = ds.shuffle(buffer_size=10).batch(batch_size=batch_size)
iter = ds.make_one_shot_iterator()
next = iter.get_next()

def display(image, label):
# display image
   ...
   plt.imshow(image)
   ...

with tf.Session() as sess:
    try:
        while True:
             image, label = sess.run(next) 
             # image = numpy array (batch, image_size)
             # label = numpy array (batch, label)
        display(image[0], label[0]) #display first image in batch
    except:
        pass

TensorFlow 的 get_single_element() 终于 around 可用于从数据集中解压缩特征和标签。

这避免了使用 .map()iter()one_shot_iterator() 生成和使用迭代器的需要(这对于大数据集来说可能代价高昂)。

get_single_element() returns 封装数据集所有成员的张量(或张量的元组或字典)。我们需要将批处理的数据集的所有成员传递到一个元素中。

这可用于获取特征作为张量数组,或特征和标签作为元组或字典(张量数组),具体取决于原始数据集的方式已创建。

在 SO 上查看此 以获取将特征和标签解包到张量数组元组中的示例。