如何使用 tfrecords 和 steps_per_epoch 控制读取哪些样本
How can I control which samples are read using tfrecords and steps_per_epoch
我目前正在将我的 tf 代码转换为 tfrecords 和 tf 数据集。在我的应用程序中,经过训练的模型通常在看到所有训练样本之前很久就收敛了。因此,我通常将数据生成器的长度自己设置为我想要在一个时期内适合的批次数,并确保在我的生成器中,在下一个时期,生成器在前一个时期的最后一个样本之后拾取。这允许所有回调按预期工作(尤其是提前停止),同时我仍然可以在每个时期用看不见的数据训练我的模型。
如何使用 tf 数据集和 tfrecords 实现此行为?我已经阅读了 tensorflow Github 上的数据集定义,但不确定这是否可行。
如果我设置 steps_per_epoch
:
,我认为有两种可能的解决方案
- 覆盖指定从何处读取下一个样本的代码部分,以仅在前一个时期的最后一个样本之后的样本处拾取。
- 尝试使用自定义 tf 数据集实现来模仿上述行为。我担心这会对并行化和性能产生无法预料的影响。
但是我也不知道怎么实现。因此,如果您对此有任何见解,我将不胜感激。
现在我可以使用一个不优雅的解决方法,我总是训练一个时期,然后用新的 tfrecord 文件初始化一个新的数据集,但我希望有更好的方法,尤其是在回调方面。
我不确定我是否完全理解您想要实现的目标。你想要那个:
- 在一个时期内,您的模型看不到整个数据集
- 接下来的 epoch 不使用之前的样本
就这些了?
在我看来,steps_per_epoch
论点是您最好的选择。例如,如果您有一个包含 100 个项目(样本或批次)的数据集,并且您设置了 steps_per_epoch=20
,那么在第一个时期,您的模型将看到项目 0 到 19,在第二个时期看到项目 20 到 39,依此类推在。无需覆盖任何部分代码。
尝试模仿数据集的行为可能不是一个好主意(需要处理的事情太多,涉及许多(艰苦的)工作)。
从你的最后一段,我了解到你希望每个时期都使用来自特定 TFRecord 文件的数据。也许你可以看看tf.data.Dataset.flat_map
。建立一个 TFRecord 文件列表(同一个文件可以出现多次)和“flat_map
” TFRecordDataset
在上面:
files = tf.data.Dataset.from_tensor_slices([
"file1.tfrecord", "file2.tfrecord",
"file1.tfrecord", "file3.tfrecord"
])
dataset = file.flat_map(TFRecordDataset)
遍历数据集会给你 Example
s 来自 file1,然后来自 file2,然后再次来自 file1,然后来自 file3。
希望对您有所帮助。
我目前正在将我的 tf 代码转换为 tfrecords 和 tf 数据集。在我的应用程序中,经过训练的模型通常在看到所有训练样本之前很久就收敛了。因此,我通常将数据生成器的长度自己设置为我想要在一个时期内适合的批次数,并确保在我的生成器中,在下一个时期,生成器在前一个时期的最后一个样本之后拾取。这允许所有回调按预期工作(尤其是提前停止),同时我仍然可以在每个时期用看不见的数据训练我的模型。
如何使用 tf 数据集和 tfrecords 实现此行为?我已经阅读了 tensorflow Github 上的数据集定义,但不确定这是否可行。
如果我设置 steps_per_epoch
:
- 覆盖指定从何处读取下一个样本的代码部分,以仅在前一个时期的最后一个样本之后的样本处拾取。
- 尝试使用自定义 tf 数据集实现来模仿上述行为。我担心这会对并行化和性能产生无法预料的影响。
但是我也不知道怎么实现。因此,如果您对此有任何见解,我将不胜感激。
现在我可以使用一个不优雅的解决方法,我总是训练一个时期,然后用新的 tfrecord 文件初始化一个新的数据集,但我希望有更好的方法,尤其是在回调方面。
我不确定我是否完全理解您想要实现的目标。你想要那个:
- 在一个时期内,您的模型看不到整个数据集
- 接下来的 epoch 不使用之前的样本
就这些了?
在我看来,steps_per_epoch
论点是您最好的选择。例如,如果您有一个包含 100 个项目(样本或批次)的数据集,并且您设置了 steps_per_epoch=20
,那么在第一个时期,您的模型将看到项目 0 到 19,在第二个时期看到项目 20 到 39,依此类推在。无需覆盖任何部分代码。
尝试模仿数据集的行为可能不是一个好主意(需要处理的事情太多,涉及许多(艰苦的)工作)。
从你的最后一段,我了解到你希望每个时期都使用来自特定 TFRecord 文件的数据。也许你可以看看tf.data.Dataset.flat_map
。建立一个 TFRecord 文件列表(同一个文件可以出现多次)和“flat_map
” TFRecordDataset
在上面:
files = tf.data.Dataset.from_tensor_slices([
"file1.tfrecord", "file2.tfrecord",
"file1.tfrecord", "file3.tfrecord"
])
dataset = file.flat_map(TFRecordDataset)
遍历数据集会给你 Example
s 来自 file1,然后来自 file2,然后再次来自 file1,然后来自 file3。
希望对您有所帮助。