从 tfrecords 数据集生成跨条切片数据集
Produce a dataset of stridded slices from a tfrecords dataset
从 question and the discussion here 继续 - 我正在尝试使用数据集 API 获取可变长度张量的数据集并将它们切成等长的切片(段)。类似于:
Dataset = tf.contrib.data.Dataset
segment_len = 6
batch_size = 16
with tf.Graph().as_default() as g:
# get the tfrecords dataset
dataset = tf.contrib.data.TFRecordDataset(filenames).map(
partial(record_type.parse_single_example, graph=g)).batch(batch_size)
# zip it with the number of segments we need to slice each tensor
dataset2 = Dataset.zip((dataset, Dataset.from_tensor_slices(
tf.constant(num_segments, dtype=tf.int64))))
it2 = dataset2.make_initializable_iterator()
def _dataset_generator():
with g.as_default():
while True:
try:
(im, length), count = sess.run(it2.get_next())
dataset3 = Dataset.zip((
# repeat each tensor then use map to take a stridded slice
Dataset.from_tensors((im, length)).repeat(count),
Dataset.range(count))).map(lambda x, c: (
x[0][:, c: c + segment_len],
x[0][:, c + 1: (c + 1) + segment_len],
))
it = dataset3.make_initializable_iterator()
it_init = it.initializer
try:
yield it_init
while True:
yield sess.run(it.get_next())
except tf.errors.OutOfRangeError:
continue
except tf.errors.OutOfRangeError:
return
# Dataset.from_generator need tensorflow > 1.3 !
das_dataset = Dataset.from_generator(
_dataset_generator,
(tf.float32, tf.float32),
# (tf.TensorShape([]), tf.TensorShape([]))
)
das_dataset_it = das_dataset.make_one_shot_iterator()
with tf.Session(graph=g) as sess:
while True:
print(sess.run(it2.initializer))
print(sess.run(das_dataset_it.get_next()))
当然我不想在生成器中传递会话,但这应该通过 link 中给出的技巧解决(创建一个虚拟数据集并映射另一个的迭代器)。上面的代码因圣经而失败:
tensorflow.python.framework.errors_impl.InvalidArgumentError: TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.framework.ops.Operation'>.
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[<unknown>, <unknown>], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]
我猜这是因为我尝试生成迭代器的初始值设定项,但我的问题基本上是我是否可以使用数据集 API 实现我正在尝试的所有目标。
从嵌套 Dataset
构建 Dataset
的最简单方法是使用 Dataset.flat_map()
转换。此转换将一个函数应用于输入数据集的每个元素(在您的示例中为 dataset2
),该函数 returns 嵌套的 Dataset
(在您的示例中很可能为 dataset3
),然后转换将所有嵌套数据集展平为单个 Dataset
.
dataset2 = ... # As above.
def get_slices(im_and_length, count):
im, length = im_and_length
# Repeat each tensor then use map to take a strided slice.
return Dataset.zip((
Dataset.from_tensors((im, length)).repeat(count),
Dataset.range(count))).map(lambda x, c: (
x[0][:, c + segment_len: (c + 1) + segment_len],
x[0][:, c + 1 + segment_len: (c + 2) + segment_len],
))
das_dataset = dataset2.flat_map(get_slices)
从
Dataset = tf.contrib.data.Dataset
segment_len = 6
batch_size = 16
with tf.Graph().as_default() as g:
# get the tfrecords dataset
dataset = tf.contrib.data.TFRecordDataset(filenames).map(
partial(record_type.parse_single_example, graph=g)).batch(batch_size)
# zip it with the number of segments we need to slice each tensor
dataset2 = Dataset.zip((dataset, Dataset.from_tensor_slices(
tf.constant(num_segments, dtype=tf.int64))))
it2 = dataset2.make_initializable_iterator()
def _dataset_generator():
with g.as_default():
while True:
try:
(im, length), count = sess.run(it2.get_next())
dataset3 = Dataset.zip((
# repeat each tensor then use map to take a stridded slice
Dataset.from_tensors((im, length)).repeat(count),
Dataset.range(count))).map(lambda x, c: (
x[0][:, c: c + segment_len],
x[0][:, c + 1: (c + 1) + segment_len],
))
it = dataset3.make_initializable_iterator()
it_init = it.initializer
try:
yield it_init
while True:
yield sess.run(it.get_next())
except tf.errors.OutOfRangeError:
continue
except tf.errors.OutOfRangeError:
return
# Dataset.from_generator need tensorflow > 1.3 !
das_dataset = Dataset.from_generator(
_dataset_generator,
(tf.float32, tf.float32),
# (tf.TensorShape([]), tf.TensorShape([]))
)
das_dataset_it = das_dataset.make_one_shot_iterator()
with tf.Session(graph=g) as sess:
while True:
print(sess.run(it2.initializer))
print(sess.run(das_dataset_it.get_next()))
当然我不想在生成器中传递会话,但这应该通过 link 中给出的技巧解决(创建一个虚拟数据集并映射另一个的迭代器)。上面的代码因圣经而失败:
tensorflow.python.framework.errors_impl.InvalidArgumentError: TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.framework.ops.Operation'>.
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[<unknown>, <unknown>], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]
我猜这是因为我尝试生成迭代器的初始值设定项,但我的问题基本上是我是否可以使用数据集 API 实现我正在尝试的所有目标。
从嵌套 Dataset
构建 Dataset
的最简单方法是使用 Dataset.flat_map()
转换。此转换将一个函数应用于输入数据集的每个元素(在您的示例中为 dataset2
),该函数 returns 嵌套的 Dataset
(在您的示例中很可能为 dataset3
),然后转换将所有嵌套数据集展平为单个 Dataset
.
dataset2 = ... # As above.
def get_slices(im_and_length, count):
im, length = im_and_length
# Repeat each tensor then use map to take a strided slice.
return Dataset.zip((
Dataset.from_tensors((im, length)).repeat(count),
Dataset.range(count))).map(lambda x, c: (
x[0][:, c + segment_len: (c + 1) + segment_len],
x[0][:, c + 1 + segment_len: (c + 2) + segment_len],
))
das_dataset = dataset2.flat_map(get_slices)