在 Tensorflow 数据集中拆分数据集问题 API
Split a dataset issue in Tensorflow dataset API
我正在使用 tf.contrib.data.make_csv_dataset
读取一个 csv 文件来形成一个数据集,然后我使用命令 take()
来形成另一个只有一个元素的数据集,但仍然 returns所有元素。
这里有什么问题?我带来了下面的代码:
import tensorflow as tf
import os
tf.enable_eager_execution()
# Constants
column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']
batch_size = 1
feature_names = column_names[:-1]
label_name = column_names[-1]
# to reorient data strucute
def pack_features_vector(features, labels):
"""Pack the features into a single array."""
features = tf.stack(list(features.values()), axis=1)
return features, labels
# Download the file
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"
train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
origin=train_dataset_url)
# form the dataset
train_dataset = tf.contrib.data.make_csv_dataset(
train_dataset_fp,
batch_size,
column_names=column_names,
label_name=label_name,
num_epochs=1)
# perform the mapping
train_dataset = train_dataset.map(pack_features_vector)
# construct a databse with one element
train_dataset= train_dataset.take(1)
# inspect elements
for step in range(10):
features, labels = next(iter(train_dataset))
print(list(features))
基于 的答案,我们可以将数据集拆分为 Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
如何修复您的代码?
不要在循环中多次创建迭代器,而是使用一个迭代器:
# inspect elements
for feature, label in train_dataset:
print(feature)
您的代码中发生了什么导致这种行为?
1) 内置 python iter
function gets an iterator from an object or the object itself must supply its own iterator. So when you call iter(train_dataset)
, it is equavalent to call Dataset.make_one_shot_iterator()
.
2) 默认情况下,在 tf.contrib.data.make_csv_dataset()
中,shuffle 参数为 True (shuffle=True
)。因此,每次调用 iter(train_dataset)
时,它都会创建包含不同数据的新迭代器。
3) 最后,当通过 for step in range(10)
循环时,您创建了 10 个大小为 1 的不同迭代器,每个迭代器都有自己的数据,因为它们被打乱了。
建议:如果你想避免这样的事情在循环外初始化(创建)迭代器:
train_dataset = train_dataset.take(1)
iterator = train_dataset.make_one_shot_iterator()
# inspect elements
for step in range(10):
features, labels = next(iterator)
print(list(features))
# throws exception because size of iterator is 1
我正在使用 tf.contrib.data.make_csv_dataset
读取一个 csv 文件来形成一个数据集,然后我使用命令 take()
来形成另一个只有一个元素的数据集,但仍然 returns所有元素。
这里有什么问题?我带来了下面的代码:
import tensorflow as tf
import os
tf.enable_eager_execution()
# Constants
column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']
batch_size = 1
feature_names = column_names[:-1]
label_name = column_names[-1]
# to reorient data strucute
def pack_features_vector(features, labels):
"""Pack the features into a single array."""
features = tf.stack(list(features.values()), axis=1)
return features, labels
# Download the file
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"
train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
origin=train_dataset_url)
# form the dataset
train_dataset = tf.contrib.data.make_csv_dataset(
train_dataset_fp,
batch_size,
column_names=column_names,
label_name=label_name,
num_epochs=1)
# perform the mapping
train_dataset = train_dataset.map(pack_features_vector)
# construct a databse with one element
train_dataset= train_dataset.take(1)
# inspect elements
for step in range(10):
features, labels = next(iter(train_dataset))
print(list(features))
基于 Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
如何修复您的代码?
不要在循环中多次创建迭代器,而是使用一个迭代器:
# inspect elements
for feature, label in train_dataset:
print(feature)
您的代码中发生了什么导致这种行为?
1) 内置 python iter
function gets an iterator from an object or the object itself must supply its own iterator. So when you call iter(train_dataset)
, it is equavalent to call Dataset.make_one_shot_iterator()
.
2) 默认情况下,在 tf.contrib.data.make_csv_dataset()
中,shuffle 参数为 True (shuffle=True
)。因此,每次调用 iter(train_dataset)
时,它都会创建包含不同数据的新迭代器。
3) 最后,当通过 for step in range(10)
循环时,您创建了 10 个大小为 1 的不同迭代器,每个迭代器都有自己的数据,因为它们被打乱了。
建议:如果你想避免这样的事情在循环外初始化(创建)迭代器:
train_dataset = train_dataset.take(1)
iterator = train_dataset.make_one_shot_iterator()
# inspect elements
for step in range(10):
features, labels = next(iterator)
print(list(features))
# throws exception because size of iterator is 1