tf.data.experimental.group_by_window() 如何在 Tensorflow 2.0 中运行

How tf.data.experimental.group_by_window() operates in Tensorflow 2.0

我正在尝试理解 Tensorflow 2 中的 tf.data.experimental.group_by_window() 方法,但我遇到了一些困难。

对于可重现的示例,我使用文档中提供的示例:

components = np.arange(100).astype(np.int64)
dataset20 = tf.data.Dataset.from_tensor_slices(components)
dataset20 = dataset.apply(tf.data.experimental.group_by_window(key_func=lambda x: x%2, reduce_func=lambda _,\
                                                          els: els.batch(10), window_size=100))

i = 0

for elem in dataset20:

    print('i is {0}\n'.format(i))

    print('elem is {0}'.format(elem.numpy()))

    i += 1

    print('\n--------------------------------\n')

i is 0

elem is [0 2 4 6 8]

--------------------------------

i is 1

elem is [1 3 5 7 9]

--------------------------------

部分混淆可能是输出与示例代码不对应。实际输出:

components = np.arange(100).astype(np.int64)
dataset20 = tf.data.Dataset.from_tensor_slices(components)
dataset20 = dataset20.apply(tf.data.experimental.group_by_window(key_func=lambda x: x%2, reduce_func=lambda _,els: els.batch(10), window_size=100))
for i, d in enumerate(dataset20): 
    print(i, d.numpy())

0 [ 0  2  4  6  8 10 12 14 16 18]
1 [20 22 24 26 28 30 32 34 36 38]
2 [40 42 44 46 48 50 52 54 56 58]
3 [60 62 64 66 68 70 72 74 76 78]
4 [80 82 84 86 88 90 92 94 96 98]
5 [ 1  3  5  7  9 11 13 15 17 19]
6 [21 23 25 27 29 31 33 35 37 39]
7 [41 43 45 47 49 51 53 55 57 59]
8 [61 63 65 67 69 71 73 75 77 79]
9 [81 83 85 87 89 91 93 95 97 99]

如文档 here 中所述,key func 将数据分成具有关联键值的组。在示例中,key func 将数据 [0, 99] 分成偶数组和奇数组。 reduce_func 然后对键、组对进行操作以生成另一个数据集。请注意,尽管 reduce_func 仅对不大于 window_size 的数据组进行操作。在该示例中,window 大小大于两个组大小(100 对 50 个元素),因此没有影响,所有偶数以 10 为一组给出,然后是所有赔率。如果 window size 更改为小于 50 的值,那么它确实有效果。例如,如果 window 大小更改为 5,并且批处理也移至 group_by_window 函数之外:

dataset20 = dataset20.apply(tf.data.experimental.group_by_window(key_func=lambda x: x%2, reduce_func=lambda _, els: els, window_size=5)).batch(10)

然后产生以下输出:

0 [0 2 4 6 8 1 3 5 7 9]
1 [10 12 14 16 18 11 13 15 17 19]
2 [20 22 24 26 28 21 23 25 27 29]
3 [30 32 34 36 38 31 33 35 37 39]
4 [40 42 44 46 48 41 43 45 47 49]
5 [50 52 54 56 58 51 53 55 57 59]
6 [60 62 64 66 68 61 63 65 67 69]
7 [70 72 74 76 78 71 73 75 77 79]
8 [80 82 84 86 88 81 83 85 87 89]
9 [90 92 94 96 98 91 93 95 97 99]