如何在 Encog 中 pause/serialize 遗传算法?

How can I pause/serialize a genetic algorithm in Encog?

如何在 Encog 3.4(Github 中目前正在开发的版本)中暂停遗传算法?

我正在使用 Java 版本的 Encog。

我正在尝试修改 Encog 附带的 Lunar 示例。我想 pause/serialize 遗传算法,然后在稍后阶段 continue/deserialize。

当我调用 train.pause(); 时,它只是 returns null - 这从代码中很明显,因为该方法总是 returns null

我认为这会非常简单,因为在某些情况下我想训练一个神经网络,用它进行一些预测,然后在我获得更多数据之前继续使用遗传算法进行训练恢复更多预测 - 无需从头开始重新开始训练。

请注意,我不是要序列化或持久化神经网络,而是整个遗传算法。

并非 Encog 中的所有培训师都支持简单 pause/resume。如果他们不支持,他们 return null,就像这个。遗传算法训练器比支持pause/resume的简单传播训练器复杂得多。要保存遗传算法的状态,您必须保存整个种群,以及评分函数(可以序列化也可以不序列化)。我修改了月球着陆器示例,以向您展示如何 save/reload 您的神经网络群体可以做到这一点。

您可以看到它训练了 50 次迭代,然后往返 (load/saves) 遗传算法,然后再训练 50 次。

package org.encog.examples.neural.lunar;

import java.io.File;
import java.io.IOException;

import org.encog.Encog;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.MethodFactory;
import org.encog.ml.ea.population.Population;
import org.encog.ml.genetic.MLMethodGeneticAlgorithm;
import org.encog.ml.genetic.MLMethodGenomeFactory;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.pattern.FeedForwardPattern;
import org.encog.util.obj.SerializeObject;

public class LunarLander {

    public static BasicNetwork createNetwork()
    {
        FeedForwardPattern pattern = new FeedForwardPattern();
        pattern.setInputNeurons(3);
        pattern.addHiddenLayer(50);
        pattern.setOutputNeurons(1);
        pattern.setActivationFunction(new ActivationTANH());
        BasicNetwork network = (BasicNetwork)pattern.generate();
        network.reset();
        return network;
    }

    public static void saveMLMethodGeneticAlgorithm(String file, MLMethodGeneticAlgorithm ga ) throws IOException
    {
        ga.getGenetic().getPopulation().setGenomeFactory(null);
        SerializeObject.save(new File(file),ga.getGenetic().getPopulation());   
    }

    public static MLMethodGeneticAlgorithm loadMLMethodGeneticAlgorithm(String filename) throws ClassNotFoundException, IOException {
        Population pop = (Population) SerializeObject.load(new File(filename));
        pop.setGenomeFactory(new MLMethodGenomeFactory(new MethodFactory(){
            @Override
            public MLMethod factor() {
                final BasicNetwork result = createNetwork();
                ((MLResettable)result).reset();
                return result;
            }},pop));

        MLMethodGeneticAlgorithm result = new MLMethodGeneticAlgorithm(new MethodFactory(){
            @Override
            public MLMethod factor() {
                return createNetwork();
            }},new PilotScore(),1);

        result.getGenetic().setPopulation(pop);

        return result;
    }


    public static void main(String args[])
    {
        BasicNetwork network = createNetwork();

        MLMethodGeneticAlgorithm train;


        train = new MLMethodGeneticAlgorithm(new MethodFactory(){
            @Override
            public MLMethod factor() {
                final BasicNetwork result = createNetwork();
                ((MLResettable)result).reset();
                return result;
            }},new PilotScore(),500);

        try {
            int epoch = 1;

            for(int i=0;i<50;i++) {
                train.iteration();
                System.out
                        .println("Epoch #" + epoch + " Score:" + train.getError());
                epoch++;
            } 
            train.finishTraining();

            // Round trip the GA and then train again
            LunarLander.saveMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin",train);
            train = LunarLander.loadMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin");

            // Train again
            for(int i=0;i<50;i++) {
                train.iteration();
                System.out
                        .println("Epoch #" + epoch + " Score:" + train.getError());
                epoch++;
            } 
            train.finishTraining();

        } catch(IOException ex) {
            ex.printStackTrace();
        } catch (ClassNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        int epoch = 1;

        for(int i=0;i<50;i++) {
            train.iteration();
            System.out
                    .println("Epoch #" + epoch + " Score:" + train.getError());
            epoch++;
        } 
        train.finishTraining();

        System.out.println("\nHow the winning network landed:");
        network = (BasicNetwork)train.getMethod();
        NeuralPilot pilot = new NeuralPilot(network,true);
        System.out.println(pilot.scorePilot());
        Encog.getInstance().shutdown();
    }
}