tf.data.experimental.sample_from_datasets 未按预期抽样

tf.data.experimental.sample_from_datasets not sampling as expected

文档似乎很简单,他们的标准 TF 教程中给出的示例没有突出显示我看到的行为。假设您有一个 1 和 0(pos 和 neg)的不平衡数据集,并且您希望以权重 [0.5, 0.5] 进行采样,以便更频繁地看到正值。你会这样做:

pos_ds = tf.data.Dataset.from_tensor_slices(np.ones(shape=(16, 1)))
neg_ds = tf.data.Dataset.from_tensor_slices(np.zeros(shape=(128, 1)))

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])

如果我想在浏览数据集时查看 pos 和 neg 的分布情况:

xs = []
for x in resampled_ds:
  xs.append(int(x.numpy()[0]))

xs = np.array(xs)
print(xs)

np.bincount(xs)

我看到了这个:

[1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 1 1 0 1 0 0 1 0 0 0 0 1 1 0 0 1
 0 1 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

array([128,  16])

有 128 个负数和 16 个正数。如果我将其用作我的 train_ds,这将相当于没有进行采样,更糟糕的是,底片不再均匀分布在步骤/时期。我猜 0.5 采样是在开始时发生的,一旦它“运行 超出”1 秒,它就开始仅对零采样。它显然不会对 1 进行替换采样。我认为如果在所有 1 都被采样后停止,1 和 0 只会是 0.5/0.5。

看起来这是一种行为,但它不是唯一明智的行为。我想在 1 个 epoch 中对正样本进行多次采样(即有放回的采样),pos 和 negs 的数量大致相等,这个 API 有什么选择吗?另外,我有数据扩充,所以在训练时积极因素实际上是不一样的。

对于更换问题,您可以这样做:

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds.repeat(128 // 16), neg_ds], weights=[0.5, 0.5])

结果是:

[1 1 1 0 0 1 1 1 1 1 0 1 0 0 1 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1
 1 0 0 1 1 1 1 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1
 1 0 0 1 0 1 0 1 1 1 0 1 0 1 0 1 0 1 0 1 0 1 1 1 1 1 1 1 0 0 0 0 0 1 1 0 0
 0 0 0 0 1 0 1 0 1 0 0 1 1 0 0 1 0 1 0 1 0 0 0 1 1 1 0 1 0 0 1 1 0 1 1 0 1
 1 0 0 1 1 0 0 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 0 1 0 0 1 1 0 0 0 1 0 1 0 1 1
 1 1 0 1 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 1 1 1 0 0 0 1 0 1 1 1 0 0 0 0 1 1 0
 0 0 1 0 1 0 0 0 0 1 0 0 0 0 1 0 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
Out[2]: array([128, 128], dtype=int64)

实际上,我还在那个 TF 教程 imbalanced_data.ipynb 上找到了解决方案(我在自己的笔记本中完全错过了这个)。

pos_ds = pos_ds.shuffle(BUFFER_SIZE).repeat()
neg_ds = neg_ds.shuffle(BUFFER_SIZE).repeat()

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])

本教程进一步建议了一种启发式方法来设置 resampled_steps_per_epoch。

然而,洗牌+重复,仍然不等同于对少数class进行替换的真实采样。一个 repeat() 后跟一个 shuffle() 可能是这样做的。我可以通过尝试两种方式来更新它。