无法将 tensorflow.python.data.ops.dataset_ops.PrefetchDataset 类型的参数解释为迭代过程中的 TFF 值

Unable to interpret an argument of type tensorflow.python.data.ops.dataset_ops.PrefetchDataset as a TFF value in iterative process

我正在尝试 运行 在 tff 中进行分类模拟,但出现此错误:

TypeError: Unable to interpret an argument of type tensorflow.python.data.ops.dataset_ops.PrefetchDataset as a TFF value.

这是我正在使用的代码

client_lr = 1e-3
server_lr = 1e-1
NUM_ROUNDS = 200
NUM_EPOCHS = 5
BATCH_SIZE = 2048
EPOCHS = 400
TH = 0.5

def base_model():
  return Sequential([
        Dense(256, activation='relu', input_shape=(x_train.shape[-1],)),
        Dropout(0.5),
        Dense(256, activation='relu'),
        Dropout(0.5),
        Dense(256, activation='relu'),
        Dropout(0.5),
        Dense(1, activation='sigmoid'),
    ])

client_train_dataset = collections.OrderedDict()
for i in range(1, total_clients+1):
  client_name = "client_" + str(i)
  start = samples_per_set * (i-1)
  end = samples_per_set * i
  data = collections.OrderedDict((('y', y_train[start:end]), ('x', x_train[start:end])))
  client_train_dataset[client_name] = data

train_dataset = tff.simulation.FromTensorSlicesClientData(client_train_dataset)

sample_dataset = train_dataset.create_tf_dataset_for_client(train_dataset.client_ids[0])
sample_element = next(iter(sample_dataset))
PREFETCH_BUFFER = 10
SHUFFLE_BUFFER = samples_per_set

def preprocess(dataset):

  def batch_format_fn(element):
    return collections.OrderedDict(
        x=element['x'],
        y=tf.reshape(element['y'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

preprocessed_sample_dataset = preprocess(sample_dataset)
sample_batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(preprocessed_sample_dataset)))

def make_federated_data(client_data, client_ids):
    return [preprocess(client_data.create_tf_dataset_for_client(x)) for x in client_ids]

federated_train_data = make_federated_data(train_dataset, train_dataset.client_ids)

def model_tff():
  model = base_model()
  return tff.learning.from_keras_model(
      model,
      input_spec=preprocessed_sample_dataset.element_spec,
      loss=tf.keras.losses.BinaryCrossentropy(),
      metrics=[
               tfa.metrics.F1Score(num_classes=1, threshold=TH),
               keras.metrics.Precision(name="precision", thresholds=TH),
               keras.metrics.Recall(name="recall", thresholds=TH)
              ])

iterative_process = tff.learning.build_federated_averaging_process(
    model_tff,
    client_optimizer_fn=lambda: optimizers.Adam(learning_rate=client_lr),
    server_optimizer_fn=lambda: optimizers.SGD(learning_rate=server_lr))

state = iterative_process.initialize()

federated_model = None

for round_num in range(1, NUM_ROUNDS+1):
    state, tff_metrics = iterative_process.next(state, federated_train_data) # THE ERROR IS HERE
    federated_model = base_model()
    federated_model.compile(optimizer=optimizers.Adam(learning_rate=client_lr),
                        loss=tf.keras.losses.BinaryCrossentropy(),
                        metrics=[
                              tfa.metrics.F1Score(num_classes=1, threshold=TH),
                              keras.metrics.Precision(name="precision", thresholds=TH),
                              keras.metrics.Recall(name="recall", thresholds=TH)
                              ])
    state.model.assign_weights_to(model=federated_model)
    federated_result = federated_model.evaluate(x_val, y_val, verbose=1, return_dict=True)
    
federated_test = federated_model.evaluate(x_test, y_test, verbose=1, return_dict=True)

我正在使用这个信用卡数据集:https://www.kaggle.com/mlg-ulb/creditcardfraud

federated_train_data<PrefetchDataset shapes: OrderedDict([(x, (None, 29)), (y, (None, 1))]), types: OrderedDict([(x, tf.float64), (y, tf.int64)])> 的列表,就像 Tensorflow Federated 网站上的教程 Federated Learning for Image Classification.

这可能是 issue#918。这是否仅在 Google Colab 中 运行 时发生?使用的是哪个版本的 TFF?

Commit#4e57386 is believed to have fixed this, which is now part of the tensorflow-federated-nightly pip 包。