在具有严格内存限制的 O(n^3) 中求解 5SUM

Solving 5SUM in O(n^3) with strict memory limits

我需要一种方法来解决经典的 5SUM 问题,无需散列或使用内存有效的散列方法。

题目要求你找出给定长度为N的数组中有多少个子序列的和等于S

例如:

Input
6 5
1 1 1 1 1 1
Output
6

限制是:

N <= 1000 ( size of the array )
S <= 400000000 ( the sum of the subsequence )
Memory usage <= 5555 kbs
Execution time 2.2s

我很确定例外的复杂度是 O(N^3)。由于内存限制,哈希不提供实际的 O(1) 时间。

使用此代码我得到的最好成绩是 70 分。 (我在 6 次测试中得了 TLE)

#include <iostream>
#include <fstream>
#include <algorithm>
#include <vector>
#define MAX 1003
#define MOD 10472

using namespace std;

ifstream in("take5.in");
ofstream out("take5.out");

vector<pair<int, int>> has[MOD];
int v[MAX];
int pnt;
vector<pair<int, int>>::iterator it;

inline void ins(int val) {
    pnt = val%MOD;
    it = lower_bound(has[pnt].begin(), has[pnt].end(), make_pair(val, -1));
    if(it == has[pnt].end() || it->first != val) {
        has[pnt].push_back({val, 1});
        sort(has[pnt].begin(), has[pnt].end());
        return;
    }
    it->second++;
}

inline int get(int val) {
    pnt = val%MOD;
    it = lower_bound(has[pnt].begin(), has[pnt].end(), make_pair(val, -1));
    if(it == has[pnt].end() || it->first != val)
        return 0;
    return it->second;
}

int main() {

    int n,S;
    int ach = 0;
    int am = 0;
    int rez = 0;
    in >> n >> S;

    for(int i = 1; i <= n; i++)
        in >> v[i];

    sort(v+1, v+n+1);

    for(int i = n; i >= 1; i--) {

        if(v[i] > S)
            continue;

        for(int j = i+1; j <= n; j++) {
            if(v[i]+v[j] > S)
                break;
            ins(v[i]+v[j]);
        }

        int I = i-1;

        if(S-v[I] < 0)
            continue;

        for(int j = 1; j <= I-1; j++) {

            if(S-v[I]-v[j] < 0)
                break;

            for(int k = 1; k <= j-1; k++) {

                if(S-v[I]-v[j]-v[k] < 0)
                    break;

                ach = S-v[I]-v[j]-v[k];
                rez += get(ach);

            }
        }
    }

    out << rez << '\n';

    return 0;
}

我觉得可以。我们正在寻找数组 arr 中具有正确 SUM 的 5 项的所有子集。我们有索引为 0..N-1 的数组。这五项中的第三项可以在 2..N-3 范围内具有索引 i。我们循环遍历所有这些索引。对于每个索引 i,我们生成索引 i 左侧范围 0..i-1 中索引的两个数字的所有组合以及索引 i+1..N-1 范围中索引的两个数字的所有组合在索引 i 的右侧。对于每个索引 i,左侧加右侧的组合少于 N*N。我们将只存储每个组合的总和,因此它不会超过 1000 * 1000 * 4 = 4MB。

现在我们有两个数字序列(和),任务是:从第一个序列中取出一个数字,从第二个序列中取出一个数字,得到等于 Si = SUM - arr[i] 的和。有多少种组合?为了有效地做到这一点,必须对序列进行排序。首先说是升序排列并且有数字a, a, a, b, c ,...。第二个是降序排列并有数字 Z, Z, Y, X, W, ...。如果 a + Z > Si 那么我们可以丢弃 Z ,因为我们没有更小的数字可以匹配。如果 a + Z < Si 我们可以丢弃 a,因为我们没有更大的数字可以匹配。如果 a + Z = Si 我们有 2 * 3 = 6 个新组合并去掉 aZ。如果我们得到免费排序,这是很好的 O(N^3) 算法。

