RL4J A3C 深度学习从网络抛出输出不是概率分布

RL4J A3C DeepLearning Throwing a Output from network is not a probability distribution

所以现在我正在痛苦地探索使用深度学习 4j 的深度学习,特别是 RL4j 和强化学习。我在教我的电脑如何玩贪吃蛇方面相对失败,但我坚持了下来。

无论如何,我 运行 遇到了一个我无法解决的问题 我会在睡觉或工作时将我的程序设置为 运行(是的,我工作在一个重要的行业),当我回来查看时,它在所有 运行ning 线程上抛出了这个错误,并且程序已经完全停止,请注意,这通常发生在训练后大约一个小时。

Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[         ?,         ?,         ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)

这是我设置网络的方式

    private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
        new A3CDiscrete.A3CConfiguration(
                (new java.util.Random()).nextInt(),            //Random seed
                220,            //Max step By epoch
                500000,         //Max step
                6,              //Number of threads
                50,              //t_max
                75,             //num step noop warmup
                0.1,           //reward scaling
                0.987,           //gamma
                1.0           //td-error clipping
        );


private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C =  ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();

此外,我网络的输入是将我的蛇游戏 16x16 的整个网格放入一个双数组中。

如果它与我的奖励功能有关,那就是

if(!snake.inGame()) {
        return -5.3; //snake dies 
    }
    if(snake.gotApple()) {
        return 5.0+.37*(snake.getLength()); //snake gets apple
    }
    return 0; //survives

我的问题是 如何阻止此错误的发生? 我真的不知道发生了什么,它让我的网络建设变得相当困难,是的,我已经在网上查看了答案所有出现的都是 2018 年的 2 GitHub 票。

如果您对它感兴趣,那么您不必深入挖掘这里是 ACPolicy 中抛出错误的函数

 public Integer nextAction(INDArray input) {
    INDArray output = actorCritic.outputAll(input)[1];
    if (rnd == null) {
        return Learning.getMaxAction(output);
    }
    float rVal = rnd.nextFloat();
    for (int i = 0; i < output.length(); i++) {
        //System.out.println(i + " " + rVal + " " + output.getFloat(i));
        if (rVal < output.getFloat(i)) {
            return i;
        } else
            rVal -= output.getFloat(i);
    }

    throw new RuntimeException("Output from network is not a probability distribution: " + output);
}

非常感谢您提供的任何帮助

您看到的是您的网络 运行 变成了 NaN。这就是异常中问号的意思。发生这种情况的原因有很多。你说,你 运行 它已经有一段时间了,所以你可能在某个时候变得不足或溢出。一些正则化或一些梯度裁剪可能有帮助。

但是,RL4J 本身正在从 beta6 开始重新设计,下一个版本应该会处于更好的状态。

如果您想尝试当前状态,可以使用快照,https://github.com/RobAltena/cartpole/blob/master/src/main/java/A3CCartpole.java

上还有一个可用的 A3C 示例

如需更全面的帮助,您可能应该查看位于 community.konduit.ai 的 DL4J 社区论坛。它更适合帮助您为贪吃蛇游戏构建成功的 AI 所需的来回操作。