如何修复 org.apache.commons.math3.exception.ConvergenceException

How do I fix a org.apache.commons.math3.exception.ConvergenceException

我正在尝试使用 Commons Math 版本 3.6.1 拟合一个 4 参数 Hill 方程。截至 2018 年 6 月 20 日,我还使用 4.0-SNAPSHOT 版本进行了尝试。我得到了相同的结果。我有一个简单的测试可以运行并且不会抛出异常。然而,更复杂的数据位失败了。我从几个处理 Hill / Sigmoidal 拟合的网络视图中找到了数据。我不确定下一步该怎么做才能解决这个问题,有什么建议吗?

我得到这个:

org.apache.commons.math3.exception.ConvergenceException: illegal state: unable to perform Q.R decomposition on the 9x4 jacobian matrix

    at org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer.qrDecomposition(LevenbergMarquardtOptimizer.java:975)
    at org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer.optimize(LevenbergMarquardtOptimizer.java:342)
    at org.apache.commons.math3.fitting.AbstractCurveFitter.fit(AbstractCurveFitter.java:63)
    at com.adarza.curve.fit.FourParamHillFitterTest.largerDataTest(FourParamHillFitterTest.java:42)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.junit.runners.model.FrameworkMethod.runReflectiveCall(FrameworkMethod.java:50)
    at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12)
    at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47)
    at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17)
    at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325)
    at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78)
    at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57)
    at org.junit.runners.ParentRunner.run(ParentRunner.java:290)
    at org.junit.runners.ParentRunner.schedule(ParentRunner.java:71)
    at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288)
    at org.junit.runners.ParentRunner.access[=12=]0(ParentRunner.java:58)
    at org.junit.runners.ParentRunner.evaluate(ParentRunner.java:268)
    at org.junit.runners.ParentRunner.run(ParentRunner.java:363)
    at org.junit.runner.JUnitCore.run(JUnitCore.java:137)
    at com.intellij.junit4.JUnit4IdeaTestRunner.startRunnerWithArgs(JUnit4IdeaTestRunner.java:68)
    at com.intellij.rt.execution.junit.IdeaTestRunner$Repeater.startRunnerWithArgs(IdeaTestRunner.java:47)
    at com.intellij.rt.execution.junit.JUnitStarter.prepareStreamsAndStart(JUnitStarter.java:242)
    at com.intellij.rt.execution.junit.JUnitStarter.main(JUnitStarter.java:70)

我的代码如下。

初始参数:

import lombok.Data;

@Data
public class FourParamHillEqInitParams {
    double initialHighVarD = Double.MIN_VALUE;
    double initialLowVarA = Double.MAX_VALUE;
    double midInflectionPointVarC = 0.0;
    double initialHillSlopeVarB = 0.0;
}

曲线拟合器:

import org.apache.commons.math3.fitting.AbstractCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.linear.DiagonalMatrix;

import java.util.*;

public class FourParamHillFitter  extends AbstractCurveFitter {
    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points) {
        final int len = points.size();
        final double[] target  = new double[len];
        final double[] weights = new double[len];

        FourParamHillEqInitParams initialGuesses = guessInitialCoefficents(points);

        final double[] initialGuess = { initialGuesses.initialLowVarA,
                                        initialGuesses.initialHillSlopeVarB,
                                        initialGuesses.midInflectionPointVarC,
                                        initialGuesses.initialHighVarD };

        System.out.println("Initial Guesses: " + Arrays.toString(initialGuess));

        int i = 0;
        for(WeightedObservedPoint point : points) {
            target[i]  = point.getY();
            weights[i] = point.getWeight();
            i += 1;
        }

        final AbstractCurveFitter.TheoreticalValuesFunction model = new
                AbstractCurveFitter.TheoreticalValuesFunction(new FourParamHillFunction(), points);

        return new LeastSquaresBuilder().
                maxEvaluations(Integer.MAX_VALUE).
                maxIterations(Integer.MAX_VALUE).
                start(initialGuess).
                target(target).
                weight(new DiagonalMatrix(weights)).
                model(model.getModelFunction(), model.getModelFunctionJacobian()).
                build();
    }

    private FourParamHillEqInitParams guessInitialCoefficents(Collection<WeightedObservedPoint> points) {
        FourParamHillEqInitParams initParams = new FourParamHillEqInitParams();
        double sum = 0.0;
        for (Iterator<WeightedObservedPoint> iterator = points.iterator(); iterator.hasNext(); ) {
            WeightedObservedPoint p =  iterator.next();
            if (p.getY() > initParams.initialHighVarD) {
                initParams.initialHighVarD = p.getY();
            }
            if (p.getY() < initParams.initialLowVarA){
                initParams.initialLowVarA = p.getY();
            }
            sum += p.getY();
        }
        initParams.midInflectionPointVarC = sum / points.size(); // mean
        initParams.initialHillSlopeVarB = 25.0;
        return initParams;
    }
}

函数:

import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;

