从 tensorflow 数据集中提取元素

Extract elements from tensorflow dataset

我有一个包含我所有数据和标签的张量流数据集。 使用以下代码将前 20 个元素提取到另一个数据集中:

train_dataset = big_dataset.take(20)

但是我如何将 big_dataset 中的最后 20 个元素提取到新数据集中?

谢谢我提前!

编辑: 以下代码显示了我如何定义 big_dataset:

big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))

现在获取第一个元素的方法是以下代码(其中 train_size 例如 20):

train_dataset = big_dataset.take(train_size)
train_dataset = train_dataset.shuffle(train_size).map(augment).batch(BATCH_SIZE)

但是使用 .skip().take() 会导致数据库为空

尝试使用 skip。例如,假设您有 120 个数据样本,batch_size 为 1,并且您没有打乱数据,那么您可以尝试如下操作:

train_dataset = big_dataset.skip(100).take(20)

对于您的特定数据集,尝试:

import tensorflow as tf

samples = 29
all_points  = tf.random.normal((samples, 5))
all_labels  = tf.random.normal((samples, 1))
big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))
train_size = 20 
train_dataset = big_dataset.skip(9).take(train_size)
print(len(train_dataset))
20