烦人的生成器错误

Annoying generator bug

此错误的原始上下文是一段代码太大,无法 post 在这样的问题中。我不得不将这段代码缩减为仍然存在错误的最小片段。这就是下面显示的代码看起来有些奇怪的原因。

在下面的代码中,class Foo 可能被认为是一种复杂的方式来获得类似 xrange.

的东西
class Foo(object):
    def __init__(self, n):
        self.generator = (x for x in range(n))

    def __iter__(self):
        for e in self.generator:
            yield e

的确,Foo 的行为似乎很像 xrange:

for c in Foo(3):
    print c
# 0
# 1
# 2

print list(Foo(3))
# [0, 1, 2]

现在Foo的子classBar只增加了一个__len__方法:

class Bar(Foo):
    def __len__(self):
        return sum(1 for _ in self.generator)

Barfor 循环中使用时的行为与 Foo 相同:

for c in Bar(3):
    print c
# 0
# 1
# 2

但是:

print list(Bar(3))
# []

我的猜测是,在 list(Bar(3)) 的计算中,Bar(3)__len__ 方法被调用,从而用完了生成器。

(如果这个猜测是正确的,就不需要调用 Bar(3).__len__;毕竟,即使 Foo 没有 __len__ 方法,list(Foo(3)) 也会产生正确的结果.)

这种情况很烦人:list(Foo(3))list(Bar(3)) 没有充分的理由产生不同的结果。

是否可以修复 Bar(当然,不用摆脱它的 __len__ 方法)使得 list(Bar(3)) returns [0, 1, 2]

你的问题是 Foo 的行为与 xrange 不同:每次你询问它的 iter 方法时,xrange 都会给你一个新的迭代器,而 Foo 总是给你相同的,这意味着一旦它耗尽对象太:

>>> a = Foo(3)
>>> list(a)
[0, 1, 2]
>>> list(a)
[]
>>> a = range(3)
>>> list(a)
[0, 1, 2]
>>> list(a)
[0, 1, 2]

我可以很容易地确认 __len__ 方法被 list 调用,方法是在你的方法中添加间谍:

class Bar(Foo):
    def __len__(self):
        print "LEN"
        return sum(1 for _ in self.generator)

(我在 Foo.__iter__ 中添加了一个 print "ITERATOR")。它产生:

>>> list(Bar(3))
LEN
ITERATOR
[]

我只能想象两种解决方法:

  1. 我的首选:return 每次在 Foo 级别调用 __iter__ 时使用一个新的迭代器来模仿 xrange:

    class Foo(object):
        def __init__(self, n):
            self.n = n
    
        def __iter__(self):
            print "ITERATOR"
            return ( x for x in range(self.n))
    
    class Bar(Foo):
        def __len__(self):
            print "LEN"
            return sum(1 for _ in self.generator)
    

    我们得到正确的:

    >>> list(Bar(3))
    ITERATOR
    LEN
    ITERATOR
    [0, 1, 2]
    
  2. 备选方案:将 len 更改为不调用迭代器并让 Foo 保持不变:

    class Bar(Foo):
        def __init__(self, n):
            self.len  = n
            super(Bar, self).__init__(n)
        def __len__(self):
            print "LEN"
            return self.len
    

    我们又得到了:

    >>> list(Bar(3))
    LEN
    ITERATOR
    [0, 1, 2]
    

    但是 Foo 和 Bar 对象一旦第一个迭代器到达其终点就被耗尽。

但是我必须承认我不知道你真实的上下文类...

这种行为可能很烦人,但实际上是可以理解的。在内部 list 只是一个数组,而数组是一个固定大小的数据结构。这样做的结果是,如果你有一个大小为 nlist 并且你想添加一个额外的项目以达到 n+1 它将必须创建一个全新的数组并完全复制旧的到新的。实际上,您的 list.append(x) 现在是 O(n) 操作,而不是常规的 O(1).

为防止这种情况,list() 会尝试获取输入的大小,以便它可以猜测数组需要的大小。

所以这个问题的一个解决方案是使用 iter:

强制它猜测
list(iter(Bar(3)))