`fit()` 方法上的 TensorFlow `AssertionError`
TensorFlow `AssertionError` on `fit()` method
将我的 tf.Dataset
传递到 tf.Keras
模型的 fit()
方法时,我得到一个 AssertionError
。
我正在使用 tensorflow==2.0.0
。
我检查了我的数据集是否有效:
# for x,y in dataset:
# print(x.shape, y.shape)
这会为模型输入数据生成正确的形状。
完整的跟踪是:
Traceback (most recent call last):
File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/me/train.py", line 102, in <module>
start_training(**arguments)
File "/me/train.py", line 66, in start_training
steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 789, in fit
*args, **kwargs)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 776, in wrapper
mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 782, in run_distribute_coordinator
rpc_layer)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
assert strategy
AssertionError
在 tensorflow 2.0.0 的最终版本 运行ning gcloud ai-platform local train
时,我遇到了同样的错误。但是,它正在处理早期版本。尝试降级到 2.0.0b1:
pip install tensorflow==2.0.0b1
--
还发现,如果您直接在 python 中 运行 或在云中 运行 就不会出现此错误。
如果您在本地训练而不使用任何分布式策略,您可以在代码中添加以下行来解决此问题:
TF_CONFIG = os.environ.get('TF_CONFIG')
if TF_CONFIG:
os.environ.pop('TF_CONFIG')
将我的 tf.Dataset
传递到 tf.Keras
模型的 fit()
方法时,我得到一个 AssertionError
。
我正在使用 tensorflow==2.0.0
。
我检查了我的数据集是否有效:
# for x,y in dataset:
# print(x.shape, y.shape)
这会为模型输入数据生成正确的形状。
完整的跟踪是:
Traceback (most recent call last):
File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/me/train.py", line 102, in <module>
start_training(**arguments)
File "/me/train.py", line 66, in start_training
steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 789, in fit
*args, **kwargs)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 776, in wrapper
mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 782, in run_distribute_coordinator
rpc_layer)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
assert strategy
AssertionError
在 tensorflow 2.0.0 的最终版本 运行ning gcloud ai-platform local train
时,我遇到了同样的错误。但是,它正在处理早期版本。尝试降级到 2.0.0b1:
pip install tensorflow==2.0.0b1
--
还发现,如果您直接在 python 中 运行 或在云中 运行 就不会出现此错误。
如果您在本地训练而不使用任何分布式策略,您可以在代码中添加以下行来解决此问题:
TF_CONFIG = os.environ.get('TF_CONFIG')
if TF_CONFIG:
os.environ.pop('TF_CONFIG')