Tensorflow gives "ValueError: Error when checking input"
Tensorflow gives "ValueError: Error when checking input"
我正在尝试使用深度 Q 网络代理解决 OpenAI gym Breakout-V0。
每当我的代理到达以下位置时:
- replay_memory已填满,可以开始训练了
- 第一次达到copy_target_network区间
- target_network第一次预测
Tensorflow 抛出以下错误:
Error when checking input: expected dense_3_input to have shape (33600,) but got array with shape (1,)
当我在调用 predict(state)
之前的 1 行打印传入 state
数组的形状时,它确认 state
的形状是 (33600,)
在显示此错误之前,模型能够 predict_on_batch()
在训练循环中使用完全相同的数据(但已批处理)
有人知道怎么解决吗?如果我遗漏任何内容,我很乐意提供更多详细信息和信息
版本:
Python3.8.7
TensorFlow 2.4.1
健身房 0.18.0
正如Dr.Snoopy所说,这是一个简单的解决方案
只需要做np.reshape(state, (1, 33600))
我正在尝试使用深度 Q 网络代理解决 OpenAI gym Breakout-V0。
每当我的代理到达以下位置时:
- replay_memory已填满,可以开始训练了
- 第一次达到copy_target_network区间
- target_network第一次预测
Tensorflow 抛出以下错误:
Error when checking input: expected dense_3_input to have shape (33600,) but got array with shape (1,)
当我在调用 predict(state)
之前的 1 行打印传入 state
数组的形状时,它确认 state
的形状是 (33600,)
在显示此错误之前,模型能够 predict_on_batch()
在训练循环中使用完全相同的数据(但已批处理)
有人知道怎么解决吗?如果我遗漏任何内容,我很乐意提供更多详细信息和信息
版本:
Python3.8.7
TensorFlow 2.4.1
健身房 0.18.0
正如Dr.Snoopy所说,这是一个简单的解决方案
只需要做np.reshape(state, (1, 33600))