如何在 java 中的神经元 class 中实现可修改的激活函数?

How to Implement modifiable activation function in the neuron class in java?

我正在学习神经网络的概念。我决定尝试自己制作神经元 class。在我的代码中实现不同激活函数的最佳方式是什么?现在它只使用二元步进函数。 这是我第一次尝试编码神经网络,所以如果您对我的代码有任何建议,或者它完全愚蠢,请告诉我。

这是我的代码:

public class Neuron {

// properties
    private ArrayList<Neuron> input;
    private ArrayList<Float> weight;
    private float pot, bias, sense, out;
    private boolean checked;

// methods
    public float fire(){
        pot = 0f;
        if (input != null) {
            for (Neuron n : input){
                if (!n.getChecked()){
                    pot += n.fire()*weight.get(input.indexOf(n));
                } else {
                        pot += n.getOut()*weight.get(input.indexOf(n));
                } // end of condition (checked)
            } // end of loop (for input)
        } // end of condition (input exists)
        checked = true;
        pot -= bias;
        pot += sense;
        out = actFunc(pot);
        return out;
    } // end of fire()

    // getting properties
    public float getPot(){return pot;}
    public boolean getChecked(){return checked;}
    public float getOut(){return out;}

    // setting properties
    public void stimulate(float f){sense = f;}
    public void setBias(float b){bias = b;}
    public void setChecked(boolean c){checked = c;}
    public void setOut(float o){out = o;}

    // connection
    public void connect(Neuron n, float w){
        input.add(n);
        weight.add(w);
        }
    public void deconnect(Neuron n){
        weight.remove(input.indexOf(n));
        input.remove(n);
    }

    // activation function
        private float actFunc(float x){
            if (x < 0) {
                return 0f;
            } else {
                return 1f;
            }
        }

// constructor
    public Neuron(Neuron[] ns, float[] ws, float b, float o){
        if (ns != null){
            input = new ArrayList<Neuron>();
            weight = new ArrayList<Float>();
            for (Neuron n : ns) input.add(n);
            for (int i = 0; i < ws.length; i++) weight.add(ws[i]);
        } else {
            input = null;
            weight = null;
        }
        bias = b;
        out = o;
    }

    public Neuron(Neuron[] ns){
        if (ns != null){
            input = new ArrayList<Neuron>();
            weight = new ArrayList<Float>();
            for (Neuron n : ns) input.add(n);
            for (int i = 0; i < input.size(); i++) weight.add((float)Math.random()*2f-1f);
        } else {
            input = null;
            weight = null;
        }
        bias = (float)Math.random();
        out = (float)Math.random();
    }

}

首先,定义任意激活函数的接口:

public interface ActivationFunction {
    float get(float f);
}

然后写一些实现:

public class StepFunction implements ActivationFunction {
    @Override
    public float get() {return (x < 0) ? 0f : 1f;}
}

public class SigmoidFunction implements ActivationFunction {
    @Override
    public float get() {return StrictMath.tanh(h);}
}

最后,为您的 Neuron:

设置一些实现
public class Neuron {
    private final ActivationFunction actFunc;
    // other fields...

    public Neuron(ActivationFunction actFunc) {
        this.actFunc = actFunc;
    }

    public float fire(){
        // ...
        out = actFunc.get(pot);
        return out;
    } 
}

如下:

Neuron n = new Neuron(new SigmoidFunction());

注意,神经网络正在使用通过神经元的信号传播,在神经元中产生权重。权重的计算还取决于激活函数的 一阶导数 。因此,我将通过方法扩展 ActivationFunction,这将 return 在指定点 x:

的一阶导数
public interface ActivationFunction {
    float get(float f);
    float firstDerivative(float x);
}

因此实现将如下所示:

public class StepFunction implements ActivationFunction {
    @Override
    public float get(float x) {return (x < 0) ? 0f : 1f;}

    @Override
    public float firstDerivative(float x) {return 1;}
}

public class SigmoidFunction implements ActivationFunction {
    @Override
    public float get(float x) {return StrictMath.tanh(x);}

    // derivative_of tanh(x) = (4*e^(2x))/(e^(2x) + 1)^2 == 1-tanh(x)^2 
    @Override
    public float firstDerivative(float x) {return 1 - Math.pow(StrictMath.tanh(x), 2);}
}

然后,在计算权重的fire()方法中使用actFunction.firstDerivative(x);