虽然排序不是免费的,但它是 O(N * N^2 * log(N^2)) = O(N^3 * log(N))。我们需要在线性时间内进行排序,这是不可能的。或者是吗?在索引 i+1 中,我们可以重用索引 i 中的序列。 i+1 只有很少的新组合 - 只有那些涉及数字 arr[i] 以及索引 0..i-1 中的一些数字的组合。如果我们对它们进行排序(我们可以,因为它们没有 N*N 个,但最多 N 个),我们所需要的只是合并两个已排序的序列。这可以在线性时间内完成。如果我们在开始时对 arr 进行排序,我们甚至可以避免完全排序。我们只是合并。

对于第二个序列,合并不是添加而是删除,但是非常相似。

实施似乎可行,但我预计某处会出现一个错误 ;-)

#include <iostream>
#include <fstream>
#include <algorithm>
#include <vector>

using namespace std;


int Generate(int arr[], int i, int sums[], int N, int NN)
{
    int p1 = 0;
    for (int i1 = 0; i1 < i - 1; ++i1)
    {
        int ai = arr[i1];
        for (int i2 = i1 + 1; i2 < i; ++i2)
        {
            sums[p1++] = ai + arr[i2];
        }
    }
    sort(sums, sums + p1);
    return p1;
}

int Combinations(int n, int sums[], int p1, int p2, int NN)
{
    int cnt = 0;
    int a = 0;
    int b = NN - p2;

    do
    {
        int state = sums[a] + sums[b] - n;

        if (state > 0) { ++b; }
        else if (state < 0) { ++a; }
        else
        {
            int cnta = 0;
            int lastA = sums[a];
            while (a < p1 && sums[a] == lastA) { a++; cnta++; }

            int cntb = 0;
            int lastB = sums[b];
            while (b < NN && sums[b] == lastB) { b++; cntb++; }

            cnt += cnta * cntb;
        }
    } while (b < NN && a < p1);

    return cnt;
}

int Add(int arr[], int i, int sums[], int p2, int N, int NN)
{
    int ii = N - 1;
    int n = arr[i];
    int nn = n + arr[ii--];
    int ip = NN - p2;
    int newP2 = p2 + N - i - 1;

    for (int p = NN - newP2; p < NN; ++p)
    {
        if (ip < NN && (ii < i || sums[ip] > nn))
        {
            sums[p] = sums[ip++];
        }
        else
        {
            sums[p] = nn;
            nn = n + arr[ii--];
        }
    }
    return newP2;
}

int Remove(int arr[], int i, int sums[], int p1)
{
    int ii = 0;
    int n = arr[i];
    int nn = n + arr[ii++];
    int pp = 0;
    int p = 0;
    for (; p < p1 - i; ++p)
    {
        while (ii <= i && sums[pp] == nn)
        {
            ++pp;
            nn = n + arr[ii++];
        }
        sums[p] = sums[pp++];
    }
    return p;
}

int main() {
    ifstream in("take5.in");
    ofstream out("take5.out");

    int N, SUM;
    in >> N >> SUM;

    int* arr = new int[N];

    for (int i = 0; i < N; i++)
        in >> arr[i];

    sort(arr, arr + N);

    int NN = (N - 3) * (N - 4) / 2 + 1;
    int* sums = new int[NN];
    int combinations = 0;
    int p1 = 0;
    int p2 = 1;

    for (int i = N - 3; i >= 2; --i)
    {
        if (p1 == 0)
        {
            p1 = Generate(arr, i, sums, N, NN);
            sums[NN - 1] = arr[N - 1] + arr[N - 2];
        }
        else
        {
            p1 = Remove(arr, i, sums, p1);
            p2 = Add(arr, i + 1, sums, p2, N, NN);
        }

        combinations += Combinations(SUM - arr[i], sums, p1, p2, NN);
    }

    out << combinations << '\n';

    return 0;
}