get 'TypeError: Caught exception' for using 'accuracy' in Tensorflow Federated

get 'TypeError: Caught exception' for using 'accuracy' in Tensorflow Federated

这是我的模型,我在TensorFlow中实现过一次。

def create_compiled_keras_model():

    inputs = Input(shape=(7, 20, 1))
    l0_c = Conv2D(32, kernel_size=(7, 7), padding='valid', activation='relu')(inputs)
    l1_c = Conv2D(32, kernel_size=(1, 5), padding='same', activation='relu')(l0_c)
    l1_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')(l1_c)
    l2_c = Conv2D(64, kernel_size=(1, 4), padding='same', activation='relu')(l1_p)
    l2_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')
    l3_c = Conv2D(2, kernel_size=(1, 1), padding='valid', activation='sigmoid')(l2_p)
    predictions = Flatten()(l3_c)
    predictions = tf.cast(predictions, dtype='float32')
    model = Model(inputs=inputs, outputs=predictions)
    opt = Adam(lr=0.0005)
    print(model.summary())
    def loss_fn(y_true, y_pred):
        return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_pred, y_true))
    model.compile(optimizer=opt,
                  loss=loss_fn,
                  metrics=['accuracy'])
    return model

我在 TensorFlow Federated 中遇到此错误。

Traceback (most recent call last):
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 270, in report
    keras_metric = metric_type.from_config(metric_config)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 594, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'fn'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 203, in <module>
    quantization_part = FedAvgQ.build_federated_averaging_process(model_fn)
  File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/new_fedavg_keras.py", line 195, in build_federated_averaging_process
    stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py", line 351, in build_model_delta_optimizer_process
    dummy_model_for_metadata = model_utils.enhance(model_fn())
  File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 196, in model_fn
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 216, in from_compiled_keras_model
    return model_utils.enhance(_TrainableKerasModel(keras_model, dummy_tensors))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 491, in __init__
    inner_model.loss_weights, inner_model.metrics)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 381, in __init__
    federated_output, federated_local_outputs_type)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/api/computations.py", line 223, in federated_computation
    return computation_wrapper_instances.federated_computation_wrapper(*args)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 410, in __call__
    self._wrapper_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 78, in _federated_computation_wrapper_fn
    suggested_name=name))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/federated_computation_utils.py", line 76, in zero_or_one_arg_fn_to_building_block
    context_stack))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
    return lambda arg: _call(fn, parameter_type, arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
    return fn(arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 377, in federated_output
    type(metric), metric.get_config(), variables)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 260, in federated_aggregate_keras_metric
    @tff.tf_computation(member_type)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 415, in <lambda>
    return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 44, in _tf_wrapper_fn
    target_fn, parameter_type, ctx_stack)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/tensorflow_serialization.py", line 278, in serialize_py_fn_as_tf_computation
    result = target(*args)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
    return lambda arg: _call(fn, parameter_type, arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
    return fn(arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 278, in report
    t=metric_type, c=metric_config, e=e))
TypeError: Caught exception trying to call `<class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.from_config()` with config {'name': 'accuracy', 'dtype': 'float32'}. Confirm that <class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.__init__() has an argument for each member of the config.
Exception: __init__() missing 1 required positional argument: 'fn'

我的数据集的标签是一种两个标签 [0. 1.] 并且我使用 binary_crossentropy 作为损失函数。但是准确性可以恢复错误。我确定它与多个标签有关。当我删除精度时,损失计算没有任何问题。任何帮助将不胜感激。

不幸的是,TensorFlow Federated 无法理解使用字符串参数编译的 Keras 模型。 TFF 要求模型上的 compile() 调用被赋予 tf.keras.losses.Losstf.keras.metrics.Metric 的实例。应该可以将有问题的代码的最后一部分更改为:

model.compile(optimizer=opt,
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.Accuracy()])

请注意,不需要定义自定义损失函数,Keras 提供了 canned binary crossentropy