'Error While Encoding with Hub.KerasLayer' 在使用 TFF 时

'Error While Encoding with Hub.KerasLayer' while using TFF

训练使用 hub.KerasLayer 的联邦模型时出现错误。下面给出了错误和堆栈跟踪的详细信息。 gist https://gist.github.com/aksingh2411/60796ee58c88e0c3f074c8909b17b5a1 提供了完整的代码。在这方面的帮助和建议将不胜感激。谢谢

from tensorflow import keras

def create_keras_model():
 encoder = hub.load("https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1")
 return tf.keras.models.Sequential([
  hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
  keras.layers.Dense(32, activation='relu'),
  keras.layers.Dense(16, activation='relu'),
  keras.layers.Dense(1, activation='sigmoid'),
])

def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
  keras_model,
  input_spec=preprocessed_example_dataset.element_spec,
  loss=tf.keras.losses.BinaryCrossentropy(),
  metrics=[tf.keras.metrics.Accuracy()])

# Building the Federated Averaging Process
iterative_process = tff.learning.build_federated_averaging_process(
 model_fn,
 client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
 server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

str(iterative_process.initialize.type_signature)
state = iterative_process.initialize()

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))

UnimplementedError                        Traceback (most recent call last)
<ipython-input-80-39d62fa827ea> in <module>()
----> 1 state, metrics = iterative_process.next(state, federated_train_data)
  2 print('round  1, metrics={}'.format(metrics))

119 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in 
quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
 58     ctx.ensure_initialized()
 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
 61   except core._NotOkStatusException as e:
 62     if name is not None:

UnimplementedError:    Cast string to float is not supported
 [[{{node StatefulPartitionedCall_1/StatefulPartitionedCall/Cast_1}}]]
 [[StatefulPartitionedCall_1]]
 [[import/StatefulPartitionedCall_3/ReduceDataset]] [Op:__inference_wrapped_function_65986]

Function call stack:
wrapped_function -> wrapped_function -> wrapped_function

问题现已解决。抛出错误是因为 'label' 被作为 tf.string 而不是 tf.int32 传递。显式转换解决了这个问题。