Top K 频繁元素 - 时间复杂度:桶排序与堆排序

Top K Frequent Elements - time complexity: Bucket Sort vs Heap

我正在研究一个 leetcode 问题 (https://leetcode.com/problems/top-k-frequent-elements/),它是:

Given an integer array nums and an integer k, return the k most frequent elements. You may return the answer in any order.

我使用 min-heap 解决了这个问题(我的时间复杂度计算在评论中 - 如果我做错了请纠正我):

        from collections import Counter
        
        if k == len(nums):
            return nums
        
        # O(N)
        c = Counter(nums)
        
        it = iter([(x[1], x[0]) for x in c.items()])
        
        # O(K)
        result = list(islice(it, k))
        heapify(result)
        
        # O(N-K)
        for elem in it:
            # O(log K)
            heappushpop(result, elem)
            
        # O(K)
        return [pair[1] for pair in result]
    
    # O(K) + O(N) + O((N - K) log K) + O(K log K)
    # if k < N :
    #   O(N log K)

然后我看到了一个使用 Bucket Sort 的解决方案,它假设用 O(N):

击败了堆解决方案
        bucket = [[] for _ in nums]

        # O(N)
        c = collections.Counter(nums)

        # O(d) where d is the number of distinct numbers. d <= N
        for num, freq in c.items():
            bucket[-freq].append(num)
            
        # O(?)
        return list(itertools.chain(*bucket))[:k]

我们如何计算此处 itertools.chain 调用的时间复杂度? 是因为我们最多会链接 N 个元素吗?这足以推断它是 O(N) 吗?

无论如何,至少在 leetcode 测试用例上,第一个性能更好 - 这可能是什么原因?

list(itertools.chain(*bucket)) 的时间复杂度为 O(N),其中 N 是嵌套列表 bucket 中元素的总数。 chain 函数大致等同于:

def chain(*iterables):
    for iterable in iterables:
        for item in iterable:
            yield item

yield语句占了运行次,复杂度为O(1),执行了N次,所以结果


您的 O(N log k) 算法在实践中可能最终变得更快的原因是因为 log k 可能不是很大; LeetCode 说 k 最多是数组中不同元素的数量,但我怀疑对于大多数测试用例来说 k 小得多,当然 log k 比那个小。

O(N)算法可能常数因子比较高,因为它分配N个列表,然后按索引随机访问它们,导致缓存未命中很多; append 操作也可能导致许多列表在变大时被重新分配。

尽管我的评论使用 nlargest 似乎 运行 比使用 heapify 等慢(见下文)。但是桶排序,至少对于这个输入,肯定是更高效的。似乎使用桶排序创建 num 元素的完整列表以获取第一个 k 元素不会导致太多的惩罚。

from collections import Counter
from heapq import nlargest
from itertools import chain

def most_frequent_1a(nums, k):
    if k == len(nums):
        return nums

    # O(N)
    c = Counter(nums)

    it = iter([(x[1], x[0]) for x in c.items()])

    # O(K)
    result = list(islice(it, k))
    heapify(result)

    # O(N-K)
    for elem in it:
        # O(log K)
        heappushpop(result, elem)

    # O(K)
    return [pair[1] for pair in result]

def most_frequent_1b(nums, k):        
    if k == len(nums):
        return nums

    c = Counter(nums)        
    return [pair[1] for pair in nlargest(k, [(x[1], x[0]) for x in c.items()])]


def most_frequent_2a(nums, k):
    bucket = [[] for _ in nums]

    # O(N)
    c = Counter(nums)

    # O(d) where d is the number of distinct numbers. d <= N
    for num, freq in c.items():
        bucket[-freq].append(num)

    # O(?)
    return list(chain(*bucket))[:k]


def most_frequent_2b(nums, k):
    bucket = [[] for _ in nums]

    # O(N)
    c = Counter(nums)

    # O(d) where d is the number of distinct numbers. d <= N
    for num, freq in c.items():
        bucket[-freq].append(num)

    # O(?)
    # don't create full list:
    i = 0
    for elem in chain(*bucket):
        yield elem
        i += 1
        if i == k:
            break

import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums, 3))
print(most_frequent_1b(nums, 3))
print(most_frequent_2a(nums, 3))
print(list(most_frequent_2b(nums, 3)))
print(timeit.timeit(stmt='most_frequent_1a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums, 3))', number=10000, globals=globals()))

打印:

[7, 723, 88]
[723, 88, 7]
[7, 88, 723]
[7, 88, 723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136