仅使用乘法+移位(32 位)进行乘法+除法

Multiply+Divide using just Multiply+Shift (32 bit)

我想知道最快的比例计算方法,即y = x * a / b,其中所有值都是32位,无符号,并且ab是固定的(初始化一次,然后不改变)但在编译时不知道。结果保证不会溢出(即使中间乘法可能需要 64 位)。编程语言并不重要,但 Java 最适合我的情况。它需要尽可能快(纳秒很重要)。我目前使用:

int result = (int) ((long) x * a / b);

但是除法很慢。我知道 ,所以最好是以下类型的公式:

int result = (int) (((long) x * factor) >>> shift);

其中 factorshift 可以从 ab 计算出来(计算速度可能很慢)。

我试图简单地替换原始公式的除法部分,但它不起作用,因为两次乘法的结果不适合 64 位:

// init
int shift = 63 - Integer.numberOfLeadingZeros(b);
int factor = ((1L << shift) / b) + 1;
...
// actual calculation
int result = (int) ((long) x * a * factor) >>> shift);

在我的情况下,结果实际上不必完全准确(差一就可以)。

怎么样

long a2 = a & 0xFFFFFFFFL;
long b2 = b & 0xFFFFFFFFL;
checkArgument(b2 > 0);
double dFactor = (double) a2 / b2;
shift = 0;
while (dFactor < 1L<<32) {
   dFactor *= 2;
   shift++;
}
factor = (long) dFactor;

准备和

int result = (int) (((x & 0xFFFFFFFFL) * factor) >>> shift);

对于快的部分?现在我们有 2**32 <= factor < 2**33 并且对于任何 int x >= 0,产品 x * factor < 2**31 * 2**33 = 2**64 正好适合 unsigned long。没有比特被浪费。 dFactorlong 的转换向下舍入,这可能是次优的。


当然可以加快准备工作,尤其是可以通过先查看前导零来消除循环。我不会为消除 double 而烦恼,因为它使事情变得简单。

由于 ab 都是固定的,您可以只进行一次除法并重复使用结果(这可能已经在幕后自动发生):

int c = a / b;
int y1 = x1 * c;
int y2 = x2 * c;
...

如果您真的需要优化它,请在 GPU 上查看 运行ning 它(例如使用 java 绑定 CUDA),这将允许你可以并行化计算,尽管这更难实现。

最后,在测试时添加计时器总是一个好主意,这样您就可以 运行 进行基准测试以确保优化确实提高了性能。

我认为如果使用公式 (x * factor) >>> shift 不可能总是得到准确的结果:对于某些边缘情况,结果要么是 1 太低,要么是 1 太高。为了始终获得正确的结果,公式需要更加复杂。我找到了一个不需要浮点数的解决方案,这里是一个测试用例:

static final Set<Integer> SOME_VALUES = new TreeSet<Integer>();

static {
    Set<Integer> set = SOME_VALUES;
    for (int i = 0; i < 100; i++) {
        set.add(i);
    }
    set.add(Integer.MAX_VALUE);
    set.add(Integer.MAX_VALUE - 1);
    for (int i = 1; i > 0; i += i) {
        set.add(i - 1);
        set.add(i);
        set.add(i + 1);
    }
    for (int i = 1; i > 0; i *= 3) {
        set.add(i);
    }
    Random r = new Random(1);
    for (int i = 0; i < 100; i++) {
        set.add(r.nextInt(Integer.MAX_VALUE));
    }
}

private static void testMultiplyDelete() {
    for (int a : SOME_VALUES) {
        for (int b : SOME_VALUES) {
            if (b == 0) {
                continue;
            }
            int shift = 32;
            // sometimes 1 too low
            long factor = (1L << shift) * a / b;
            // sometimes 1 too high
            // long factor = ((1L << shift) * a / b) + 1;

            // sometimes 1 too low
            // double dFactor = (double) a / b;
            // int shift = 0;
            // while (dFactor > 0 && dFactor < (1L << 32)) {
            //     dFactor *= 2;
            //     shift++;
            // }
            // long factor = (long) dFactor;

            for (int x : SOME_VALUES) {
                long expectedResult = (long) x * a / b;
                if (expectedResult < 0 ||
                        expectedResult >= Integer.MAX_VALUE) {
                    continue;
                }
                int result = (int) ((x * factor) >>> shift);
                if (Math.abs(result - expectedResult) > 1) {
                    System.out.println(x + "*" + a + "/" + b +
                            "=" + expectedResult + "; " +
                            "(" + x + "*" + factor + ")>>>" + shift + "=" + result);
                }
            }
        }
    }
}