zip 迭代器在 python 中断言等长

zip iterators asserting for equal length in python

我正在寻找一种很好的方法来 zip 多个可迭代对象在可迭代对象的长度不相等时引发异常。

在可迭代对象是列表或具有 len 方法的情况下,此解决方案简洁明了:

def zip_equal(it1, it2):
    if len(it1) != len(it2):
        raise ValueError("Lengths of iterables are different")
    return zip(it1, it2)

但是,如果it1it2是生成器,前面的函数会失败,因为没有定义长度TypeError: object of type 'generator' has no len()

我想 itertools 模块提供了一种简单的方法来实现它,但到目前为止我还没有找到它。我想出了这个自制的解决方案:

def zip_equal(it1, it2):
    exhausted = False
    while True:
        try:
            el1 = next(it1)
            if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            exhausted = True
            # it2 must be exhausted too.
        try:
            el2 = next(it2)
            # here it2 is not exhausted.
            if exhausted:  # it1 was exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            # here it2 is exhausted
            if not exhausted:
                # but it1 was not exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
            exhausted = True
        if not exhausted:
            yield (el1, el2)
        else:
            return

解决方案可以通过以下代码进行测试:

it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it1, it2))           # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it2, it1))           # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd'])  # it1 has length 4
it2 = (x for x in [0, 1, 2, 3])          # it2 has length 4
list(zip_equal(it1, it2))                # like zip (or izip in python2)

我是否忽略了任何替代解决方案?我的 zip_equal 函数有更简单的实现吗?

更新:

我可以想到一个更简单的解决方案,使用 itertools.zip_longest() 并在生成的元组中存在用于填充较短迭代的标记值时引发异常:

from itertools import zip_longest

