合并 Python3 中的 k 个排序列表,内存和时间之间的权衡问题

Merging k sorted lists in Python3, problem with trade-off between memory and time

输入是: 第一行——数组的数量(k); 下一行 - 第一个数字是数组大小,接下来的数字是元素。

最大 k 为 1024。最大数组大小为 10*k。 0 到 100 之间的所有数字。内存限制 - 10MB,时间限制 - 1s。 建议的复杂度为 k ⋅ log(k) ⋅ n,其中 n 是数组长度。

示例输入:

4            
6 2 26 64 88 96 96
4 8 20 65 86
7 1 4 16 42 58 61 69
1 84

示例输出:

1 2 4 8 16 20 26 42 58 61 64 65 69 84 86 88 96 96 

我有4个解决方案。一种使用 heapq 并按块读取输入行,一种使用 heapq,一种使用 Counter,一种什么都不用。

这个使用 heapq(有利于时间但不利于内存,我认为堆是正确的方式,但是如果我按部分读取行,也许可以优化它,这样我就不需要内存了整个输入):

from heapq import merge


if __name__ == '__main__':
    print(*merge(*[[int(el) for el in input().split(' ')[1:]] for _ in range(int(input()))]), sep=' ')

这个是上一个的进阶版。它按块读取行,但是这是非常复杂的解决方案,我不知道如何优化这些读取:

from heapq import merge
from functools import reduce


def read_block(n, fd, cursors, offset, has_unused_items):
    MEMORY_LIMIT = 10240000
    block_size = MEMORY_LIMIT / n
    result = []

    for i in range(n):
        if has_unused_items[i]:
            if i == 0:
                fd.seek(cursors[i] + offset)
            else:
                fd.read(cursors[i])

            block = ''
            c = 0
            char = ''

            while c < block_size or char != ' ':
                if cursors[i] == 0:
                    while char != ' ':
                        char = fd.read(1)
                        cursors[i] += 1

                char = fd.read(1)

                if char != '\n':
                    block += char
                    cursors[i] += 1
                    c += 1
                else:
                    has_unused_items[i] = False
                    break

            result.append([int(i) for i in block.split(' ')])

            while char != '\n':
                char = fd.read(1)

    return result


def to_output(fd, iter):
    fd.write(' '.join([str(el) for el in iter]))


if __name__ == '__main__':
    with open('input.txt') as fd_input:
        with open('output.txt', 'w') as fd_output:
            n = int(fd_input.readline())
            offset = fd_input.tell()
            cursors = [0] * n
            has_unused_items = [True] * n
            result = []

            while reduce(lambda x, p: x or p, has_unused_items):
                result = merge(
                    result,
                    *read_block(n, fd_input, cursors, offset, has_unused_items)
                )

            to_output(fd_output, result)

这个比较好记(用计数器排序,但是我没有用到所有数组都排序的信息):

from collections import Counter


def solution():
    A = Counter()

    for _ in range(int(input())):
        A.update(input().split(' ')[1:])

    for k in sorted([int(el) for el in A]):
        for _ in range(A[str(k)]):
            yield k

这个适合时间(但可能不够好):

def solution():
    A = tuple(tuple(int(el) for el in input().split(' ')[1:]) for _ in range(int(input())) # input data
    c = [0] * len(A) # cursors for each array

    for i in range(101):
        for j, a in enumerate(A):
            for item in a[c[j]:]:
                if item == i:
                    yield i
                    c[j] += 1
                else:
                    break 

完美,如果我在第一个示例中按部分排列数组,那么整个输入就不需要内存,但我不知道如何正确地按块读取行。

你能提出一些解决问题的建议吗?

O Deep Thought computer,生命宇宙万物的答案是什么

这是我用于测试的代码

"""4
6 2 26 64 88 96 96
4 8 20 65 86
7 1 4 16 42 58 61 69
1 84"""

from heapq import merge
from io import StringIO
from timeit import timeit

def solution():
    pass

times = []
for i in range(5000):
    f = StringIO(__doc__)
    times.append(timeit(solution, number=1))

print(min(times))

这是结果,我测试了评论中提出的解决方案:

6.5e-06 秒

def solution():
    A = []
    A = merge(A, *((int(i)
                    for i in line.split(' ')[1:])
                    for line in f.readlines()))
    return A

7.1e-06 秒

def solution():
    A = []
    for _ in range(int(f.readline())):
        A = merge(A, (int(i) for i in f.readline().split(' ')[1:]))
    return A

7.9e-07 秒

def solution():
    A = Counter()
    for _ in range(int(f.readline())):
        A.update(f.readline().split(' ')[1:])
    for k in sorted([int(el) for el in A]):
        for _ in range(A[str(k)]):
            yield k

8.3e-06 秒

def solution():
    A = []
    for _ in range(int(f.readline())):
        for i in f.readline().split(' ')[1:]:
            insort(A, i)
    return A

6.2e-07 秒

def solution():
    A = Counter()
    for _ in range(int(f.readline())):
        A.update(f.readline().split(' ')[1:])
    l = [int(el) for el in A]
    l.sort()
    for k in l:
        for _ in range(A[str(k)]):
            yield k

你的代码很棒,不要使用 sorted(数组越大影响越大)。你应该用更大的输入来测试它(我用了你给的)。

这只有前一个的获胜者(加上解决方案 6,这是您给出的第二个)。速度限制似乎是由程序的 I/O 给出的,而不是排序本身。

请注意,我生成正方形(行数 == 每行数)

如果整数行已经排序,那么你只需要关注如何将这些片段拼接在一起。

为了实现这一点,我的解决方案在元组列表中跟踪问题的 state

每个元组记录该行的offsetnum_elements是该行中待处理的元素个数,next_elem是下一个要处理的元素的值被处理,last_elem 是行中最后一个元素的值。

算法遍历 state 元组列表,这些元组根据 next_elemlast_elem 的值排序,将下一个最低值附加到 A 列表. state 已更新,列表已排序,冲洗并重复直到列表为空。

我很想知道它相对于其他解决方案的表现如何。

from operator import itemgetter

def solution():
    state = []
    A = []
    k = int(f.readline())
    for _ in range(k):
        offset = f.tell()
        line = f.readline().split()
        # Store the state data for processing each line in a tuple
        # Append tuple to the state list: (offset, num_elements, next_elem, last_elem)
        state.append((offset, int(line[0]), int(line[1]), int(line[-1])))
    # Sort the list of stat tuples by value of next and last elements
    state.sort(key=itemgetter(2, 3))
    # [
    #    (34, 7, 1, 69),
    #    (2, 6, 2, 96),
    #    (21, 4, 8, 86),
    #    (55, 1, 84, 84)
    # ]
    while (len(state) > 0):
        offset, num_elements, _, last = state[0]
        _ = f.seek(offset)
        line = f.readline().split()
        if ((len(state) == 1) or (last <= state[1][2])):
            # Add the remaining line elements to the `result`
            A += line[-(num_elements):]
            # Delete the line from state
            del state[0]
        else:
            while (int(line[-(num_elements)]) <= state[1][2]):
                # Append the element to the `result`
                A.append(line[-(num_elements)])
                # Decrement the number of elements in the line to be processed
                num_elements -= 1
            if (num_elements > 0):
                # Update the tuple
                state[0] = (offset, (num_elements), int(
                    line[-(num_elements)]), int(line[-1]))
                # Sort the list of tuples
                state.sort(key=itemgetter(2, 3))
            else:
                # Delete the depleted line from state
                del state[0]
    # Return the result
    return A