unsigned long long int 值中的 C++ 奇怪跳转

C++ strange jump in unsigned long long int values

我有以下问题,实际上来自我最近参加的编码测试:

问题:

函数 f(n) = a*n + b*n*(floor(log(n)/log(2))) + c*n*n*n 存在。

在特定值下,令f(n) = k

给定 k, a, b, c,找到 n

对于给定的 k 值,如果不存在 n 值,则 return 0.

限制:

1 <= n < 2^63-1
0 < a, b < 100
0 <= c < 100
0 < k < 2^63-1

这里的逻辑是,由于 f(n) 对于给定的 a、b 和 c 是纯递增的,所以我可以通过二进制搜索找到 n

我写的代码如下:

#include<iostream>
#include<stdlib.h>
#include<math.h>
using namespace std;

unsigned long long logToBase2Floor(unsigned long long n){
    return (unsigned long long)(double(log(n))/double(log(2)));
}

#define f(n, a, b, c) (a*n + b*n*(logToBase2Floor(n)) + c*n*n*n)


unsigned long long findNByBinarySearch(unsigned long long k, unsigned long long a, unsigned long long b, unsigned long long c){
    unsigned long long low = 1;
    unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
    unsigned long long n;
    while(low<=high){
        n = (low+high)/2;
        cout<<"\n\n          k= "<<k;
        cout<<"\n f(n,a,b,c)= "<<f(n,a,b,c)<<"  low = "<<low<<"  mid="<<n<<"  high = "<<high;
        if(f(n,a,b,c) == k)
            return n;
        else if(f(n,a,b,c) < k)
            low = n+1;
        else high = n-1;
    }
    return 0;
}

然后我用几个测试用例进行了尝试:

int main(){
    unsigned long long n, a, b, c;
    n = (unsigned long long)pow(2,63)-1;
    a = 99;
    b = 99;
    c = 99;
    cout<<"\nn="<<n<<"  a="<<a<<"  b="<<b<<"  c="<<c<<"    k = "<<f(n, a, b, c);
    cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;
    n = 1000;
    cout<<"\nn="<<n<<"  a="<<a<<"  b="<<b<<"  c="<<c<<"    k = "<<f(n, a, b, c);
    cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;
    return 0;
}

然后奇怪的事情发生了。

代码适用于测试用例 n = (unsigned long long)pow(2,63)-1;,正确地 returning 了 n 的值。但它不适用于 n=1000。我打印输出并看到以下内容:

n=1000  a=99  b=99  c=99    k = 99000990000

          k= 99000990000
 f(n,a,b,c)= 4611686018427387904  low = 1  mid=4611686018427387904  high = 9223372036854775807
 ...
 ...
          k= 99000990000
 f(n,a,b,c)= 172738215936  low = 1  mid=67108864  high = 134217727

          k= 99000990000
 f(n,a,b,c)= 86369107968  low = 1  mid=33554432  high = 67108863

          k= 99000990000
 f(n,a,b,c)= 129553661952  low = 33554433  mid=50331648  high = 67108863**
 ...
 ...
          k= 99000990000
 f(n,a,b,c)= 423215328047139441  low = 37748737  mid=37748737  high = 37748737
ANSWER: 0

数学上似乎有些不对。为什么 f(1000) 的值大于 f(33554432) 的值?

所以我在 Python 中尝试了相同的代码,并得到了以下值:

>>> f(1000, 99, 99, 99)
99000990000L
>>> f(33554432, 99, 99, 99)
3740114254432845378355200L

所以,价值肯定更大。

问题:


到底发生了什么?

问题在这里:

unsigned long long low = 1;
// Side note: This is simply (2ULL << 62) - 1
unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
unsigned long long n;
while (/* irrelevant */) {
    n = (low + high) / 2;
    // Some stuff that do not modify n... 
    f(n, a, b, c) // <-- Here!
}

在第一次迭代中,您有 low = 1high = 2^63 - 1,这意味着 n = 2^63 / 2 = 2^62。现在,让我们看看 f:

#define f(n, a, b, c) (/* I do not care about this... */ + c*n*n*n)

您在 f 中有 n^3,所以对于 n = 2^62n^3 = 2^186,这对于您的 unsigned long long 来说可能太大了(这很可能为 64 位长)。

我该如何解决?

这里的主要问题是在进行二分查找时溢出,所以你应该单独处理溢出的情况。

序言: 我正在使用 ull_t 因为我很懒,你应该避免在 C++ 中使用宏,更喜欢使用函数并让编译器内联它。此外,我更喜欢循环而不是使用 log 函数来计算 unsigned long long 的 log2(请参阅此答案的底部以了解 log2is_overflow 的实现)。

using ull_t = unsigned long long;

constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
    if (n == 0ULL) { // Avoid log2(0)
        return 0ULL;
    }
    if (is_overflow(n, a, b, c)) {
        return 0ULL;
    }
    return a * n + b * n * log2(n) + c * n * n * n;
}

这里是稍微修改过的二进制搜索版本:

constexpr auto find_n (ull_t k, ull_t a, ull_t b, ull_t c) {
    constexpr ull_t max = std::numeric_limits<ull_t>::max();
    auto lb = 1ULL, ub = (1ULL << 63) - 1;
    while (lb <= ub) {
        if (ub > max - lb) {
            // This should never happens since ub < 2^63 and lb <= ub so lb + ub < 2^64
            return 0ULL;
        }
        // Compute middle point (no overflow guarantee).
        auto tn = (lb + ub) / 2;
        // If there is an overflow, then change the upper bound.
        if (is_overflow(tn, a, b, c)) {
            ub = tn - 1;
        }
        // Otherwize, do a standard binary search...
        else {
            auto val = f(tn, a, b, c);
            if (val < k) {
                lb = tn + 1;
            }
            else if (val > k) {
                ub = tn - 1;
            }
            else {
                return tn;
            }
        }
    }
    return 0ULL;
}

如您所见,这里只有一个相关的测试,即 is_overflow(tn, a, b, c)(关于 lb + ub 的第一个测试在这里不相关,因为 ub < 2^63lb <= ub < 2^63 所以 ub + lb < 2^64 在我们的例子中对于 unsigned long long 是可以的。

完成实施:

#include <limits>
#include <type_traits>

using ull_t = unsigned long long;

template <typename T, 
          typename = std::enable_if_t<std::is_integral<T>::value>>
constexpr auto log2 (T n) {
    T log = 0;
    while (n >>= 1) ++log;
    return log;
}

constexpr bool is_overflow (ull_t n, ull_t a, ull_t b, ull_t c) {
    ull_t max = std::numeric_limits<ull_t>::max();
    if (n > max / a) {
        return true;
    }
    if (n > max / b) {
        return true;
    }
    if (b * n > max / log2(n)) {
        return true;
    }
    if (c != 0) {
        if (n > max / c) return true;
        if (c * n > max / n) return true;
        if (c * n * n > max / n) return true;
    }
    if (a * n > max - c * n * n * n) {
        return true;
    }
    if (a * n + c * n * n * n > max - b * n * log2(n)) {
        return true;
    }
    return false;
}

constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
    if (n == 0ULL) {
        return 0ULL;
    }
    if (is_overflow(n, a, b, c)) {
        return 0ULL;
    }
    return a * n + b * n * log2(n) + c * n * n * n;
}

constexpr auto find_n (ull_t k, ull_t a, ull_t b, ull_t c) {
    constexpr ull_t max = std::numeric_limits<ull_t>::max();
    auto lb = 1ULL, ub = (1ULL << 63) - 1;
    while (lb <= ub) {
        if (ub > max - lb) {
            return 0ULL; // Problem here
        }
        auto tn = (lb + ub) / 2;
        if (is_overflow(tn, a, b, c)) {
            ub = tn - 1;
        }
        else {
            auto val = f(tn, a, b, c);
            if (val < k) {
                lb = tn + 1;
            }
            else if (val > k) {
                ub = tn - 1;
            }
            else {
                return tn;
            }
        }
    }
    return 0ULL;
}

编译时检查:

下面是一小段代码,您可以用它来在编译时检查上面的代码(因为一切都是 constexpr):

template <unsigned long long n, unsigned long long a, 
          unsigned long long b, unsigned long long c>
struct check: public std::true_type {
    enum {
        k = f(n, a, b, c)
    };
    static_assert(k != 0, "Value out of bound for (n, a, b, c).");
    static_assert(n == find_n(k, a, b, c), "");
};

template <unsigned long long a, 
          unsigned long long b, 
          unsigned long long c>
struct check<0, a, b, c>: public std::true_type {
    static_assert(a != a, "Ambiguous values for n when k = 0.");
};

template <unsigned long long n>
struct check<n, 0, 0, 0>: public std::true_type {
    static_assert(n != n, "Ambiguous values for n when a = b = c = 0.");
};

#define test(n, a, b, c) static_assert(check<n, a, b, c>::value, "");

test(1000, 99, 99, 0);
test(1000, 99, 99, 99);
test(453333, 99, 99, 99);
test(495862, 99, 99, 9);
test(10000000, 1, 1, 0);

注:k的最大值约为2^63,所以对于给定的三元组(a, b, c)(a, b, c)的最大值=40=] 就是 f(n, a, b, c) < 2 ^ 63f(n + 1, a, b, c) >= 2 ^ 63 之类的。对于a = b = c = 99,这个最大值是n = 453333(凭经验找到的),这也是我上面测试的原因。