TensorFlow:数据集应用方法的简单自定义 transformation_func 的示例实现
TensorFlow: Example implementation of a simple custom transformation_func for Dataset's apply method
我正在尝试为数据集 API 中的 apply 方法实现一个简单的自定义 transformation_func
,但没有发现文档特别有用。
具体来说,我的 dataset
包含视频帧和相应的标签:{[frame_0, label_0], [frame_1, label_1], [frame_2, label_2],...}
。
我想对其进行转换,以便它另外包含每个标签的前一帧:{[frame_0, frame_1, label_1], [frame_1, frame_2, label_2], [frame_2, frame_3, label_3],...}
。
这可能可以通过做类似 tf.data.Dataset.zip(dataset, dataset.skip(1))
的事情来实现,但那样我就会有重复的标签。
我没能找到 transformation_func
的参考实现。有人能让我开始做这件事吗?
apply
只是为了方便与现有的转换函数一起使用,ds.apply(func)
与 func(ds)
几乎相同,只是更 "chainable" 的方式。这是一种可能的方法来做你想做的事:
import tensorflow as tf
frames = tf.constant([ 1, 2, 3, 4, 5, 6], dtype=tf.int32)
labels = tf.constant(['a', 'b', 'c', 'd', 'e', 'f'], dtype=tf.string)
# Create dataset
ds = tf.data.Dataset.from_tensor_slices((frames, labels))
# Zip it with itself but skipping the first one
ds = tf.data.Dataset.zip((ds, ds.skip(1)))
# Make desired output structure
ds = ds.map(lambda fl1, fl2: (fl1[0], fl2[0], fl2[1]))
# Iterate
it = ds.make_one_shot_iterator()
elem = it.get_next()
# Test
with tf.Session() as sess:
while True:
try: print(sess.run(elem))
except tf.errors.OutOfRangeError: break
输出:
(1, 2, b'b')
(2, 3, b'c')
(3, 4, b'd')
(4, 5, b'e')
(5, 6, b'f')
我正在尝试为数据集 API 中的 apply 方法实现一个简单的自定义 transformation_func
,但没有发现文档特别有用。
具体来说,我的 dataset
包含视频帧和相应的标签:{[frame_0, label_0], [frame_1, label_1], [frame_2, label_2],...}
。
我想对其进行转换,以便它另外包含每个标签的前一帧:{[frame_0, frame_1, label_1], [frame_1, frame_2, label_2], [frame_2, frame_3, label_3],...}
。
这可能可以通过做类似 tf.data.Dataset.zip(dataset, dataset.skip(1))
的事情来实现,但那样我就会有重复的标签。
我没能找到 transformation_func
的参考实现。有人能让我开始做这件事吗?
apply
只是为了方便与现有的转换函数一起使用,ds.apply(func)
与 func(ds)
几乎相同,只是更 "chainable" 的方式。这是一种可能的方法来做你想做的事:
import tensorflow as tf
frames = tf.constant([ 1, 2, 3, 4, 5, 6], dtype=tf.int32)
labels = tf.constant(['a', 'b', 'c', 'd', 'e', 'f'], dtype=tf.string)
# Create dataset
ds = tf.data.Dataset.from_tensor_slices((frames, labels))
# Zip it with itself but skipping the first one
ds = tf.data.Dataset.zip((ds, ds.skip(1)))
# Make desired output structure
ds = ds.map(lambda fl1, fl2: (fl1[0], fl2[0], fl2[1]))
# Iterate
it = ds.make_one_shot_iterator()
elem = it.get_next()
# Test
with tf.Session() as sess:
while True:
try: print(sess.run(elem))
except tf.errors.OutOfRangeError: break
输出:
(1, 2, b'b')
(2, 3, b'c')
(3, 4, b'd')
(4, 5, b'e')
(5, 6, b'f')