模乘函数:在模数下将两个整数相乘
Modulo Multiplication Function: Multiplying two integers under a modulus
我在 miller-rabin 素数测试的代码中遇到了这个模乘函数。这应该可以消除计算 ( a * b ) % m
.
时发生的整数溢出
我需要一些帮助来理解这里发生的事情。为什么这行得通?该数字文字 0x8000000000000000ULL
的意义是什么?
unsigned long long mul_mod(unsigned long long a, unsigned long long b, unsigned long long m) {
unsigned long long d = 0, mp2 = m >> 1;
if (a >= m) a %= m;
if (b >= m) b %= m;
for (int i = 0; i < 64; i++)
{
d = (d > mp2) ? (d << 1) - m : d << 1;
if (a & 0x8000000000000000ULL)
d += b;
if (d >= m) d -= m;
a <<= 1;
}
return d;
}
此代码目前出现在 the modular arithmetic Wikipedia page 上,仅适用于最多 63 位的参数——见底部。
概述
计算普通乘法a * b
的一种方法是添加b
的左移副本——[中的每个1位一个=16=]。这类似于我们大多数人在学校做长乘法的方式,但被简化了:因为我们只需要将 b
的每个副本“乘以”1 或 0,我们需要做的就是添加移位后的副本b
的(当 a
的相应位为 1 时)或什么都不做(当它为 0 时)。
这段代码做了类似的事情。但是,为了避免溢出(主要是;见下文),它不是移动 b
的每个副本然后将其添加到总数中,而是将 b
的未移动副本添加到总数中,并依赖于稍后左移 总数 以将其移到正确的位置。您可以认为这些转变“作用于”到目前为止添加到总数中的所有被加数。例如,第一个循环迭代检查 a
的最高位,即位 63,是否为 1(a & 0x8000000000000000ULL
就是这样做的),如果是,则将 b
的未移位副本添加到总数;到循环完成时,前一行代码会将总数 d
左移 1 位 63 次。
这样做的主要优点是我们总是将我们已知小于 m
的两个数字(即 b
和 d
)相加,因此处理模环绕很便宜:我们知道 b + d < 2 * m
,所以为了确保我们的总和仍然小于 m
,检查是否 b + d < m
就足够了,如果不是,则减去 m
。如果我们改用先移位后相加的方法,我们将需要每位 %
模运算,这与除法一样昂贵——而且通常比减法昂贵得多。
模运算的一个特性是,每当我们想对某个数字 m
执行一系列算术运算时,都以通常的算术方式执行它们并取余数模 m
最后总是产生与对每个中间结果取余数模 m
相同的结果(前提是没有发生溢出)。
代码
在循环体的第一行之前,我们有不变量d < m
和b < m
。
行
d = (d > mp2) ? (d << 1) - m : d << 1;
是将总数 d
左移 1 位的谨慎方法,同时将其保持在 0 .. m
范围内并避免溢出。我们不是先移动它然后测试结果是否为 m
或更大,而是测试它当前是否严格高于 RoundDown(m/2)
-- 因为如果是这样,加倍后, 它肯定会严格高于 2 * RoundDown(m/2) >= m - 1
,因此需要减去 m
才能回到范围内。请注意,即使 (d << 1) - m
中的 (d << 1)
可能会溢出并丢失 d
的最高位,但这并没有什么坏处,因为它不会影响减法结果的最低 64 位,即我们唯一感兴趣的。(另请注意,如果 d == m/2
完全正确,我们之后会得到 d == m
,这稍微超出了范围——但是将测试从 d > mp2
更改为d >= mp2
解决这个问题会打破 m
是奇数和 d == RoundDown(m/2)
的情况,所以我们必须忍受这个。没关系,因为它会在下面修复。)
为什么不直接写 d <<= 1; if (d >= m) d -= m;
呢?假设,在无限精度运算中,d << 1 >= m
,所以我们应该执行减法——但是 d
的高位打开并且 d << 1
的其余部分小于 m
:这种情况下,初始移位会丢失高位,if
执行失败。
限制为 63 位或更少的输入
上述边缘情况只有在d
的高位打开时才会发生,只有在m
的高位也打开时才会发生(因为我们保持不变d < m
).因此,即使 m
的值非常高,代码看起来也能正常工作。不幸的是,事实证明它仍然可以在 其他地方 溢出,导致某些设置最高位的输入的答案不正确。例如,当 a = 3
、b = 0x7FFFFFFFFFFFFFFFULL
和 m = 0xFFFFFFFFFFFFFFFFULL
时,正确答案应该是 0x7FFFFFFFFFFFFFFEULL
,但是 the code will return 0x7FFFFFFFFFFFFFFDULL
(查看正确答案的简单方法是重新运行 a
和 b
的值交换)。具体来说,每当行 d += b
溢出并使截断的 d
小于 m
时,就会发生此行为,从而导致错误地跳过减法。
如果记录了此行为(如维基百科页面所示),这只是一个限制,而不是错误。
取消限制
如果我们替换行
if (a & 0x8000000000000000ULL)
d += b;
if (d >= m) d -= m;
和
unsigned long long x = -(a >> 63) & b;
if (d >= m - x) d -= m;
d += x;
the code will work for all inputs,包括设置了最高位的那些。神秘的第一行只是一种无条件(因此通常更快)的写作方式
unsigned long long x = (a & 0x8000000000000000ULL) ? b : 0;
测试 d >= m - x
在 d
上运行 在 它被修改之前 -- 它就像旧的 d >= m
测试,但是 b
(当a
的最高位打开时)或0(否则)已经从两边减去。这测试 d
是否 是 m
或更大,一旦 x
添加到它。我们知道 RHS m - x
永远不会下溢,因为最大的 x
可以是 b
并且我们已经在函数顶部建立了 b < m
。
我在 miller-rabin 素数测试的代码中遇到了这个模乘函数。这应该可以消除计算 ( a * b ) % m
.
我需要一些帮助来理解这里发生的事情。为什么这行得通?该数字文字 0x8000000000000000ULL
的意义是什么?
unsigned long long mul_mod(unsigned long long a, unsigned long long b, unsigned long long m) {
unsigned long long d = 0, mp2 = m >> 1;
if (a >= m) a %= m;
if (b >= m) b %= m;
for (int i = 0; i < 64; i++)
{
d = (d > mp2) ? (d << 1) - m : d << 1;
if (a & 0x8000000000000000ULL)
d += b;
if (d >= m) d -= m;
a <<= 1;
}
return d;
}
此代码目前出现在 the modular arithmetic Wikipedia page 上,仅适用于最多 63 位的参数——见底部。
概述
计算普通乘法a * b
的一种方法是添加b
的左移副本——[中的每个1位一个=16=]。这类似于我们大多数人在学校做长乘法的方式,但被简化了:因为我们只需要将 b
的每个副本“乘以”1 或 0,我们需要做的就是添加移位后的副本b
的(当 a
的相应位为 1 时)或什么都不做(当它为 0 时)。
这段代码做了类似的事情。但是,为了避免溢出(主要是;见下文),它不是移动 b
的每个副本然后将其添加到总数中,而是将 b
的未移动副本添加到总数中,并依赖于稍后左移 总数 以将其移到正确的位置。您可以认为这些转变“作用于”到目前为止添加到总数中的所有被加数。例如,第一个循环迭代检查 a
的最高位,即位 63,是否为 1(a & 0x8000000000000000ULL
就是这样做的),如果是,则将 b
的未移位副本添加到总数;到循环完成时,前一行代码会将总数 d
左移 1 位 63 次。
这样做的主要优点是我们总是将我们已知小于 m
的两个数字(即 b
和 d
)相加,因此处理模环绕很便宜:我们知道 b + d < 2 * m
,所以为了确保我们的总和仍然小于 m
,检查是否 b + d < m
就足够了,如果不是,则减去 m
。如果我们改用先移位后相加的方法,我们将需要每位 %
模运算,这与除法一样昂贵——而且通常比减法昂贵得多。
模运算的一个特性是,每当我们想对某个数字 m
执行一系列算术运算时,都以通常的算术方式执行它们并取余数模 m
最后总是产生与对每个中间结果取余数模 m
相同的结果(前提是没有发生溢出)。
代码
在循环体的第一行之前,我们有不变量d < m
和b < m
。
行
d = (d > mp2) ? (d << 1) - m : d << 1;
是将总数 d
左移 1 位的谨慎方法,同时将其保持在 0 .. m
范围内并避免溢出。我们不是先移动它然后测试结果是否为 m
或更大,而是测试它当前是否严格高于 RoundDown(m/2)
-- 因为如果是这样,加倍后, 它肯定会严格高于 2 * RoundDown(m/2) >= m - 1
,因此需要减去 m
才能回到范围内。请注意,即使 (d << 1) - m
中的 (d << 1)
可能会溢出并丢失 d
的最高位,但这并没有什么坏处,因为它不会影响减法结果的最低 64 位,即我们唯一感兴趣的。(另请注意,如果 d == m/2
完全正确,我们之后会得到 d == m
,这稍微超出了范围——但是将测试从 d > mp2
更改为d >= mp2
解决这个问题会打破 m
是奇数和 d == RoundDown(m/2)
的情况,所以我们必须忍受这个。没关系,因为它会在下面修复。)
为什么不直接写 d <<= 1; if (d >= m) d -= m;
呢?假设,在无限精度运算中,d << 1 >= m
,所以我们应该执行减法——但是 d
的高位打开并且 d << 1
的其余部分小于 m
:这种情况下,初始移位会丢失高位,if
执行失败。
限制为 63 位或更少的输入
上述边缘情况只有在d
的高位打开时才会发生,只有在m
的高位也打开时才会发生(因为我们保持不变d < m
).因此,即使 m
的值非常高,代码看起来也能正常工作。不幸的是,事实证明它仍然可以在 其他地方 溢出,导致某些设置最高位的输入的答案不正确。例如,当 a = 3
、b = 0x7FFFFFFFFFFFFFFFULL
和 m = 0xFFFFFFFFFFFFFFFFULL
时,正确答案应该是 0x7FFFFFFFFFFFFFFEULL
,但是 the code will return 0x7FFFFFFFFFFFFFFDULL
(查看正确答案的简单方法是重新运行 a
和 b
的值交换)。具体来说,每当行 d += b
溢出并使截断的 d
小于 m
时,就会发生此行为,从而导致错误地跳过减法。
如果记录了此行为(如维基百科页面所示),这只是一个限制,而不是错误。
取消限制
如果我们替换行
if (a & 0x8000000000000000ULL)
d += b;
if (d >= m) d -= m;
和
unsigned long long x = -(a >> 63) & b;
if (d >= m - x) d -= m;
d += x;
the code will work for all inputs,包括设置了最高位的那些。神秘的第一行只是一种无条件(因此通常更快)的写作方式
unsigned long long x = (a & 0x8000000000000000ULL) ? b : 0;
测试 d >= m - x
在 d
上运行 在 它被修改之前 -- 它就像旧的 d >= m
测试,但是 b
(当a
的最高位打开时)或0(否则)已经从两边减去。这测试 d
是否 是 m
或更大,一旦 x
添加到它。我们知道 RHS m - x
永远不会下溢,因为最大的 x
可以是 b
并且我们已经在函数顶部建立了 b < m
。