dl4j lstm 不成功

dl4j lstm not successful

我正在尝试将练习复制到此 link 页面的一半左右: https://d2l.ai/chapter_recurrent-neural-networks/sequence.html

该练习使用正弦函数在 -1 到 1 之间创建 1000 个数据点,并使用循环网络来逼近该函数。

下面是我使用的代码。我要回去研究更多为什么这不起作用,因为当我能够轻松地使用前馈网络来近似这个函数时,它对我来说没有多大意义。

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

你能解释一下我需要的代码吗?10 隐藏 1 输出 lstm 网络近似正弦函数?

我没有使用任何归一化,因为函数已经是 -1:1 并且我使用 Y 输入作为特征,随后的 Y 输入作为标签来训练网络。

你注意到我正在构建一个 class 可以更轻松地构建网络,我已经尝试对问题进行许多更改,但我厌倦了猜测。

以下是我的一些结果示例。蓝色是数据红色是结果

如果没有看到完整的代码,很难说出发生了什么。首先,我没有看到指定的 RnnOutputLayer。您可以看一下 this,它向您展示了如何在 DL4J 中构建 RNN。 如果您的 RNN 设置正确,这可能是一个调整问题。您可以找到更多关于调整 here 的信息。 Adam 可能是比 RMSProp 更好的更新程序选择。 tanh 可能是激活输出层的不错选择,因为它的范围是 (-1,1)。 check/tweak 的其他事情 - 学习率、时期数、数据设置(比如你是否试图预测太远?)。

这是你从想知道为什么这不起作用到我的原始结果到底是怎么这么好的时候之一。

我的失败是没有清楚地理解文档,也没有理解 BPTT。

对于前馈网络,每次迭代存储为一行,每个输入存储为一列。一个例子是[dataset.size,网络inputs.size]

然而,对于循环输入,它是相反的,每一行都是一个输入,每一列都是激活 lstm 事件链状态所必需的时间迭代。至少我的输入需要是 [0, networkinputs.size, dataset.size] 但也可以是 [dataset.size, networkinputs.size, statelength.size]

在我之前的示例中,我使用 [dataset.size, networkinputs.size, 1] 格式的数据训练网络。因此,根据我对低分辨率的理解,lstm 网络根本不应该工作,但至少会以某种方式产生一些东西。

将数据集转换为列表可能也存在一些问题,因为我还更改了我向网络提供数据的方式,但我认为大部分问题是数据结构问题。

以下是我的新结果