`fit()` 方法上的 TensorFlow `AssertionError`

TensorFlow `AssertionError` on `fit()` method

将我的 tf.Dataset 传递到 tf.Keras 模型的 fit() 方法时,我得到一个 AssertionError

我正在使用 tensorflow==2.0.0

我检查了我的数据集是否有效:

# for x,y in dataset:
#     print(x.shape, y.shape)

这会为模型输入数据生成正确的形状。

完整的跟踪是:

Traceback (most recent call last):
  File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/me/train.py", line 102, in <module>
    start_training(**arguments)
  File "/me/train.py", line 66, in start_training
    steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 789, in fit
    *args, **kwargs)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 776, in wrapper
    mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 782, in run_distribute_coordinator
    rpc_layer)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
    assert strategy
AssertionError

在 tensorflow 2.0.0 的最终版本 运行ning gcloud ai-platform local train 时,我遇到了同样的错误。但是,它正在处理早期版本。尝试降级到 2.0.0b1:

pip install tensorflow==2.0.0b1

--

还发现,如果您直接在 python 中 运行 或在云中 运行 就不会出现此错误。

如果您在本地训练而不使用任何分布式策略,您可以在代码中添加以下行来解决此问题:

  TF_CONFIG = os.environ.get('TF_CONFIG')
  if TF_CONFIG:
    os.environ.pop('TF_CONFIG')