找到整数排序列表发生变化的索引
Find the indices where a sorted list of integer changes
假设一个排序的整数列表如下:
data = [1] * 3 + [4] * 5 + [5] * 2 + [9] * 3
# [1, 1, 1, 4, 4, 4, 4, 4, 5, 5, 9, 9, 9]
我想找到值发生变化的索引,即
[3, 8, 10, 13]
一种方法是使用 itertools.groupby
:
cursor = 0
result = []
for key, group in groupby(data):
cursor += sum(1 for _ in group)
result.append(cursor)
print(result)
输出
[3, 8, 10, 13]
这种方法是 O(n)。另一种可能的方法是使用 bisect.bisect_left
:
cursor = 0
result = []
while cursor < len(data):
cursor = bisect_left(data, data[cursor] + 1, cursor, len(data))
result.append(cursor)
print(result)
输出
[3, 8, 10, 13]
这种方法是 O(k*log n),其中 k 是不同元素的数量。这种方法的一个变体是使用 exponential search.
有没有更快或更高效的方法?
我在两组数据上测试了你的方法的执行时间,并使用 numpy
添加了第三组
data1 = [1] * 30000000 + [2] * 30000000 + [4] * 50000000 + [5] * 20000000 + [7] * 40000000 + [9] * 30000000 + [11] * 10000000 + [15] * 30000000
data2 = list(range(10000000))
cursor = 0
result = []
start_time = time.time()
for key, group in groupby(data):
cursor += sum(1 for _ in group)
result.append(cursor)
print(f'groupby {time.time() - start_time} seconds')
cursor = 0
result = []
start_time = time.time()
while cursor < len(data):
cursor = bisect_left(data, data[cursor] + 1, cursor, len(data))
result.append(cursor)
print(f'bisect_left {time.time() - start_time} seconds')
data = np.array(data)
start_time = time.time()
[i + 1 for i in np.where(data[:-1] != data[1:])[0]] + [len(data)]
print(f'numpy {time.time() - start_time} seconds')
# We need to iterate over the results array to add 1 to each index for your expected results.
与data1
groupby 8.864859104156494 seconds
bisect_left 0.0 seconds
numpy 0.27180027961730957 seconds
和data2
groupby 3.602466583251953 seconds
bisect_left 5.440978765487671 seconds
numpy 2.2847368717193604 seconds
正如您所提到的,bisect_left
在很大程度上取决于唯一元素的数量,但使用 numpy
似乎比 itertools.groupby
具有更好的性能,即使对索引进行了额外的迭代列表。
当谈到渐近复杂性时,我认为当您应用更均匀分布的分而治之方法时,平均而言,您可以在二进制搜索上略有改进:尝试首先查明更接近输入列表的中间,从而将范围分成大约两半,这将减少下一个二进制搜索操作路径。
然而,由于这是 Python,增益可能不明显,因为 Python 代码开销(如 yield
、yield from
、递归,...)。对于您使用的列表大小,它甚至可能表现更差:
from bisect import bisect_left
def locate(data, start, end):
if start >= end or data[start] == data[end - 1]:
return
mid = (start + end) // 2
val = data[mid]
if val == data[start]:
start = mid
val += 1
i = bisect_left(data, val, start + 1, end)
yield from locate(data, start, i)
yield i
yield from locate(data, i, end)
data = [1] * 3 + [4] * 5 + [5] * 2 + [9] * 3
print(*locate(data, 0, len(data))) # 3 8 10
请注意,这仅输出有效索引,因此此示例输入不包括 13。
既然你说“我对运行时更感兴趣”,这里有一些更快的 groupby
解决方案的 cursor += sum(1 for _ in group)
替代品。
使用@Guy 的 data1
但所有段长度除以 10:
normal optimized
original 870 ms 871 ms
zip_count 612 ms 611 ms
count_of 344 ms 345 ms
list_index 387 ms 386 ms
length_hint 223 ms 222 ms
改用list(range(1000000))
:
normal optimized
original 385 ms 331 ms
zip_count 463 ms 401 ms
count_of 197 ms 165 ms
list_index 175 ms 143 ms
length_hint 226 ms 127 ms
优化版本使用更多局部变量或列表理解。
我不认为 __length_hint__
是 保证 是准确的,即使对于列表迭代器也是如此,但它似乎是(通过了我的正确性检查)并且我不明白为什么它不会。
代码(Try it online!,但你必须减少一些东西才能不超过时间限制):
from timeit import default_timer as timer
from itertools import groupby, count
from collections import deque
from operator import countOf
def original(data):
cursor = 0
result = []
for key, group in groupby(data):
cursor += sum(1 for _ in group)
result.append(cursor)
return result
def original_opti(data):
cursor = 0
sum_ = sum
return [cursor := cursor + sum_(1 for _ in group)
for _, group in groupby(data)]
def zip_count(data):
cursor = 0
result = []
for key, group in groupby(data):
index = count()
deque(zip(group, index), 0)
cursor += next(index)
result.append(cursor)
return result
def zip_count_opti(data):
cursor = 0
result = []
append = result.append
count_, deque_, zip_, next_ = count, deque, zip, next
for key, group in groupby(data):
index = count_()
deque_(zip_(group, index), 0)
cursor += next_(index)
append(cursor)
return result
def count_of(data):
cursor = 0
result = []
for key, group in groupby(data):
cursor += countOf(group, key)
result.append(cursor)
return result
def count_of_opti(data):
cursor = 0
countOf_ = countOf
result = [cursor := cursor + countOf_(group, key)
for key, group in groupby(data)]
return result
def list_index(data):
cursor = 0
result = []
for key, _ in groupby(data):
cursor = data.index(key, cursor)
result.append(cursor)
del result[0]
result.append(len(data))
return result
def list_index_opti(data):
cursor = 0
index = data.index
groups = groupby(data)
next(groups, None)
result = [cursor := index(key, cursor)
for key, _ in groups]
result.append(len(data))
return result
def length_hint(data):
result = []
it = iter(data)
for _ in groupby(it):
result.append(len(data) - (1 + it.__length_hint__()))
del result[0]
result.append(len(data))
return result
def length_hint_opti(data):
it = iter(data)
groups = groupby(it)
next(groups, None)
n_minus_1 = len(data) - 1
length_hint = it.__length_hint__
result = [n_minus_1 - length_hint()
for _ in groups]
result.append(len(data))
return result
funcss = [
(original, original_opti),
(zip_count, zip_count_opti),
(count_of, count_of_opti),
(list_index, list_index_opti),
(length_hint, length_hint_opti),
]
data1 = [1] * 3 + [2] * 3 + [4] * 5 + [5] * 2 + [7] * 4 + [9] * 3 + [11] * 1 + [15] * 3
data1 = [x for x in data1 for _ in range(1000000)]
data2 = list(range(1000000))
for _ in range(3):
for name in 'data1', 'data2':
print(name)
data = eval(name)
expect = None
for funcs in funcss:
print(f'{funcs[0].__name__:11}', end='')
for func in funcs:
times = []
for _ in range(5):
start_time = timer()
result = func(data)
end_time = timer()
times.append(end_time - start_time)
print(f'{round(min(times) * 1e3):5d} ms', end='')
if expect is None:
expect = result
else:
assert result == expect
print()
print()
假设一个排序的整数列表如下:
data = [1] * 3 + [4] * 5 + [5] * 2 + [9] * 3
# [1, 1, 1, 4, 4, 4, 4, 4, 5, 5, 9, 9, 9]
我想找到值发生变化的索引,即
[3, 8, 10, 13]
一种方法是使用 itertools.groupby
:
cursor = 0
result = []
for key, group in groupby(data):
cursor += sum(1 for _ in group)
result.append(cursor)
print(result)
输出
[3, 8, 10, 13]
这种方法是 O(n)。另一种可能的方法是使用 bisect.bisect_left
:
cursor = 0
result = []
while cursor < len(data):
cursor = bisect_left(data, data[cursor] + 1, cursor, len(data))
result.append(cursor)
print(result)
输出
[3, 8, 10, 13]
这种方法是 O(k*log n),其中 k 是不同元素的数量。这种方法的一个变体是使用 exponential search.
有没有更快或更高效的方法?
我在两组数据上测试了你的方法的执行时间,并使用 numpy
data1 = [1] * 30000000 + [2] * 30000000 + [4] * 50000000 + [5] * 20000000 + [7] * 40000000 + [9] * 30000000 + [11] * 10000000 + [15] * 30000000
data2 = list(range(10000000))
cursor = 0
result = []
start_time = time.time()
for key, group in groupby(data):
cursor += sum(1 for _ in group)
result.append(cursor)
print(f'groupby {time.time() - start_time} seconds')
cursor = 0
result = []
start_time = time.time()
while cursor < len(data):
cursor = bisect_left(data, data[cursor] + 1, cursor, len(data))
result.append(cursor)
print(f'bisect_left {time.time() - start_time} seconds')
data = np.array(data)
start_time = time.time()
[i + 1 for i in np.where(data[:-1] != data[1:])[0]] + [len(data)]
print(f'numpy {time.time() - start_time} seconds')
# We need to iterate over the results array to add 1 to each index for your expected results.
与data1
groupby 8.864859104156494 seconds
bisect_left 0.0 seconds
numpy 0.27180027961730957 seconds
和data2
groupby 3.602466583251953 seconds
bisect_left 5.440978765487671 seconds
numpy 2.2847368717193604 seconds
正如您所提到的,bisect_left
在很大程度上取决于唯一元素的数量,但使用 numpy
似乎比 itertools.groupby
具有更好的性能,即使对索引进行了额外的迭代列表。
当谈到渐近复杂性时,我认为当您应用更均匀分布的分而治之方法时,平均而言,您可以在二进制搜索上略有改进:尝试首先查明更接近输入列表的中间,从而将范围分成大约两半,这将减少下一个二进制搜索操作路径。
然而,由于这是 Python,增益可能不明显,因为 Python 代码开销(如 yield
、yield from
、递归,...)。对于您使用的列表大小,它甚至可能表现更差:
from bisect import bisect_left
def locate(data, start, end):
if start >= end or data[start] == data[end - 1]:
return
mid = (start + end) // 2
val = data[mid]
if val == data[start]:
start = mid
val += 1
i = bisect_left(data, val, start + 1, end)
yield from locate(data, start, i)
yield i
yield from locate(data, i, end)
data = [1] * 3 + [4] * 5 + [5] * 2 + [9] * 3
print(*locate(data, 0, len(data))) # 3 8 10
请注意,这仅输出有效索引,因此此示例输入不包括 13。
既然你说“我对运行时更感兴趣”,这里有一些更快的 groupby
解决方案的 cursor += sum(1 for _ in group)
替代品。
使用@Guy 的 data1
但所有段长度除以 10:
normal optimized
original 870 ms 871 ms
zip_count 612 ms 611 ms
count_of 344 ms 345 ms
list_index 387 ms 386 ms
length_hint 223 ms 222 ms
改用list(range(1000000))
:
normal optimized
original 385 ms 331 ms
zip_count 463 ms 401 ms
count_of 197 ms 165 ms
list_index 175 ms 143 ms
length_hint 226 ms 127 ms
优化版本使用更多局部变量或列表理解。
我不认为 __length_hint__
是 保证 是准确的,即使对于列表迭代器也是如此,但它似乎是(通过了我的正确性检查)并且我不明白为什么它不会。
代码(Try it online!,但你必须减少一些东西才能不超过时间限制):
from timeit import default_timer as timer
from itertools import groupby, count
from collections import deque
from operator import countOf
def original(data):
cursor = 0
result = []
for key, group in groupby(data):
cursor += sum(1 for _ in group)
result.append(cursor)
return result
def original_opti(data):
cursor = 0
sum_ = sum
return [cursor := cursor + sum_(1 for _ in group)
for _, group in groupby(data)]
def zip_count(data):
cursor = 0
result = []
for key, group in groupby(data):
index = count()
deque(zip(group, index), 0)
cursor += next(index)
result.append(cursor)
return result
def zip_count_opti(data):
cursor = 0
result = []
append = result.append
count_, deque_, zip_, next_ = count, deque, zip, next
for key, group in groupby(data):
index = count_()
deque_(zip_(group, index), 0)
cursor += next_(index)
append(cursor)
return result
def count_of(data):
cursor = 0
result = []
for key, group in groupby(data):
cursor += countOf(group, key)
result.append(cursor)
return result
def count_of_opti(data):
cursor = 0
countOf_ = countOf
result = [cursor := cursor + countOf_(group, key)
for key, group in groupby(data)]
return result
def list_index(data):
cursor = 0
result = []
for key, _ in groupby(data):
cursor = data.index(key, cursor)
result.append(cursor)
del result[0]
result.append(len(data))
return result
def list_index_opti(data):
cursor = 0
index = data.index
groups = groupby(data)
next(groups, None)
result = [cursor := index(key, cursor)
for key, _ in groups]
result.append(len(data))
return result
def length_hint(data):
result = []
it = iter(data)
for _ in groupby(it):
result.append(len(data) - (1 + it.__length_hint__()))
del result[0]
result.append(len(data))
return result
def length_hint_opti(data):
it = iter(data)
groups = groupby(it)
next(groups, None)
n_minus_1 = len(data) - 1
length_hint = it.__length_hint__
result = [n_minus_1 - length_hint()
for _ in groups]
result.append(len(data))
return result
funcss = [
(original, original_opti),
(zip_count, zip_count_opti),
(count_of, count_of_opti),
(list_index, list_index_opti),
(length_hint, length_hint_opti),
]
data1 = [1] * 3 + [2] * 3 + [4] * 5 + [5] * 2 + [7] * 4 + [9] * 3 + [11] * 1 + [15] * 3
data1 = [x for x in data1 for _ in range(1000000)]
data2 = list(range(1000000))
for _ in range(3):
for name in 'data1', 'data2':
print(name)
data = eval(name)
expect = None
for funcs in funcss:
print(f'{funcs[0].__name__:11}', end='')
for func in funcs:
times = []
for _ in range(5):
start_time = timer()
result = func(data)
end_time = timer()
times.append(end_time - start_time)
print(f'{round(min(times) * 1e3):5d} ms', end='')
if expect is None:
expect = result
else:
assert result == expect
print()
print()