def zip_equal(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

不幸的是,我们不能将 zip()yield from 一起使用,以避免每次迭代都进行测试的 Python 代码循环;一旦最短的迭代器用完,zip() 将推进所有前面的迭代器,因此如果其中只有一个额外的项目,则吞没证据。

这是一种不需要对迭代的每个循环进行任何额外检查的方法。这可能是可取的,尤其是对于长迭代。

想法是在末尾用 "value" 填充每个可迭代对象,当到达时引发异常,然后仅在最后进行所需的验证。该方法使用 zip()itertools.chain().

下面的代码是为 Python 3.5.

编写的
import itertools

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None

下面是它的使用情况。

>>> list(zip_equal([1, 2], [3, 4], [5, 6]))
[(1, 3, 5), (2, 4, 6)]

>>> list(zip_equal([1, 2], [3], [4]))
RuntimeError: iterable 1 exhausted first

>>> list(zip_equal([1], [2, 3], [4]))
RuntimeError: iterable 1 is longer

>>> list(zip_equal([1], [2], [3, 4]))
RuntimeError: iterable 2 is longer

我想出了一个使用 sentinel iterable FYI 的解决方案:

class _SentinelException(Exception):
    def __iter__(self):
        raise _SentinelException


def zip_equal(iterable1, iterable2):
    i1 = iter(itertools.chain(iterable1, _SentinelException()))
    i2 = iter(iterable2)
    try:
        while True:
            yield (next(i1), next(i2))
    except _SentinelException:  # i1 reaches end
        try:
            next(i2)  # check whether i2 reaches end
        except StopIteration:
            pass
        else:
            raise ValueError('the second iterable is longer than the first one')
    except StopIteration: # i2 reaches end, as next(i1) has already been called, i1's length is bigger than i2
        raise ValueError('the first iterable is longger the second one.')

使用more_itertools.zip_equal (v8.3.0+):

代码

import more_itertools as mit

演示

list(mit.zip_equal(range(3), "abc"))
# [(0, 'a'), (1, 'b'), (2, 'c')]

list(mit.zip_equal(range(3), "abcd"))
# UnequalIterablesError

more_itertools 是通过 λ pip install more_itertools

安装的第三方包

PEP 618 中的内置 zip 函数引入了一个可选的布尔关键字参数 strict

引用 What’s New In Python 3.10:

The zip() function now has an optional strict flag, used to require that all the iterables have an equal length.

启用后,如果其中一个参数在其他参数之前用完,则会引发 ValueError

>>> list(zip('ab', range(3)))
[('a', 0), ('b', 1)]
>>> list(zip('ab', range(3), strict=True))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: zip() argument 2 is longer than argument 1

一个比它所基于的 cjerdonek 更快的新解决方案,以及一个基准。基准优先,我的解决方案是绿色的。请注意,“总大小”在所有情况下都是相同的,都是 200 万个值。 x 轴是 iterables 的数量。从 1 个具有 200 万个值的可迭代对象,然后 2 个每个具有 100 万个值的可迭代对象,一直到 100,000 个每个具有 20 个值的可迭代对象。

黑色的是Python的zip,我这里用的是Python3.8所以它不做这个题的检查等长的任务,但是我把它包括进去了作为人们所希望的最大速度的 reference/limit。你可以看到我的解决方案非常接近。

对于压缩 两个 可迭代对象的最常见情况,我的速度几乎是 cjerdonek 之前最快解决方案的三倍,并且不比 zip 慢多少.文本时间:

         number of iterables     1     2     3     4     5    10   100  1000 10000 50000 100000
-----------------------------------------------------------------------------------------------
       more_itertools__pylang 209.3 132.1 105.8  93.7  87.4  74.4  54.3  51.9  53.9  66.9  84.5
   fillvalue__Martijn_Pieters 159.1 101.5  85.6  74.0  68.8  59.0  44.1  43.0  44.9  56.9  72.0
     chain_raising__cjerdonek  58.5  35.1  26.3  21.9  19.7  16.6  10.4  12.7  34.4 115.2 223.2
     ziptail__Stefan_Pochmann  10.3  12.4  10.4   9.2   8.7   7.8   6.7   6.8   9.4  22.6  37.8
                          zip  10.3   8.5   7.8   7.4   7.4   7.1   6.4   6.8   9.0  19.4  32.3

我的代码 (Try it online!):

def zip_equal(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError('zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError('zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

基本思路是让zip(*iterables)做所有的工作,然后在因为一些iterable耗尽而停止后,检查所有iterables的长度是否相等。他们是当且仅当:

  1. zip 停止,因为 first 可迭代对象没有其他元素(即,没有其他可迭代对象 更短 ) .
  2. 其他可迭代对象的
  3. None 有任何进一步的元素(即,没有其他可迭代对象是 更长的 )。

我如何检查这些条件:

  • 由于我需要在 zip 结束后检查这些条件,所以我不能 return 纯粹地 zip 对象。相反,我在它后面链接了一个空的 zip_tail 迭代器来进行检查。
  • 为了支持检查第一个标准,我在它后面链接了一个空的 first_tail 迭代器,它的唯一工作是记录第一个迭代器的迭代停止(即,它被要求提供另一个元素但它没有有一个,所以 first_tail 迭代器被要求一个)。
  • 为了支持检查第二个标准,我获取所有其他可迭代对象的迭代器并将它们保存在列表中,然后再将它们提供给 zip

旁注:more-itertools 几乎使用与 Martijn 相同的方法,但进行适当的 is 检查而不是 Martijn 的 sentinel in combo。这可能是它变慢的主要原因。

基准代码(Try it online!):

import timeit
import itertools
from itertools import repeat, chain, zip_longest
from collections import deque
from sys import hexversion, maxsize

#-----------------------------------------------------------------------------
# Solution by Martijn Pieters
#-----------------------------------------------------------------------------

def zip_equal__fillvalue__Martijn_Pieters(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

#-----------------------------------------------------------------------------
# Solution by pylang
#-----------------------------------------------------------------------------

def zip_equal__more_itertools__pylang(*iterables):
    return more_itertools__zip_equal(*iterables)

_marker = object()

def _zip_equal_generator(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo

def more_itertools__zip_equal(*iterables):
    """``zip`` the input *iterables* together, but raise
    ``UnequalIterablesError`` if they aren't all the same length.

        >>> it_1 = range(3)
        >>> it_2 = iter('abc')
        >>> list(zip_equal(it_1, it_2))
        [(0, 'a'), (1, 'b'), (2, 'c')]

        >>> it_1 = range(3)
        >>> it_2 = iter('abcd')
        >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
        ...
        more_itertools.more.UnequalIterablesError: Iterables have different
        lengths

    """
    if hexversion >= 0x30A00A6:
        warnings.warn(
            (
                'zip_equal will be removed in a future version of '
                'more-itertools. Use the builtin zip function with '
                'strict=True instead.'
            ),
            DeprecationWarning,
        )
    # Check whether the iterables are all the same size.
    try:
        first_size = len(iterables[0])
        for i, it in enumerate(iterables[1:], 1):
            size = len(it)
            if size != first_size:
                break
        else:
            # If we didn't break out, we can use the built-in zip.
            return zip(*iterables)

        # If we did break out, there was a mismatch.
        raise UnequalIterablesError(details=(first_size, i, size))
    # If any one of the iterables didn't have a length, start reading
    # them until one runs out.
    except TypeError:
        return _zip_equal_generator(iterables)

#-----------------------------------------------------------------------------
# Solution by cjerdonek
#-----------------------------------------------------------------------------

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal__chain_raising__cjerdonek(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None
            
#-----------------------------------------------------------------------------
# Solution by Stefan Pochmann
#-----------------------------------------------------------------------------

def zip_equal__ziptail__Stefan_Pochmann(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError(f'zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError(f'zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

#-----------------------------------------------------------------------------
# List of solutions to be speedtested
#-----------------------------------------------------------------------------

solutions = [
    zip_equal__more_itertools__pylang,
    zip_equal__fillvalue__Martijn_Pieters,
    zip_equal__chain_raising__cjerdonek,
    zip_equal__ziptail__Stefan_Pochmann,
    zip,
]

def name(solution):
    return solution.__name__[11:] or 'zip'

#-----------------------------------------------------------------------------
# The speedtest code
#-----------------------------------------------------------------------------

def test(m, n):
    """Speedtest all solutions with m iterables of n elements each."""

    all_times = {solution: [] for solution in solutions}
    def show_title():
        print(f'{m} iterators of length {n:,}:')
    if verbose: show_title()
    def show_times(times, solution):
        print(*('%3d ms ' % t for t in times),
              name(solution))
        
    for _ in range(3):
        for solution in solutions:
            times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
            times = [round(t * 1e3, 3) for t in times]
            all_times[solution].append(times)
            if verbose: show_times(times, solution)
        if verbose: print()
        
    if verbose:
        print('best by min:')
        show_title()
        for solution in solutions:
            show_times(min(all_times[solution], key=min), solution)
        print('best by max:')
    show_title()
    for solution in solutions:
        show_times(min(all_times[solution], key=max), solution)
    print()

    stats.append((m,
                  [min(all_times[solution], key=min)
                   for solution in solutions]))

#-----------------------------------------------------------------------------
# Run the speedtest for several numbers of iterables
#-----------------------------------------------------------------------------

stats = []
verbose = False
total_elements = 2 * 10**6
for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
    test(m, total_elements // m)

#-----------------------------------------------------------------------------
# Print the speedtest results for use in the plotting script
#-----------------------------------------------------------------------------

print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
names = [name(solution) for solution in solutions]
print(f'{names = }')
print(f'{stats = }')

plotting/table(也at Replit)的代码:

import matplotlib.pyplot as plt

names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]

colors = {
    'more_itertools__pylang': 'm',
    'fillvalue__Martijn_Pieters': 'red',
    'chain_raising__cjerdonek': 'gold',
    'ziptail__Stefan_Pochmann': 'lime',
    'zip': 'black',
}

ns = [n for n, _ in stats]
print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
print('-' * 95)
x = range(len(ns))
for i, name in enumerate(names):
    ts = [min(tss[i]) for _, tss in stats]
    color = colors[name]
    if color:
        plt.plot(x, ts, '.-', color=color, label=name)
        print('%29s' % name, *('%5.1f' % t for t in ts))
plt.xticks(x, ns, size=9)
plt.ylim(0, 133)
plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
plt.legend(loc='upper center')
#plt.show()
plt.savefig('zip_equal_plot.png', dpi=200)