在 Java 中优化 Collat​​z 猜想

Optimize Collatz Conjecture in Java

我正在开发一个程序,该程序使用 Collat​​z 猜想确定数字变为 1 所需的步数(如果 n 是奇数,则为 3n+1;如果 n 为偶数,则 n/2 ).该程序每完成一次计算,就会将正在计算的数字加一,并测试它可以在几秒钟内计算出多少个数字。这是我目前的工作程序:

public class Collatz {
    static long numSteps = 0;
    public static long calculate(long c){
        if(c == 1){
            return numSteps;
        }
        else if(c % 2 == 0){
            numSteps++;
            calculate(c / 2);
        }
        else if(c % 2 != 0){
            numSteps++;
            calculate(c * 3 + 1);
        }
        return numSteps;
    }
    public static void main(String args[]){
        int n = 1;
        long startTime = System.currentTimeMillis();
        while(System.currentTimeMillis() < startTime + 60000){

            calculate(n);
            n++;
            numSteps = 0;
        }
        System.out.println("The highest number was: " + n);
    }
}

它目前可以在一分钟内计算大约 1 亿个数字,但我正在寻找有关如何进一步优化程序以使其在一分钟内计算更多数字的建议。任何和所有建议将不胜感激:)。

你可以

  • 通过假设 c % 2 == 0 为假而不是 c % 2 != 0 必须为真来优化计算方法。您还可以假设 c * 3 + 1 必须是偶数,这样您就可以计算 (c * 3 + 1)/2 并将两个加到 numSteps 中。您可以使用循环而不是递归,因为 Java 没有尾调用优化。

  • 通过记忆获得更大的进步。对于每个数字,你可以记住你得到的结果,如果这个数字是在return那个值之前计算出来的。你可能想对记忆设置一个上限,例如不高于您要计算的最后一个数字。如果不这样做,某些值将是最大值的许多倍。

为了您的兴趣

public class Collatz {
    static final int[] CALC_CACHE = new int[2_000_000_000];

    static int calculate(long n) {
        int numSteps = 0;
        long c = n;
        while (c != 1) {
            if (c < CALC_CACHE.length) {
                int steps = CALC_CACHE[(int) c];
                if (steps > 0) {
                    numSteps += steps;
                    break;
                }
            }
            if (c % 2 == 0) {
                numSteps++;
                c /= 2;
            } else {
                numSteps += 2;
                if (c > Long.MAX_VALUE / 3)
                    throw new IllegalStateException("c is too large " + c);
                c = (c * 3 + 1) / 2;
            }
        }
        if (n < CALC_CACHE.length) {
            CALC_CACHE[(int) n] = numSteps;
        }
        return numSteps;
    }

    public static void main(String args[]) {
        long n = 1, maxN = 0, maxSteps = 0;
        long startTime = System.currentTimeMillis();
        while (System.currentTimeMillis() < startTime + 60000) {
            for (int i = 0; i < 10; i++) {
                int steps = calculate(n);
                if (steps > maxSteps) {
                    maxSteps = steps;
                    maxN = n;
                }
                n++;
            }
            if (n % 10000000 == 1)
                System.out.printf("%,d%n", n);
        }
        System.out.printf("The highest number was: %,d, maxSteps: %,d for: %,d%n", n, maxSteps, maxN);
    }
}

打印

The highest number was: 1,672,915,631, maxSteps: 1,000 for: 1,412,987,847

更高级的答案是使用多线程。在这种情况下,使用带记忆的递归更容易实现。

import java.util.stream.LongStream;

public class Collatz {
    static final short[] CALC_CACHE = new short[Integer.MAX_VALUE-8];

    public static int calculate(long c) {
        if (c == 1) {
            return 0;
        }
        int steps;
        if (c < CALC_CACHE.length) {
            steps = CALC_CACHE[(int) c];
            if (steps > 0)
                return steps;
        }
        if (c % 2 == 0) {
            steps = calculate(c / 2) + 1;
        } else {
            steps = calculate((c * 3 + 1) / 2) + 2;
        }
        if (c < CALC_CACHE.length) {
            if (steps > Short.MAX_VALUE)
                throw new AssertionError();
            CALC_CACHE[(int) c] = (short) steps;
        }
        return steps;
    }

    static int calculate2(long n) {
        int numSteps = 0;
        long c = n;
        while (c != 1) {
            if (c < CALC_CACHE.length) {
                int steps = CALC_CACHE[(int) c];
                if (steps > 0) {
                    numSteps += steps;
                    break;
                }
            }
            if (c % 2 == 0) {
                numSteps++;
                c /= 2;
            } else {
                numSteps += 2;
                if (c > Long.MAX_VALUE / 3)
                    throw new IllegalStateException("c is too large " + c);
                c = (c * 3 + 1) / 2;
            }
        }
        if (n < CALC_CACHE.length) {
            CALC_CACHE[(int) n] = (short) numSteps;
        }
        return numSteps;
    }

    public static void main(String args[]) {
        long maxN = 0, maxSteps = 0;
        long startTime = System.currentTimeMillis();
        long[] res = LongStream.range(1, 6_000_000_000L).parallel().collect(
                () -> new long[2],
                (long[] arr, long n) -> {
                    int steps = calculate(n);
                    if (steps > arr[0]) {
                        arr[0] = steps;
                        arr[1] = n;
                    }
                },
                (a, b) -> {
                    if (a[0] < b[0]) {
                        a[0] = b[0];
                        a[1] = b[1];
                    }
                });
        maxN = res[1];
        maxSteps = res[0];
        long time = System.currentTimeMillis() - startTime;
        System.out.printf("After %.3f seconds, maxSteps: %,d for: %,d%n", time / 1e3, maxSteps, maxN);
    }
}

打印

After 52.461 seconds, maxSteps: 1,131 for: 4,890,328,815

注意:如果我将第二个计算调用更改为

     steps = calculate((c * 3 + 1) ) + 1;

它打印

After 63.065 seconds, maxSteps: 1,131 for: 4,890,328,815