tf.data.Dataset.map() 用于由多个切片组成的数据集

tf.data.Dataset.map() for datasets made from multiple slices

从单个切片创建的数据集的 tf.data.Dataset.map() 看起来像 dataset.map(lambda x: x/2)。如果数据集是从两个切片创建的,它会是什么样子?例如,请参见以下代码。代码最后一行中的 map() 函数适用于从单个切片创建的数据集,但会导致我的双切片情况出错。

import tensorflow as tf, numpy as np     # tensorflow 2.0
from tensorflow import keras as kr

dataset = tf.data.Dataset.from_tensor_slices((features_int8, labels_int8)) # features, labels are numpy arrays

model = kr.Sequential()
model.add(kr.layers.InputLayer(6)
model.add(kr.layers.Dense(     8, activation=tf.nn.tanh))
model.add(kr.layers.Dense(     3, activation=tf.nn.tanh))

model.compile(optimizer = kr.optimizers.RMSprop(), loss = kr.losses.MeanSquaredError())

model.fit(dataset.batch(64).map(lambda x: x/9), epochs = 10)

如图所示,在单独的函数中传递 lambda 函数

def map_fn(x, y):
  return x / 9, y

model.fit(dataset.batch(64).map(map_fn), epochs = 10)