Tensorflow 洗牌迭代器

Tensorflow shuffle iterator

我想用 Tensorflow 迭代器检索属于不同 类 的 2 个项目(进行 BC 学习)...

我一直在研究的解决方案是 tf.while_loop,但我觉得它不合适。除了我提出的解决方案之外,还有其他方法吗?

这是一个关于属于 5 类

的随机数的朴素数据集的示例
import tensorflow as tf
import numpy as np

dataset = np.array([(np.random.rand(), i/20) for i in range(100)])
dataset = tf.data.Dataset.from_tensor_slices(dataset)
dataset = dataset.shuffle(100)
iterator = dataset.make_one_shot_iterator()

a = iterator.get_next()
b = iterator.get_next()
loop_vars = [a, b]

def cond(a, b):
    l1 = tf.gather(a, 1)
    l2 = tf.gather(b, 1)
    return tf.equal(l1, l2)

def body(a, b):
    a = iterator.get_next()
    b = iterator.get_next()
    return a, b

loop = tf.while_loop(cond, body, loop_vars)


with tf.Session() as sess:
    for i in range(10):
        values = sess.run([loop])
        print values

谢谢:)

我不太清楚你想做什么,但是如果你有两个 tf.data.Dataset 对象并且你想从它们中随机抽样,你可以做类似下面的事情(注意这将需要升级到 tf-nightly 包或等待 TensorFlow 1.9 发布):

# Define two datasets with the same structure but different values, to represent
# the different inputs. Using dummy data (a dataset of '1's and a dataset of '2's)
# to make the example clearer.
dataset_1 = tf.data.Dataset.from_tensors(1).repeat(None)
dataset_2 = tf.data.Dataset.from_tensors(2).repeat(None)

merged_dataset = tf.contrib.data.sample_from_datasets([dataset_1, dataset_2])
merged_dataset = merged_dataset.batch(2)  # Get two elements at a time.

iterator = merged_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
  for i in range(10):
    values = sess.run([next_element])
    print values

# Prints: 
# [array([1, 2], dtype=int32)]
# [array([1, 2], dtype=int32)]
# [array([2, 1], dtype=int32)]
# [array([1, 2], dtype=int32)]
# [array([1, 1], dtype=int32)]
# [array([2, 2], dtype=int32)]
# [array([2, 2], dtype=int32)]
# [array([1, 1], dtype=int32)]
# [array([2, 2], dtype=int32)]
# [array([2, 2], dtype=int32)]

如果要确保元素不是来自同一个 class,可以使用将 weights 参数指定给 tf.contrib.data.sample_from_datasets() 的功能,它可以是 Dataset 的(在本例中为 one-hot)分布,如下:

import tensorflow as tf

NUM_CLASSES = 5
NUM_DISTINCT = 2
datasets = [tf.data.Dataset.from_tensors(i).repeat(None)
            for i in range(NUM_CLASSES)]

# Define a dataset with NUM_DISTINCT distinct class IDs per element,
# then unbatch it in to one class per element.
weight_dataset = tf.contrib.data.Counter().map(
    lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:NUM_DISTINCT])
weight_dataset = weight_dataset.apply(tf.contrib.data.unbatch())
weight_dataset = weight_dataset.map(lambda x: tf.one_hot(x, NUM_CLASSES))

merged_dataset = tf.contrib.data.sample_from_datasets(datasets, weight_dataset)
merged_dataset = merged_dataset.batch(NUM_DISTINCT)

iterator = merged_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
 for i in range(1000):
   values = sess.run(next_element)
   assert values[0] != values[1]
   print values