当 a 和 b 都小于 c 但 a * b 溢出时,如何计算 a * b / c?
How can I compute a * b / c when both a and b are smaller than c, but a * b overflows?
假设uint
是我定点平台上最大的整数类型,我有:
uint func(uint a, uint b, uint c);
这需要 return a * b / c
的良好近似值。
c
的值大于 a
的值和 b
的值。
所以我们肯定知道 a * b / c
的值适合 uint
.
然而,a * b
的值本身溢出了uint
的大小。
因此计算 a * b / c
值的一种方法是:
return a / c * b;
甚至:
if (a > b)
return a / c * b;
return b / c * a;
但是,c
的值大于 a
的值和 b
的值。
所以上面的建议只是 return 零。
我需要按比例减少 a * b
和 c
,但同样 - 问题是 a * b
溢出。
理想情况下,我将能够:
- 将
a * b
替换为uint(-1)
- 将
c
替换为 uint(-1) / a / b * c
。
但是无论我如何排序表达式uint(-1) / a / b * c
,我都遇到了一个问题:
uint(-1) / a / b * c
由于 uint(-1) / a / b
被截断为零
uint(-1) / a * c / b
由于 uint(-1) / a * c
而溢出
uint(-1) * c / a / b
由于 uint(-1) * c
溢出
如何解决这种情况才能找到 a * b / c
的良好近似值?
编辑 1
我的平台上没有_umul128
这样的东西,最大的整数类型是uint64
。我最大的类型是 uint
,我不支持任何比它更大的类型(无论是在硬件层面,还是在一些预先存在的标准库中)。
我最大的类型是uint
。
编辑 2
回应大量重复的建议和评论:
我手头没有可以用来解决这个问题的“更大的类型”。这就是为什么问题的开头陈述是:
Assuming that uint
is the largest integral type on my fixed-point platform
我假设不存在其他类型,无论是在 SW 层(通过一些内置标准库)还是在 HW 层。
您可以按如下方式取消质因数:
uint gcd(uint a, uint b)
{
uint c;
while (b)
{
a %= b;
c = a;
a = b;
b = c;
}
return a;
}
uint func(uint a, uint b, uint c)
{
uint temp = gcd(a, c);
a = a/temp;
c = c/temp;
temp = gcd(b, c);
b = b/temp;
c = c/temp;
// Since you are sure the result will fit in the variable, you can simply
// return the expression you wanted after having those terms canceled.
return a * b / c;
}
needs to return a good approximation of a * b / c
My largest type is uint
both a and b are smaller than c
的变化:
Algorithm: Scale a, b to not overflow
SQRT_MAX_P1 as a compile time constant of sqrt(uint_MAX + 1)
sh = 0;
if (c >= SQRT_MAX_P1) {
while (|a| >= SQRT_MAX_P1) a/=2, sh++
while (|b| >= SQRT_MAX_P1) b/=2, sh++
while (|c| >= SQRT_MAX_P1) c/=2, sh--
}
result = a*b/c
shift result by sh.
对于 n-bit uint
,我希望结果至少正确到 n/2
位有效数字。
可以利用 a,b
中较小的小于 SQRT_MAX_P1
的优势来改进事情。如果感兴趣,稍后会详细介绍。
例子
#include <inttypes.h>
#define IMAX_BITS(m) ((m)/((m)%255+1) / 255%255*8 + 7-86/((m)%255+12))
//
#define UINTMAX_WIDTH (IMAX_BITS(UINTMAX_MAX))
#define SQRT_UINTMAX_P1 (((uintmax_t)1ull) << (UINTMAX_WIDTH/2))
uintmax_t muldiv_about(uintmax_t a, uintmax_t b, uintmax_t c) {
int shift = 0;
if (c > SQRT_UINTMAX_P1) {
while (a >= SQRT_UINTMAX_P1) {
a /= 2; shift++;
}
while (b >= SQRT_UINTMAX_P1) {
b /= 2; shift++;
}
while (c >= SQRT_UINTMAX_P1) {
c /= 2; shift--;
}
}
uintmax_t r = a * b / c;
if (shift > 0) r <<= shift;
if (shift < 0) r >>= shift;
return r;
}
#include <stdio.h>
int main() {
uintmax_t a = 12345678;
uintmax_t b = 4235266395;
uintmax_t c = 4235266396;
uintmax_t r = muldiv_about(a,b,c);
printf("%ju\n", r);
}
输出 32 位数学(精确答案是 12345677)
12345600
使用 64 位数学输出
12345677
这是另一种使用递归和最小近似来实现高精度的方法。
首先是代码,下面是解释。
代码:
uint32_t bp(uint32_t a) {
uint32_t b = 0;
while (a!=0)
{
++b;
a >>= 1;
};
return b;
}
int mul_no_ovf(uint32_t a, uint32_t b)
{
return ((bp(a) + bp(b)) <= 32);
}
uint32_t f(uint32_t a, uint32_t b, uint32_t c)
{
if (mul_no_ovf(a, b))
{
return (a*b) / c;
}
uint32_t m = c / b;
++m;
uint32_t x = m*b - c;
// So m * b == c + x where x < b and m >= 2
uint32_t n = a/m;
uint32_t r = a % m;
// So a*b == n * (c + x) + r*b == n*c + n*x + r*b where r*b < c
// Approximation: get rid of the r*b part
uint32_t res = n;
if (r*b > c/2) ++res;
return res + f(n, x, c);
}
解释:
The multiplication a * b can be written as a sum of b
a * b = b + b + .... + b
Since b < c we can take a number m of these b so that (m-1)*b < c <= m*b, like
(b + b + ... + b) + (b + b + ... + b) + .... + b + b + b
\---------------/ \---------------/ + \-------/
m*b + m*b + .... + r*b
\-------------------------------------/
n times m*b
so we have
a*b = n*m*b + r*b
where r*b < c and m*b > c. Consequently, m*b is equal to c + x, so we have
a*b = n*(c + x) + r*b = n*c + n*x + r*b
Divide by c :
a*b/c = (n*c + n*x + r*b)/c = n + n*x/c + r*b/c
The values m, n, x, r can all be calculated from a, b and c without any loss of
precision using integer division (/) and remainder (%).
The approximation is to look at r*b (which is less than c) and "add zero" when r*b<=c/2
and "add one" when r*b>c/2.
So now there are two possibilities:
1) a*b = n + n*x/c
2) a*b = (n + 1) + n*x/c
So the problem (i.e. calculating a*b/c) has been changed to the form
MULDIV(a1,b1,c) = NUMBER + MULDIV(a2,b2,c)
where a2,b2 is less than a1,b2. Consequently, recursion can be used until
a2*b2 no longer overflows (and the calculation can be done directly).
我已经建立了一个在 O(1)
复杂度(无循环)下工作的解决方案:
typedef unsigned long long uint;
typedef struct
{
uint n;
uint d;
}
fraction;
uint func(uint a, uint b, uint c);
fraction reducedRatio(uint n, uint d, uint max);
fraction normalizedRatio(uint a, uint b, uint scale);
fraction accurateRatio(uint a, uint b, uint scale);
fraction toFraction(uint n, uint d);
uint roundDiv(uint n, uint d);
uint func(uint a, uint b, uint c)
{
uint hi = a > b ? a : b;
uint lo = a < b ? a : b;
fraction f = reducedRatio(hi, c, (uint)(-1) / lo);
return f.n * lo / f.d;
}
fraction reducedRatio(uint n, uint d, uint max)
{
fraction f = toFraction(n, d);
if (n > max || d > max)
f = normalizedRatio(n, d, max);
if (f.n != f.d)
return f;
return toFraction(1, 1);
}
fraction normalizedRatio(uint a, uint b, uint scale)
{
if (a <= b)
return accurateRatio(a, b, scale);
fraction f = accurateRatio(b, a, scale);
return toFraction(f.d, f.n);
}
fraction accurateRatio(uint a, uint b, uint scale)
{
uint maxVal = (uint)(-1) / scale;
if (a > maxVal)
{
uint c = a / (maxVal + 1) + 1;
a /= c; // we can now safely compute `a * scale`
b /= c;
}
if (a != b)
{
uint n = a * scale;
uint d = a + b; // can overflow
if (d >= a) // no overflow in `a + b`
{
uint x = roundDiv(n, d); // we can now safely compute `scale - x`
uint y = scale - x;
return toFraction(x, y);
}
if (n < b - (b - a) / 2)
{
return toFraction(0, scale); // `a * scale < (a + b) / 2 < MAXUINT256 < a + b`
}
return toFraction(1, scale - 1); // `(a + b) / 2 < a * scale < MAXUINT256 < a + b`
}
return toFraction(scale / 2, scale / 2); // allow reduction to `(1, 1)` in the calling function
}
fraction toFraction(uint n, uint d)
{
fraction f = {n, d};
return f;
}
uint roundDiv(uint n, uint d)
{
return n / d + n % d / (d - d / 2);
}
这是我的测试:
#include <stdio.h>
int main()
{
uint a = (uint)(-1) / 3; // 0x5555555555555555
uint b = (uint)(-1) / 2; // 0x7fffffffffffffff
uint c = (uint)(-1) / 1; // 0xffffffffffffffff
printf("0x%llx", func(a, b, c)); // 0x2aaaaaaaaaaaaaaa
return 0;
}
假设uint
是我定点平台上最大的整数类型,我有:
uint func(uint a, uint b, uint c);
这需要 return a * b / c
的良好近似值。
c
的值大于 a
的值和 b
的值。
所以我们肯定知道 a * b / c
的值适合 uint
.
然而,a * b
的值本身溢出了uint
的大小。
因此计算 a * b / c
值的一种方法是:
return a / c * b;
甚至:
if (a > b)
return a / c * b;
return b / c * a;
但是,c
的值大于 a
的值和 b
的值。
所以上面的建议只是 return 零。
我需要按比例减少 a * b
和 c
,但同样 - 问题是 a * b
溢出。
理想情况下,我将能够:
- 将
a * b
替换为uint(-1)
- 将
c
替换为uint(-1) / a / b * c
。
但是无论我如何排序表达式uint(-1) / a / b * c
,我都遇到了一个问题:
uint(-1) / a / b * c
由于uint(-1) / a / b
被截断为零
uint(-1) / a * c / b
由于uint(-1) / a * c
而溢出
uint(-1) * c / a / b
由于uint(-1) * c
溢出
如何解决这种情况才能找到 a * b / c
的良好近似值?
编辑 1
我的平台上没有_umul128
这样的东西,最大的整数类型是uint64
。我最大的类型是 uint
,我不支持任何比它更大的类型(无论是在硬件层面,还是在一些预先存在的标准库中)。
我最大的类型是uint
。
编辑 2
回应大量重复的建议和评论:
我手头没有可以用来解决这个问题的“更大的类型”。这就是为什么问题的开头陈述是:
Assuming that
uint
is the largest integral type on my fixed-point platform
我假设不存在其他类型,无论是在 SW 层(通过一些内置标准库)还是在 HW 层。
您可以按如下方式取消质因数:
uint gcd(uint a, uint b)
{
uint c;
while (b)
{
a %= b;
c = a;
a = b;
b = c;
}
return a;
}
uint func(uint a, uint b, uint c)
{
uint temp = gcd(a, c);
a = a/temp;
c = c/temp;
temp = gcd(b, c);
b = b/temp;
c = c/temp;
// Since you are sure the result will fit in the variable, you can simply
// return the expression you wanted after having those terms canceled.
return a * b / c;
}
needs to return a good approximation of
a * b / c
My largest type isuint
both a and b are smaller than c
Algorithm: Scale a, b to not overflow
SQRT_MAX_P1 as a compile time constant of sqrt(uint_MAX + 1)
sh = 0;
if (c >= SQRT_MAX_P1) {
while (|a| >= SQRT_MAX_P1) a/=2, sh++
while (|b| >= SQRT_MAX_P1) b/=2, sh++
while (|c| >= SQRT_MAX_P1) c/=2, sh--
}
result = a*b/c
shift result by sh.
对于 n-bit uint
,我希望结果至少正确到 n/2
位有效数字。
可以利用 a,b
中较小的小于 SQRT_MAX_P1
的优势来改进事情。如果感兴趣,稍后会详细介绍。
例子
#include <inttypes.h>
#define IMAX_BITS(m) ((m)/((m)%255+1) / 255%255*8 + 7-86/((m)%255+12))
//
#define UINTMAX_WIDTH (IMAX_BITS(UINTMAX_MAX))
#define SQRT_UINTMAX_P1 (((uintmax_t)1ull) << (UINTMAX_WIDTH/2))
uintmax_t muldiv_about(uintmax_t a, uintmax_t b, uintmax_t c) {
int shift = 0;
if (c > SQRT_UINTMAX_P1) {
while (a >= SQRT_UINTMAX_P1) {
a /= 2; shift++;
}
while (b >= SQRT_UINTMAX_P1) {
b /= 2; shift++;
}
while (c >= SQRT_UINTMAX_P1) {
c /= 2; shift--;
}
}
uintmax_t r = a * b / c;
if (shift > 0) r <<= shift;
if (shift < 0) r >>= shift;
return r;
}
#include <stdio.h>
int main() {
uintmax_t a = 12345678;
uintmax_t b = 4235266395;
uintmax_t c = 4235266396;
uintmax_t r = muldiv_about(a,b,c);
printf("%ju\n", r);
}
输出 32 位数学(精确答案是 12345677)
12345600
使用 64 位数学输出
12345677
这是另一种使用递归和最小近似来实现高精度的方法。
首先是代码,下面是解释。
代码:
uint32_t bp(uint32_t a) {
uint32_t b = 0;
while (a!=0)
{
++b;
a >>= 1;
};
return b;
}
int mul_no_ovf(uint32_t a, uint32_t b)
{
return ((bp(a) + bp(b)) <= 32);
}
uint32_t f(uint32_t a, uint32_t b, uint32_t c)
{
if (mul_no_ovf(a, b))
{
return (a*b) / c;
}
uint32_t m = c / b;
++m;
uint32_t x = m*b - c;
// So m * b == c + x where x < b and m >= 2
uint32_t n = a/m;
uint32_t r = a % m;
// So a*b == n * (c + x) + r*b == n*c + n*x + r*b where r*b < c
// Approximation: get rid of the r*b part
uint32_t res = n;
if (r*b > c/2) ++res;
return res + f(n, x, c);
}
解释:
The multiplication a * b can be written as a sum of b
a * b = b + b + .... + b
Since b < c we can take a number m of these b so that (m-1)*b < c <= m*b, like
(b + b + ... + b) + (b + b + ... + b) + .... + b + b + b
\---------------/ \---------------/ + \-------/
m*b + m*b + .... + r*b
\-------------------------------------/
n times m*b
so we have
a*b = n*m*b + r*b
where r*b < c and m*b > c. Consequently, m*b is equal to c + x, so we have
a*b = n*(c + x) + r*b = n*c + n*x + r*b
Divide by c :
a*b/c = (n*c + n*x + r*b)/c = n + n*x/c + r*b/c
The values m, n, x, r can all be calculated from a, b and c without any loss of
precision using integer division (/) and remainder (%).
The approximation is to look at r*b (which is less than c) and "add zero" when r*b<=c/2
and "add one" when r*b>c/2.
So now there are two possibilities:
1) a*b = n + n*x/c
2) a*b = (n + 1) + n*x/c
So the problem (i.e. calculating a*b/c) has been changed to the form
MULDIV(a1,b1,c) = NUMBER + MULDIV(a2,b2,c)
where a2,b2 is less than a1,b2. Consequently, recursion can be used until
a2*b2 no longer overflows (and the calculation can be done directly).
我已经建立了一个在 O(1)
复杂度(无循环)下工作的解决方案:
typedef unsigned long long uint;
typedef struct
{
uint n;
uint d;
}
fraction;
uint func(uint a, uint b, uint c);
fraction reducedRatio(uint n, uint d, uint max);
fraction normalizedRatio(uint a, uint b, uint scale);
fraction accurateRatio(uint a, uint b, uint scale);
fraction toFraction(uint n, uint d);
uint roundDiv(uint n, uint d);
uint func(uint a, uint b, uint c)
{
uint hi = a > b ? a : b;
uint lo = a < b ? a : b;
fraction f = reducedRatio(hi, c, (uint)(-1) / lo);
return f.n * lo / f.d;
}
fraction reducedRatio(uint n, uint d, uint max)
{
fraction f = toFraction(n, d);
if (n > max || d > max)
f = normalizedRatio(n, d, max);
if (f.n != f.d)
return f;
return toFraction(1, 1);
}
fraction normalizedRatio(uint a, uint b, uint scale)
{
if (a <= b)
return accurateRatio(a, b, scale);
fraction f = accurateRatio(b, a, scale);
return toFraction(f.d, f.n);
}
fraction accurateRatio(uint a, uint b, uint scale)
{
uint maxVal = (uint)(-1) / scale;
if (a > maxVal)
{
uint c = a / (maxVal + 1) + 1;
a /= c; // we can now safely compute `a * scale`
b /= c;
}
if (a != b)
{
uint n = a * scale;
uint d = a + b; // can overflow
if (d >= a) // no overflow in `a + b`
{
uint x = roundDiv(n, d); // we can now safely compute `scale - x`
uint y = scale - x;
return toFraction(x, y);
}
if (n < b - (b - a) / 2)
{
return toFraction(0, scale); // `a * scale < (a + b) / 2 < MAXUINT256 < a + b`
}
return toFraction(1, scale - 1); // `(a + b) / 2 < a * scale < MAXUINT256 < a + b`
}
return toFraction(scale / 2, scale / 2); // allow reduction to `(1, 1)` in the calling function
}
fraction toFraction(uint n, uint d)
{
fraction f = {n, d};
return f;
}
uint roundDiv(uint n, uint d)
{
return n / d + n % d / (d - d / 2);
}
这是我的测试:
#include <stdio.h>
int main()
{
uint a = (uint)(-1) / 3; // 0x5555555555555555
uint b = (uint)(-1) / 2; // 0x7fffffffffffffff
uint c = (uint)(-1) / 1; // 0xffffffffffffffff
printf("0x%llx", func(a, b, c)); // 0x2aaaaaaaaaaaaaaa
return 0;
}