当收到“DirectedInterleave selected an exhausted input”警告时,TensorFlow 的“sample_from_datasets”是否仍然从数据集中采样?

Does TensorFlow's `sample_from_datasets` still sample from a Dataset when getting a `DirectedInterleave selected an exhausted input` warning?

当使用 TensorFlow 的 tf.data.experimental.sample_from_datasets to equally sample from two very unbalanced Datasets, I end up getting a DirectedInterleave selected an exhausted input: 0 warning. Based on this GitHub issue 时,这似乎是在 sample_from_datasets 中的一个数据集已经耗尽示例时发生的,并且需要对已经看到的示例进行采样。

然后耗尽的数据集是否仍然产生样本(从而保持所需的平衡训练比率),或者数据集是否不采样因此训练再次变得不平衡?如果是后者,是否有一种方法可以用 sample_from_datasets 产生所需的平衡训练比率?

注意:正在使用 TensorFlow 2 Beta

较小的数据集不会重复 - 一旦用完,其余部分将仅来自仍有示例的较大数据集。

您可以通过执行以下操作来验证此行为:

def data1():
  for i in range(5):
    yield "data1-{}".format(i)

def data2():
  for i in range(10000):
    yield "data2-{}".format(i)

ds1 = tf.data.Dataset.from_generator(data1, tf.string)
ds2 = tf.data.Dataset.from_generator(data2, tf.string)

sampled_ds = tf.data.experimental.sample_from_datasets([ds2, ds1], seed=1)

然后,如果我们遍历 sampled_ds,我们会发现 data1 中的样本一旦耗尽就不会生成:

tf.Tensor(b'data1-0', shape=(), dtype=string)
tf.Tensor(b'data2-0', shape=(), dtype=string)
tf.Tensor(b'data2-1', shape=(), dtype=string)
tf.Tensor(b'data2-2', shape=(), dtype=string)
tf.Tensor(b'data2-3', shape=(), dtype=string)
tf.Tensor(b'data2-4', shape=(), dtype=string)
tf.Tensor(b'data1-1', shape=(), dtype=string)
tf.Tensor(b'data1-2', shape=(), dtype=string)
tf.Tensor(b'data1-3', shape=(), dtype=string)
tf.Tensor(b'data2-5', shape=(), dtype=string)
tf.Tensor(b'data1-4', shape=(), dtype=string)
tf.Tensor(b'data2-6', shape=(), dtype=string)
tf.Tensor(b'data2-7', shape=(), dtype=string)
tf.Tensor(b'data2-8', shape=(), dtype=string)
tf.Tensor(b'data2-9', shape=(), dtype=string)
tf.Tensor(b'data2-10', shape=(), dtype=string)
tf.Tensor(b'data2-11', shape=(), dtype=string)
tf.Tensor(b'data2-12', shape=(), dtype=string)
...
---[no more 'data1-x' examples]--
...

当然,你可以data1重复这样的事情:

sampled_ds = tf.data.experimental.sample_from_datasets([ds2, ds1.repeat()], seed=1)

但从评论看来您已经意识到这一点并且它不适用于您的场景。

If the latter, is there a method to produce the desired balanced training ratio with sample_from_datasets?

嗯,如果你有 2 个不同长度的数据集,并且你从那里均匀采样,那么你似乎只有 2 个选择:

  • 重复较小的数据集 n 次(其中 n ≃ len(ds2)/len(ds1)
  • 一旦较小的数据集用完就停止采样

要实现第一个你可以使用ds1.repeat(n)

要实现第二个,您可以使用 ds2.take(m) where m=len(ds1).