tff:定义Tensorflow.take()函数的用法

Tff: define the usage of Tensorflow.take() function

我正在尝试模仿 here: Working with tff's clientData 提供的联合学习实现,以便清楚地理解代码。 我已经到了需要澄清的地步。

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
  1. dataset.batch(5)指的是什么?这些批次是从数据中提取训练而 3 个批次用于测试吗?
  2. .take(5) 是什么意思?

这一行:

dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)

您首先将 dataset 中的样本分成 5 个批次。之后,您将 map_fn 函数应用于 dataset 中的每个批次(一次 5 个样本) .最后,使用 dataset.take(5),您将从 dataset 返回 5 个批次,其中每个批次有 5 个样本。

在您链接的示例中,client_data 包含多个 tf 数据集。