最近 Google 位运算的面试谜题

recent Google interview puzzle on bitwise operation

这是 Google 最近的面试问题:

We define f(X, Y) as number of different corresponding bits in binary representation of X and Y. For example, f(2, 7) = 2, since binary representation of 2 and 7 are 010 and 111, respectively. The first and the third bit differ, so f(2, 7) = 2.

You are given an array of N positive integers, A1, A2 ,…, AN. Find sum of f(Ai, Aj) for all pairs (i, j) such that 1 ≤ i, j ≤ N

例如:

A=[1, 3, 5]

We return

f(1, 1) + f(1, 3) + f(1, 5) + f(3, 1) + f(3, 3) + f(3, 5) + f(5, 1) + f(5, 3) + f(5, 5) =

0 + 1 + 1 + 1 + 0 + 2 + 1 + 2 + 0 = 8

我能想到这个解决方案是 O(n^2)

int numSetBits(unsigned int A) {
    int count  = 0;

    while(A != 0) {
        A = A & (A-1);
        count++;
    }

    return count;
}

int count_diff_bits(int a, int b)
{
    int x = a ^ b;

    return numSetBits(x);
}

for (i = 0; i < n; i++)
   for (j = 0; j < n; j++) {
       sum += count_diff_bits(A[i], A[j]);
   }
}

我能想到的另一种方法是(考虑到每个元素只包含一个二进制数字):

这种做法是否正确。

遍历数组,统计每个位索引中"on"位的个数,例如[1,3,5]:

0 0 1
0 1 1
1 0 1
-----
1 1 3

现在,对于每个位计数器,计算:

[bit count] * [array size - bit count] * 2

所有位求和...

以上例:

3 * (3 - 3) * 2 = 0
1 * (3 - 1) * 2 = 4
1 * (3 - 1) * 2 = 4
          total = 8

为了说明为什么这有效,让我们使用一位来查看问题的一个子集。让我们看看如果我们有一个数组会发生什么:[1, 1, 0, 0, 1, 0, 1]。我们的计数是 4,大小是 7。如果我们检查数组中所有位的第一位(包括问题中的自身),我们得到:

1 xor 1 = 0
1 xor 1 = 0
1 xor 0 = 1
1 xor 0 = 1
1 xor 1 = 0
1 xor 0 = 1
1 xor 1 = 0

可以看出,这个位的贡献是"off"个位。这同样适用于任何其他 "on" 位。我们可以说每个 "on" 位算作 "off" 位的数量:

[bit count] * [array size - bit count]

乘以 2 从何而来?好吧,因为我们对 "off" 位做同样的事情,除了这些,贡献是 "on" 位的数量:

[array size - bit count] * [bit count]

当然和上面一样,我们可以乘...

复杂度为 O(n*k) 其中 k 是位数(代码中为 32)。

#include <bits/stdc++.h>
#define MOD 1000000007ll
using namespace std;
typedef long long LL;

int solve(int arr[], int n) {

    int ans = 0;
    // traverse over all bits
    for(int i = 0; i < 31; i++) {

        // count number of elements with ith bit = 0
        long long count = 0;
        for(int j = 0; j < n; j++) if ( ( arr[j] & ( 1 << i ) ) ) count++;

        // add to answer count * (n - count) * 2
        ans += (count * ((LL)n - count) * 2ll) % MOD;
        if(ans >= MOD) ans -= MOD;
    }

    return ans;
}

int main() {

    int arr[] = {1, 3, 5};
    int n = sizeof arr / sizeof arr[0];
    cout << solve(arr, n) << endl; 
    return 0;
}