找出 Uneaten Leaves 算法错误

figure out Uneaten Leaves algorithm bug

我在面试挑战中遇到过这个问题

K caterpillars are eating their way through N leaves, each caterpillar falls from leaf to leaf in a unique sequence, all caterpillars start at a twig at position 0 and falls onto the leaves at position between 1 and N. Each caterpillar j has an associated jump number Aj. A caterpillar with jump number j eats leaves at positions that are multiple of j. It will proceed in the order j, 2j, 3j…. till it reaches the end of the leaves and it stops and build its cocoon. Given a set A of K elements , we need to determine the number of uneaten leaves.

Constraints:

1 <= N <= 109

1 <= K <= 15

1 <= A[i] <= 109

Input format:

N = No of uneaten leaves.

K = No. of caterpillars.

A = Array of integer. jump numbers Output:

The integer nu. Of uneaten leaves

Sample Input:

10
3
2
4
5

Output:

4

Explanation:

[2, 4, 5] is the 3-member set of jump numbers. All leaves which are multiple of 2, 4, and 5 are eaten. Only 4 leaves which are numbered 1,3,7,9 are left.

解决这个问题的天真的方法是有一个包含所有 N 个数字的 布尔值 数组,并遍历每只毛毛虫并记住被它吃掉的叶子。

int uneatenusingNaive(int N, vector<int> A)
{
    int eaten = 0;
    vector<bool>seen(N+1, false);
    for (int i = 0; i < A.size(); i++)
    {
        long Ai = A[i];
        long j = A[i];
        while (j <= N && j>0)
        {
            if (!seen[j])
            {
                seen[j] = true;
                eaten++;
            }
            j += Ai;
        }
    }
    return N - eaten;
}

这种方法通过了 10 个测试案例中的 8 个,并给出了 2 个案例的错误答案。

另一种方法使用 Inclusion Exclusion principle, explanation for it can be found here and here
下面是我的第二种方法的代码

 int gcd(int a, int b)
    {
        if (b == 0)
            return a;
        return gcd(b, a%b);
    }
    int lcm(int i, int j)
    {
        return i*j / gcd(i, j);
    }
    
    vector<vector<int>> mixStr(vector<vector<int>> & mix, vector<int>& A, unordered_map<int, int> & maxStart)
    {
        vector<vector<int>> res;
        if (mix.size() == 0)
        {
            for (int i = 0; i < A.size(); i++)
            {
                vector<int> tmp;
                tmp.push_back(A[i]);
                res.push_back(tmp);
            }
            return res;
        }
        
        
        for (int i = 0; i<mix.size(); i++)
        {
            int currSlotSize = mix[i].size();
            int currSlotMax = mix[i][currSlotSize - 1];
            
            for (int j = maxStart[currSlotMax]; j < A.size(); j++)
            {
                vector<int> tmp(mix[i]);
                tmp.push_back(A[j]);
                res.push_back(tmp);
            }
        }
        return res;
    }
    int uneatenLeavs(int N, int k, vector<int> A)
    {
        int i = 0;
        vector<vector<int>> mix;
        bool sign = true;
        int res = N;
        sort(A.begin(), A.end());
        unordered_map<int,int> maxStart;
        for (int i = 0; i < A.size(); i++)
        {
            maxStart[A[i]] = i + 1;
        }
        int eaten = 0;
        
    
        while (mix.size() != 1)
        {   
            
            mix = mixStr(mix, A, maxStart);
            for (int j = 0; j < mix.size(); j++)
            {
                int _lcm = mix[j][0];
                for (int s = 1; s < mix[j].size(); s++)
                {
                    _lcm = lcm(mix[j][s], _lcm);
                }
                if (sign)
                {
                    res -= N / _lcm;
                }
                else
                {
                    res += N / _lcm;
                }
            }
            sign = !sign;
            i++;
        }
        return res;
    }

这种方法只通过了 1/10 的测试用例。对于其余的测试用例,超出了时间限制且答案错误。

问题:
我在第一种或第二种方法中缺少什么是 100% 正确的。

使用包含-排除定理是正确的方法,但是,您的实现似乎太慢了。我们可以使用位掩码技术来获得 O(K*2^K) 时间复杂度。

看看这个:

long result = 0;

for(int i = 1; i < 1 << K; i++){
     long lcm = 1;
     for(int j = 0; j < K; j++)
        if(((1<<j) & i) != 0) //if bit j is set, compute new LCM after including A[j]
           lcm *= A[j]/gcd(lcm, A[j]);
     if(number of bit set in i is odd)
        result += N/lcm;
     else
        result -= N/lcm; 
}

对于您的第一种方法,O(N*K) 时间复杂度算法,N = 10^9 和 K = 15,它会太慢,并且可以导致内存限制 exceed/time 超出限制。

注意 lcm 可以大于 N,因此需要额外检查。