找到整数排序列表发生变化的索引

Find the indices where a sorted list of integer changes

假设一个排序的整数列表如下:

data = [1] * 3 + [4] * 5 + [5] * 2 + [9] * 3
# [1, 1, 1, 4, 4, 4, 4, 4, 5, 5, 9, 9, 9]

我想找到值发生变化的索引,即

[3, 8, 10, 13]

一种方法是使用 itertools.groupby:

cursor = 0
result = []
for key, group in groupby(data):
    cursor += sum(1 for _ in group)
    result.append(cursor)
print(result)

输出

[3, 8, 10, 13]

这种方法是 O(n)。另一种可能的方法是使用 bisect.bisect_left:

cursor = 0
result = []
while cursor < len(data):
    cursor = bisect_left(data, data[cursor] + 1, cursor, len(data))
    result.append(cursor)
print(result)

输出

[3, 8, 10, 13]

这种方法是 O(k*log n),其中 k 是不同元素的数量。这种方法的一个变体是使用 exponential search.

有没有更快或更高效的方法?

我在两组数据上测试了你的方法的执行时间,并使用 numpy

添加了第三组
data1 = [1] * 30000000 + [2] * 30000000 + [4] * 50000000 + [5] * 20000000 + [7] * 40000000 + [9] * 30000000 + [11] * 10000000 + [15] * 30000000
data2 = list(range(10000000))

cursor = 0
result = []
start_time = time.time()
for key, group in groupby(data):
    cursor += sum(1 for _ in group)
    result.append(cursor)
print(f'groupby {time.time() - start_time} seconds')

cursor = 0
result = []
start_time = time.time()
while cursor < len(data):
    cursor = bisect_left(data, data[cursor] + 1, cursor, len(data))
    result.append(cursor)
print(f'bisect_left {time.time() - start_time} seconds')

data = np.array(data)
start_time = time.time()
[i + 1 for i in np.where(data[:-1] != data[1:])[0]] + [len(data)]
print(f'numpy {time.time() - start_time} seconds')

# We need to iterate over the results array to add 1 to each index for your expected results.

data1

groupby 8.864859104156494 seconds
bisect_left 0.0 seconds
numpy 0.27180027961730957 seconds

data2

groupby 3.602466583251953 seconds
bisect_left 5.440978765487671 seconds
numpy 2.2847368717193604 seconds

正如您所提到的,bisect_left 在很大程度上取决于唯一元素的数量,但使用 numpy 似乎比 itertools.groupby 具有更好的性能,即使对索引进行了额外的迭代列表。

当谈到渐近复杂性时,我认为当您应用更均匀分布的分而治之方法时,平均而言,您可以在二进制搜索上略有改进:尝试首先查明更接近输入列表的中间,从而将范围分成大约两半,这将减少下一个二进制搜索操作路径。

然而,由于这是 Python,增益可能不明显,因为 Python 代码开销(如 yieldyield from、递归,...)。对于您使用的列表大小,它甚至可能表现更差:

from bisect import bisect_left

def locate(data, start, end):
    if start >= end or data[start] == data[end - 1]:
        return
    mid = (start + end) // 2
    val = data[mid] 
    if val == data[start]:
        start = mid
        val += 1
    i = bisect_left(data, val, start + 1, end)
    yield from locate(data, start, i)
    yield i
    yield from locate(data, i, end)

data = [1] * 3 + [4] * 5 + [5] * 2 + [9] * 3
print(*locate(data, 0, len(data)))  # 3 8 10

请注意,这仅输出有效索引,因此此示例输入不包括 13。

既然你说“我对运行时更感兴趣”,这里有一些更快的 groupby 解决方案的 cursor += sum(1 for _ in group) 替代品。

使用@Guy 的 data1 但所有段长度除以 10:

             normal  optimized
original     870 ms  871 ms
zip_count    612 ms  611 ms
count_of     344 ms  345 ms
list_index   387 ms  386 ms
length_hint  223 ms  222 ms

改用list(range(1000000))

             normal  optimized
original     385 ms  331 ms
zip_count    463 ms  401 ms
count_of     197 ms  165 ms
list_index   175 ms  143 ms
length_hint  226 ms  127 ms

优化版本使用更多局部变量或列表理解。

我不认为 __length_hint__ 保证 是准确的,即使对于列表迭代器也是如此,但它似乎是(通过了我的正确性检查)并且我不明白为什么它不会。

代码(Try it online!,但你必须减少一些东西才能不超过时间限制):

from timeit import default_timer as timer
from itertools import groupby, count
from collections import deque
from operator import countOf

def original(data):
    cursor = 0
    result = []
    for key, group in groupby(data):
        cursor += sum(1 for _ in group)
        result.append(cursor)
    return result

def original_opti(data):
    cursor = 0
    sum_ = sum
    return [cursor := cursor + sum_(1 for _ in group)
            for _, group in groupby(data)]

def zip_count(data):
    cursor = 0
    result = []
    for key, group in groupby(data):
        index = count()
        deque(zip(group, index), 0)
        cursor += next(index)
        result.append(cursor)
    return result

def zip_count_opti(data):
    cursor = 0
    result = []
    append = result.append
    count_, deque_, zip_, next_ = count, deque, zip, next
    for key, group in groupby(data):
        index = count_()
        deque_(zip_(group, index), 0)
        cursor += next_(index)
        append(cursor)
    return result

def count_of(data):
    cursor = 0
    result = []
    for key, group in groupby(data):
        cursor += countOf(group, key)
        result.append(cursor)
    return result

def count_of_opti(data):
    cursor = 0
    countOf_ = countOf
    result = [cursor := cursor + countOf_(group, key)
              for key, group in groupby(data)]
    return result

def list_index(data):
    cursor = 0
    result = []
    for key, _ in groupby(data):
        cursor = data.index(key, cursor)
        result.append(cursor)
    del result[0]
    result.append(len(data))
    return result

def list_index_opti(data):
    cursor = 0
    index = data.index
    groups = groupby(data)
    next(groups, None)
    result = [cursor := index(key, cursor)
              for key, _ in groups]
    result.append(len(data))
    return result

def length_hint(data):
    result = []
    it = iter(data)
    for _ in groupby(it):
        result.append(len(data) - (1 + it.__length_hint__()))
    del result[0]
    result.append(len(data))
    return result

def length_hint_opti(data):
    it = iter(data)
    groups = groupby(it)
    next(groups, None)
    n_minus_1 = len(data) - 1
    length_hint = it.__length_hint__
    result = [n_minus_1 - length_hint()
              for _ in groups]
    result.append(len(data))
    return result

funcss = [
    (original, original_opti),
    (zip_count, zip_count_opti),
    (count_of, count_of_opti),
    (list_index, list_index_opti),
    (length_hint, length_hint_opti),
]

data1 = [1] * 3 + [2] * 3 + [4] * 5 + [5] * 2 + [7] * 4 + [9] * 3 + [11] * 1 + [15] * 3
data1 = [x for x in data1 for _ in range(1000000)]
data2 = list(range(1000000))

for _ in range(3):
    for name in 'data1', 'data2':
        print(name)
        data = eval(name)
        expect = None
        for funcs in funcss:
            print(f'{funcs[0].__name__:11}', end='')
            for func in funcs:
                times = []
                for _ in range(5):
                    start_time = timer()
                    result = func(data)
                    end_time = timer()
                    times.append(end_time - start_time)
                print(f'{round(min(times) * 1e3):5d} ms', end='')
                if expect is None:
                    expect = result
                else:
                    assert result == expect
            print()
        print()