使用itertools.product时如何跳过一些迭代?

How to skip some iterations when using itertools.product?

假设有三个排序列表,A、B、C。

A = [1, 2, 3, 4]
B = [3, 4, 5]
C = [2, 3, 4]

我正在使用 itertools.product 查找总和小于 10 的所有可能组合。

如果我只有三个列表,我会使用下面的代码。

A = [1, 2, 3, 4] B = [3, 4, 5] C = [2, 3, 4]

for a in A:
    for b in B:
        for c in C:
            if a + b + c < 10:
                print(a, b, c)
            else:
                break

在这里,每个列表都是排序的,因此我使用 break 来提高效率。

但是当我用itertools.product的时候,那我怎么用break呢? 我的意思是如何直接进行特定迭代(例如,a = 3、b = 3、c = 3)?

for a, b, c in itertools.product(A, B, C):
   ....?

您可以尝试以下方法:

from itertools import product, dropwhile

A = [1, 2, 3, 4] 
B = [3, 4, 5] 
C = [2, 3, 4]

for a, b, c in dropwhile(lambda x: x != (3,3,3), product(A, B, C)):
    print(a, b, c)

它给出:

3 3 3
3 3 4
3 4 2
3 4 3
3 4 4
. . .

请注意,这并没有真正直接进入给定的迭代。相反,itertools.dropwhile 运行迭代器直到满足指定条件,然后才开始返回它的值。

无法跳过 itertools.product 中的迭代,但鉴于列表已排序,可以通过使用二分搜索并查找低于所需差异的项目来减少迭代次数并使用记忆:

import itertools
import bisect


def bisect_fast(A, B, C, threshold):
    seen_b_diff = {}
    seen_c_diff = {}

    for a in A:
        b_diff = threshold - a
        if b_diff not in seen_b_diff:
            index =  bisect.bisect_left(B, b_diff)
            seen_b_diff[b_diff] = index

        # In B we are only interested in items that are less than `b_diff`
        for ind in range(seen_b_diff[b_diff]):
            b = B[ind]
            c_diff = threshold - (a + b)
            # In `C` we are only interested in items that are less than `c_diff`
            if c_diff not in seen_c_diff:
                index = bisect.bisect_left(C, c_diff)
                seen_c_diff[c_diff] = index

            for ind in range(seen_c_diff[c_diff]):
                yield a, b, C[ind] 


def naive(A, B, C, threshold):
    for a, b, c in itertools.product(A, B, C):
        if a + b + c < threshold:
            yield a, b, c

输出

>>> from random import choice
>>> A, B, C = [sorted([choice(list(range(1000))) for _ in range(250)]) for _ in range(3)]
>>> list(naive(A, B, C, 1675)) == list(bisect_fast(A, B, C, 1675))
True
>>> %timeit list(bisect_fast(A, B, C, 1675))
1.59 s ± 32.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit list(naive(A, B, C, 1675))
3.09 s ± 55.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)