65.9K
CodeProject 正在变化。 阅读更多。
Home

Java 中高斯-牛顿算法的实现

starIconstarIconstarIconstarIconstarIcon

5.00/5 (8投票s)

2017年3月12日

CPOL

3分钟阅读

viewsIcon

21950

downloadIcon

468

用Java实现高斯-牛顿算法来解决非线性最小二乘问题;即,找到一个函数的最小值。

引言

高斯-牛顿算法是一种解决非线性函数的数学模型。下面给出一个简单的非线性函数

其中a1a2是该函数的未知参数。为了找到这两个参数,在不同的x值上测量y值;y可以是化学反应的速率,x是影响速率的化学物质的浓度。结果可能如下所示

x0.0380.1940.4250.6261.2532.5003.740
0.0500.1270.0940.21220.27290.26650.3317

高斯-牛顿算法用于使用上述观测数据来估计a1 a2的值。

这个示例函数和数据取自 https://en.wikipedia.org/wiki/Gauss-Newton_algorithm

背景

高斯-牛顿算法的几个步骤

  1. 创建一个函数中未知参数的数组

    b = (b1, b2,...,bn)

  2. 初始化参数,并针对矩阵x中的每个数据点,计算预测值(y')。
  3. 计算残差

    ri = y'i - yi

  4. 找到残差对参数的偏导数,并生成雅可比矩阵

  5. 遵循一个迭代过程,使用以下方程计算参数的新值

    s 是迭代次数,J 是雅可比矩阵,JT 是 J 的转置。

有关更多信息,请访问 https://en.wikipedia.org/wiki/Gauss-Newton_algorithm

Using the Code

在这个实现中,我将使用已经在另一篇文章中实现的一些矩阵运算。有关更多信息,请查看 这篇文章

  1. 给定矩阵 xy,优化从对 b 矩阵的初始猜测开始。默认情况下,为非线性函数中的所有参数赋予值 1.0
  2. 计算残差
    public double[][] calculateResiduals(double[][] x, double[] y, double[] b) {
         double[][] res = new double[y.length][1];
    
         for (int i = 0; i < res.length; i++) {
             res[i][0] = findY(x[i][0], b) - y[i];
         }
         return res;
    }

    对于 x 矩阵中的每个数据点,调用函数 findY() 来计算 y 的预测值。在代码中,您将发现此函数是 abstract(抽象的)

    public abstract double findY(double x, double[] b);

    用户在使用优化之前需要实现此函数。例如,对于引言中的函数,实现将如下所示

            GaussNewton gaussNewton = new GaussNewton() {
    
                @Override
                public double findY(double x, double[] b) {
                    // y = (x * a1) / (a2 + x)
                    return (x * b[0]) / (b[1] + x);
                }
            };
  3. 计算雅可比矩阵,它是残差对函数中参数的偏导数
        public double[][] jacob(double[] b, double[][] x, int numberOfObservations) {
            int numberOfVariables = b.length;
            double[][] jc = new double[numberOfObservations][numberOfVariables];
    
            for (int i = 0; i < numberOfObservations; i++) {
                for (int j = 0; j < numberOfVariables; j++) {
                    jc[i][j] = derivative(x[i][0], b, j);
                }
            }
            return jc;
        }

    调用函数 derivative() 来计算偏导数

        public double derivative(double x, double[] b, int bIndex) {
            double[] bCopy = b.clone();
            bCopy[bIndex] += alpha;
            double y1 = findY(x, bCopy);
            bCopy = b.clone();
            bCopy[bIndex] -= alpha;
            double y2 = findY(x, bCopy);
            return (y1 - y2) / (2 * alpha);
        }

    此函数给出了偏导数的良好近似值;即,在变量发生微小变化后 y 的变化。

  4. 执行 (J JT)-1 JT r 运算
        public double[][] transjacob(double[][] JArray, double[][] res) throws NoSquareException {
            Matrix r = new Matrix(res); // r
            Matrix J = new Matrix(JArray); // J
            Matrix JT = MatrixMathematics.transpose(J); // JT
            Matrix JTJ = MatrixMathematics.multiply(JT, J); // JT * J
            Matrix JTJ_1 = MatrixMathematics.inverse(JTJ); // (JT * J)^-1
            Matrix JTJ_1JT = MatrixMathematics.multiply(JTJ_1, JT); // (JT * J)^-1 * JT
            Matrix JTJ_1JTr = MatrixMathematics.multiply(JTJ_1JT, r); // (JT * J)^-1 * JT * r
            return JTJ_1JTr.getValues();
        }
  5. 使用步骤 4 的结果,计算参数的新值
    IntStream.range(0, values.length).forEach(j -> b2[j] = b2[j] - gamma * values[j][0]);

    b2 是新的 b 矩阵。gamma 是来自雅可比的数值的一部分。如果 b 的初始值与最优值相差甚远,则存在收敛问题。应用这个简单的分数似乎可以解决这个问题。应用这个分数的缺点是迭代次数会增加。

  6. 有了新的 b 矩阵,在下一次迭代中重复步骤 2-5。所有的优化步骤都在下面的函数中给出
    public double[] optimise(double[][] x, double[] y, double[] b) throws NoSquareException {
            int maxIteration = 1000;
            double oldError = 100;
            double precision = 1e-6;
            double[] b2 = b.clone();
            double gamma = .01;
            for (int i = 0; i < maxIteration; i++) {
                double[][] res = calculateResiduals(x, y, b2);
                double error = calculateError(res);
                System.out.println("Iteration : " + i + ", Error-diff: " + 
                        (Math.abs(oldError - error)) + ", b = "+ Arrays.toString(b2));
                if (Math.abs(oldError - error) <= precision) {
                    break;
                } 
                oldError = error;
                double[][] jacobs = jacob(b2, x, y.length);
                double[][] values = transjacob(jacobs, res);
                IntStream.range(0, values.length).forEach(j -> b2[j] = b2[j] - gamma * values[j][0]);
            }
            return b2;
    
        }

示例

函数

观察

x0.0380.1940.4250.6261.2532.5003.740
0.0500.1270.0940.21220.27290.26650.3317

目标

使用上述数据和高斯-牛顿算法找到 a1 和 a2 

    @Test
    public void optimiseWithInitialValueOf1() throws NoSquareException {
        double[][] x = new double[7][1];
        x[0][0] = 0.038;
        x[1][0] = 0.194;
        x[2][0] = 0.425;
        x[3][0] = 0.626;
        x[4][0] = 1.253;
        x[5][0] = 2.500;
        x[6][0] = 3.740;
        double[] y = new double[] { 0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317 };
        GaussNewton gaussNewton = new GaussNewton() {

            @Override
            public double findY(double x, double[] b) {
                // y = (x * a1) / (a2 + x)
                return (x * b[0]) / (b[1] + x);
            }
        };
        double[] b = gaussNewton.optimise(x, y, 2);
        Assert.assertArrayEquals(new double[]{0.36, 0.56}, b, 0.01);
    }

在上面的测试中,参数的初始值是默认值 (1.0);但是,使用起始值 100,我们将得到相同的值;请参阅随附的代码。我还对一组大型的随机生成数据测试了该算法,它在时间和内存效率方面表现良好;请参阅随附的测试代码。

历史

这是本文的第一个版本。

Java 中高斯-牛顿算法的实现 - CodeProject - 代码之家
© . All rights reserved.