Top K 频繁元素 - 时间复杂度:桶排序与堆排序
Top K Frequent Elements - time complexity: Bucket Sort vs Heap
我正在研究一个 leetcode 问题 (https://leetcode.com/problems/top-k-frequent-elements/),它是:
Given an integer array nums and an integer k, return the k most frequent elements. You may return the answer in any order.
我使用 min-heap
解决了这个问题(我的时间复杂度计算在评论中 - 如果我做错了请纠正我):
from collections import Counter
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1], x[0]) for x in c.items()])
# O(K)
result = list(islice(it, k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result, elem)
# O(K)
return [pair[1] for pair in result]
# O(K) + O(N) + O((N - K) log K) + O(K log K)
# if k < N :
# O(N log K)
然后我看到了一个使用 Bucket Sort
的解决方案,它假设用 O(N)
:
击败了堆解决方案
bucket = [[] for _ in nums]
# O(N)
c = collections.Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(itertools.chain(*bucket))[:k]
我们如何计算此处 itertools.chain
调用的时间复杂度?
是因为我们最多会链接 N
个元素吗?这足以推断它是 O(N) 吗?
无论如何,至少在 leetcode 测试用例上,第一个性能更好 - 这可能是什么原因?
list(itertools.chain(*bucket))
的时间复杂度为 O(N),其中 N 是嵌套列表 bucket
中元素的总数。 chain
函数大致等同于:
def chain(*iterables):
for iterable in iterables:
for item in iterable:
yield item
yield
语句占了运行次,复杂度为O(1),执行了N次,所以结果
您的 O(N log k) 算法在实践中可能最终变得更快的原因是因为 log k 可能不是很大; LeetCode 说 k 最多是数组中不同元素的数量,但我怀疑对于大多数测试用例来说 k 小得多,当然 log k 比那个小。
O(N)算法可能常数因子比较高,因为它分配N个列表,然后按索引随机访问它们,导致缓存未命中很多; append
操作也可能导致许多列表在变大时被重新分配。
尽管我的评论使用 nlargest
似乎 运行 比使用 heapify
等慢(见下文)。但是桶排序,至少对于这个输入,肯定是更高效的。似乎使用桶排序创建 num
元素的完整列表以获取第一个 k
元素不会导致太多的惩罚。
from collections import Counter
from heapq import nlargest
from itertools import chain
def most_frequent_1a(nums, k):
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1], x[0]) for x in c.items()])
# O(K)
result = list(islice(it, k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result, elem)
# O(K)
return [pair[1] for pair in result]
def most_frequent_1b(nums, k):
if k == len(nums):
return nums
c = Counter(nums)
return [pair[1] for pair in nlargest(k, [(x[1], x[0]) for x in c.items()])]
def most_frequent_2a(nums, k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(chain(*bucket))[:k]
def most_frequent_2b(nums, k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
# don't create full list:
i = 0
for elem in chain(*bucket):
yield elem
i += 1
if i == k:
break
import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums, 3))
print(most_frequent_1b(nums, 3))
print(most_frequent_2a(nums, 3))
print(list(most_frequent_2b(nums, 3)))
print(timeit.timeit(stmt='most_frequent_1a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums, 3))', number=10000, globals=globals()))
打印:
[7, 723, 88]
[723, 88, 7]
[7, 88, 723]
[7, 88, 723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136
我正在研究一个 leetcode 问题 (https://leetcode.com/problems/top-k-frequent-elements/),它是:
Given an integer array nums and an integer k, return the k most frequent elements. You may return the answer in any order.
我使用 min-heap
解决了这个问题(我的时间复杂度计算在评论中 - 如果我做错了请纠正我):
from collections import Counter
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1], x[0]) for x in c.items()])
# O(K)
result = list(islice(it, k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result, elem)
# O(K)
return [pair[1] for pair in result]
# O(K) + O(N) + O((N - K) log K) + O(K log K)
# if k < N :
# O(N log K)
然后我看到了一个使用 Bucket Sort
的解决方案,它假设用 O(N)
:
bucket = [[] for _ in nums]
# O(N)
c = collections.Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(itertools.chain(*bucket))[:k]
我们如何计算此处 itertools.chain
调用的时间复杂度?
是因为我们最多会链接 N
个元素吗?这足以推断它是 O(N) 吗?
无论如何,至少在 leetcode 测试用例上,第一个性能更好 - 这可能是什么原因?
list(itertools.chain(*bucket))
的时间复杂度为 O(N),其中 N 是嵌套列表 bucket
中元素的总数。 chain
函数大致等同于:
def chain(*iterables):
for iterable in iterables:
for item in iterable:
yield item
yield
语句占了运行次,复杂度为O(1),执行了N次,所以结果
您的 O(N log k) 算法在实践中可能最终变得更快的原因是因为 log k 可能不是很大; LeetCode 说 k 最多是数组中不同元素的数量,但我怀疑对于大多数测试用例来说 k 小得多,当然 log k 比那个小。
O(N)算法可能常数因子比较高,因为它分配N个列表,然后按索引随机访问它们,导致缓存未命中很多; append
操作也可能导致许多列表在变大时被重新分配。
尽管我的评论使用 nlargest
似乎 运行 比使用 heapify
等慢(见下文)。但是桶排序,至少对于这个输入,肯定是更高效的。似乎使用桶排序创建 num
元素的完整列表以获取第一个 k
元素不会导致太多的惩罚。
from collections import Counter
from heapq import nlargest
from itertools import chain
def most_frequent_1a(nums, k):
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1], x[0]) for x in c.items()])
# O(K)
result = list(islice(it, k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result, elem)
# O(K)
return [pair[1] for pair in result]
def most_frequent_1b(nums, k):
if k == len(nums):
return nums
c = Counter(nums)
return [pair[1] for pair in nlargest(k, [(x[1], x[0]) for x in c.items()])]
def most_frequent_2a(nums, k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(chain(*bucket))[:k]
def most_frequent_2b(nums, k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num, freq in c.items():
bucket[-freq].append(num)
# O(?)
# don't create full list:
i = 0
for elem in chain(*bucket):
yield elem
i += 1
if i == k:
break
import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums, 3))
print(most_frequent_1b(nums, 3))
print(most_frequent_2a(nums, 3))
print(list(most_frequent_2b(nums, 3)))
print(timeit.timeit(stmt='most_frequent_1a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums, 3)', number=10000, globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums, 3))', number=10000, globals=globals()))
打印:
[7, 723, 88]
[723, 88, 7]
[7, 88, 723]
[7, 88, 723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136