RL4J A3C 深度学习从网络抛出输出不是概率分布
RL4J A3C DeepLearning Throwing a Output from network is not a probability distribution
所以现在我正在痛苦地探索使用深度学习 4j 的深度学习,特别是 RL4j 和强化学习。我在教我的电脑如何玩贪吃蛇方面相对失败,但我坚持了下来。
无论如何,我 运行 遇到了一个我无法解决的问题 我会在睡觉或工作时将我的程序设置为 运行(是的,我工作在一个重要的行业),当我回来查看时,它在所有 运行ning 线程上抛出了这个错误,并且程序已经完全停止,请注意,这通常发生在训练后大约一个小时。
Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[ ?, ?, ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)
这是我设置网络的方式
private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
new A3CDiscrete.A3CConfiguration(
(new java.util.Random()).nextInt(), //Random seed
220, //Max step By epoch
500000, //Max step
6, //Number of threads
50, //t_max
75, //num step noop warmup
0.1, //reward scaling
0.987, //gamma
1.0 //td-error clipping
);
private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();
此外,我网络的输入是将我的蛇游戏 16x16 的整个网格放入一个双数组中。
如果它与我的奖励功能有关,那就是
if(!snake.inGame()) {
return -5.3; //snake dies
}
if(snake.gotApple()) {
return 5.0+.37*(snake.getLength()); //snake gets apple
}
return 0; //survives
我的问题是
如何阻止此错误的发生?
我真的不知道发生了什么,它让我的网络建设变得相当困难,是的,我已经在网上查看了答案所有出现的都是 2018 年的 2 GitHub 票。
如果您对它感兴趣,那么您不必深入挖掘这里是 ACPolicy 中抛出错误的函数
public Integer nextAction(INDArray input) {
INDArray output = actorCritic.outputAll(input)[1];
if (rnd == null) {
return Learning.getMaxAction(output);
}
float rVal = rnd.nextFloat();
for (int i = 0; i < output.length(); i++) {
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
if (rVal < output.getFloat(i)) {
return i;
} else
rVal -= output.getFloat(i);
}
throw new RuntimeException("Output from network is not a probability distribution: " + output);
}
非常感谢您提供的任何帮助
您看到的是您的网络 运行 变成了 NaN。这就是异常中问号的意思。发生这种情况的原因有很多。你说,你 运行 它已经有一段时间了,所以你可能在某个时候变得不足或溢出。一些正则化或一些梯度裁剪可能有帮助。
但是,RL4J 本身正在从 beta6 开始重新设计,下一个版本应该会处于更好的状态。
如果您想尝试当前状态,可以使用快照,https://github.com/RobAltena/cartpole/blob/master/src/main/java/A3CCartpole.java
上还有一个可用的 A3C 示例
如需更全面的帮助,您可能应该查看位于 community.konduit.ai 的 DL4J 社区论坛。它更适合帮助您为贪吃蛇游戏构建成功的 AI 所需的来回操作。
所以现在我正在痛苦地探索使用深度学习 4j 的深度学习,特别是 RL4j 和强化学习。我在教我的电脑如何玩贪吃蛇方面相对失败,但我坚持了下来。
无论如何,我 运行 遇到了一个我无法解决的问题 我会在睡觉或工作时将我的程序设置为 运行(是的,我工作在一个重要的行业),当我回来查看时,它在所有 运行ning 线程上抛出了这个错误,并且程序已经完全停止,请注意,这通常发生在训练后大约一个小时。
Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[ ?, ?, ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)
这是我设置网络的方式
private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
new A3CDiscrete.A3CConfiguration(
(new java.util.Random()).nextInt(), //Random seed
220, //Max step By epoch
500000, //Max step
6, //Number of threads
50, //t_max
75, //num step noop warmup
0.1, //reward scaling
0.987, //gamma
1.0 //td-error clipping
);
private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();
此外,我网络的输入是将我的蛇游戏 16x16 的整个网格放入一个双数组中。
如果它与我的奖励功能有关,那就是
if(!snake.inGame()) {
return -5.3; //snake dies
}
if(snake.gotApple()) {
return 5.0+.37*(snake.getLength()); //snake gets apple
}
return 0; //survives
我的问题是 如何阻止此错误的发生? 我真的不知道发生了什么,它让我的网络建设变得相当困难,是的,我已经在网上查看了答案所有出现的都是 2018 年的 2 GitHub 票。
如果您对它感兴趣,那么您不必深入挖掘这里是 ACPolicy 中抛出错误的函数
public Integer nextAction(INDArray input) {
INDArray output = actorCritic.outputAll(input)[1];
if (rnd == null) {
return Learning.getMaxAction(output);
}
float rVal = rnd.nextFloat();
for (int i = 0; i < output.length(); i++) {
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
if (rVal < output.getFloat(i)) {
return i;
} else
rVal -= output.getFloat(i);
}
throw new RuntimeException("Output from network is not a probability distribution: " + output);
}
非常感谢您提供的任何帮助
您看到的是您的网络 运行 变成了 NaN。这就是异常中问号的意思。发生这种情况的原因有很多。你说,你 运行 它已经有一段时间了,所以你可能在某个时候变得不足或溢出。一些正则化或一些梯度裁剪可能有帮助。
但是,RL4J 本身正在从 beta6 开始重新设计,下一个版本应该会处于更好的状态。
如果您想尝试当前状态,可以使用快照,https://github.com/RobAltena/cartpole/blob/master/src/main/java/A3CCartpole.java
上还有一个可用的 A3C 示例如需更全面的帮助,您可能应该查看位于 community.konduit.ai 的 DL4J 社区论坛。它更适合帮助您为贪吃蛇游戏构建成功的 AI 所需的来回操作。