public class FourParamHillFunction implements ParametricUnivariateFunction {
    public double value(double x, double... parm) {
//        return parameters[0] * Math.pow(t, parameters[1]) * Math.exp(-parameters[2] * t);
        double a = parm[0];
        double b = parm[1];
        double c = parm[2];
        double d = parm[3];

        return d+ ((a-d)/ (1 + Math.pow( (x/c), b)));
    }


    // Jacobian matrix of the above. In this case, this is just an array of
    // partial derivatives of the above function, with one element for each parameter.
    public double[] gradient(double t, double... parameters) {
        final double a = parameters[0];
        final double b = parameters[1];
        final double c = parameters[2];
        final double d = parameters[3];

        // Jacobian Matrix Edit

        // Using Derivative Structures...
        // constructor takes 4 arguments - the number of parameters in your
        // equation to be differentiated (4 in this case), the order of
        // differentiation for the DerivativeStructure, the index of the
        // parameter represented by the DS, and the value of the parameter itself
        DerivativeStructure aDev = new DerivativeStructure(4, 1, 0, a);
        DerivativeStructure bDev = new DerivativeStructure(4, 1, 1, b);
        DerivativeStructure cDev = new DerivativeStructure(4, 1, 2, c);
        DerivativeStructure dDev = new DerivativeStructure(4, 1, 3, d);

        // define the equation to be differentiated using another DerivativeStructure
//        DerivativeStructure y = aDev.multiply(DerivativeStructure.pow(t, bDev))
//                .multiply(cDev.negate().multiply(t).exp());

        //y = d+(a-d)/(1+(x/c)^b)
        DerivativeStructure numerator = aDev.subtract(dDev);
        DerivativeStructure xPart = cDev.reciprocal().multiply(t).pow(bDev);
        DerivativeStructure denominator = xPart.add(1.0);
        DerivativeStructure y = dDev.add( numerator.divide(denominator) );

        // then return the partial derivatives required
        // notice the format, 4 arguments for the method since 4 parameters were
        // specified first order derivative of the first parameter, then the second,
        // then the third
        return new double[] {
                y.getPartialDerivative(1, 0, 0, 0),
                y.getPartialDerivative(0, 1, 0, 0),
                y.getPartialDerivative(0, 0, 1, 0),
                y.getPartialDerivative(0, 0, 0, 1)
        };
    }
}

测试:

import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;

public class FourParamHillFitterTest {

    @Test
    public void basicTest() {
        FourParamHillFitter fitter = new FourParamHillFitter();

        double[] xValues = { 1.0, 2.0, 3.0, 4.0, 5.0 };
        double[] yValues = { 1.0, 1.2, 3.0, 7.0, 7.0 };

        ArrayList<WeightedObservedPoint> points = createPointsFromArray(xValues, yValues);

        double coeffs[] = fitter.fit(points);

        System.out.println(Arrays.toString(coeffs));

        assertEquals(4, coeffs.length);
        assertEquals(1.099995, coeffs[0], 0.1);
        assertEquals(31.03071, coeffs[1], 0.01);
        assertEquals(3.072862, coeffs[2], 0.01);
        assertEquals(7.000825, coeffs[3], 0.01);
    }

    @Test
    public void largerDataTest() {
        FourParamHillFitter fitter = new FourParamHillFitter();

        double[] xValues = { 0, 1.3, 2.8, 5, 10.2, 16.5, 21.3, 31.8, 52.2};
        double[] yValues = { 0.1, 0.5, 0.9, 2.6, 7.1, 12.3, 15.3, 20.4, 24.4 };

        ArrayList<WeightedObservedPoint> points = createPointsFromArray(xValues, yValues);

        final double coeffs[] = fitter.fit(points);

        System.out.println(Arrays.toString(coeffs));

        assertEquals(4, coeffs.length);
        assertEquals(0.1536, coeffs[0], 0.01);
        assertEquals(1.7718, coeffs[1], 0.01);
        assertEquals(19.3494, coeffs[2], 0.01);
        assertEquals(28.4479, coeffs[3], 0.01);
    }

    public ArrayList<WeightedObservedPoint> createPointsFromArray(double[] xs, double[] ys){
        ArrayList<WeightedObservedPoint> points = new ArrayList<WeightedObservedPoint>();

        for(int i=0; i < xs.length; i++){
            WeightedObservedPoint point = new WeightedObservedPoint(0, xs[i], ys[i]);
            points.add(point);
        }
        return points;
    }

}

不确定,但我假设,因为如果计算函数的梯度:

d+((a-d)/(1 + Math.pow( (x/c), b)))

b的偏导数涉及Log()(自然对数)表达式

-(((a - d)*(x/c)^b*Log(x/c))/((x/c)^b + 1)^2)

Log(0)-Infinity

因此避免 x 值等于 00 附近的 x 值如 0.0001 可能有帮助。

在我自己的项目中,我在 FindFit 函数中实现了符号梯度,这也可能会改善结果。