读取 Tensorflow 数据集会改变 `take()` 和 `skip()` 的行为

Reading Tensorflow Dataset changes bahaviour of `take()` and `skip()`

我正在尝试检查我的 tensorflow 数据集中的标签。但是,在使用 take()skip() 后,标签的值会发生意外变化,具体取决于我是否检查数据。 (看起来标签内的一些变成了零。)我没有看到我的检查功能可以改变数据集的任何方式。我错过了什么?

要重现该行为,请更改 LOOK_AT_DATA_TWICE 变量。

# python 3.9.4, tensorflow 2.5.0-rc1
import numpy as np
import tensorflow as tf

tf.random.set_seed(42)


def inspect_dataset(ds, msg="", print_all=True):
    sum_ones = 0
    sum_zeros = 0
    for (sig, label) in ds.as_numpy_iterator():
        if print_all:
            print(msg, label, np.histogram(label, bins=2)[0])
        sum_ones += np.sum(label)
        sum_zeros += np.sum(label - 1)

    print(msg, "SUM of ones=", sum_ones)
    print(msg, "SUM of zero=", sum_zeros)


all_pattern = np.random.random((4000, 1000))
all_labels = np.array(2000 * [0] + 2000 * [1])

print(f"all_pattern.shape={all_pattern.shape}")
print(f"all_labels.shape={all_labels.shape}, sum(all_labels)={np.sum(all_labels)}")
print(f"Creating dataset from labels hist: {np.histogram(all_labels, bins=2)[0]}")

complete_ds = tf.data.Dataset.from_tensor_slices((all_pattern, all_labels))
complete_ds = complete_ds.shuffle(len(all_labels))

LOOK_AT_DATA_TWICE = True  # This changes the numbers output below
if LOOK_AT_DATA_TWICE:
    inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)
inspect_dataset(complete_ds, msg="complete_ds in gerneration", print_all=False)

validation_split=0.5
num_test_samples = int(validation_split * len(all_labels))
train_ds = complete_ds.skip(num_test_samples)
val_ds = complete_ds.take(num_test_samples)

inspect_dataset(train_ds, msg="train_ds in generation", print_all=False)
inspect_dataset(val_ds, msg="val_ds in generation", print_all=False)

输出 LOOK_AT_DATA_TWICE = True:

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]

complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 997
train_ds in generation SUM of zero= -1003
val_ds in generation SUM of ones= 988
val_ds in generation SUM of zero= -1012

输出 LOOK_AT_DATA_TWICE = False:

all_pattern.shape=(4000, 1000)
all_labels.shape=(4000,), sum(all_labels)=2000
Creating dataset from labels hist: [2000 2000]
   
complete_ds in gerneration SUM of ones= 2000
complete_ds in gerneration SUM of zero= -2000
train_ds in generation SUM of ones= 1031
train_ds in generation SUM of zero= -969
val_ds in generation SUM of ones= 1003
val_ds in generation SUM of zero= -997

当数据集耗尽时(即迭代一次后),它将重做所有操作。在你的情况下,因为你正在洗牌,所以第一个时期的洗牌将不同于第二个时期的洗牌。

意思是你的训练集和测试集实际上不同时期不一致。

您可以将 reshuffle_each_iteration 设置为对 shuffle 的调用,以使 shuffle 在每次迭代时表现相同。如果您仍然希望对您的训练集进行不同的随机播放,您应该再次调用它。

ds = tf.data.Dataset.from_tensor_slices(data)
shuffled_ds = ds.shuffle(reshuffle_each_iteration=False)
train_ds = shuffled_ds.take(train_size)
train_ds = train_ds.shuffle()
test_ds = shuffled_ds.skip(train_size)