带有迭代器的自定义缓存无法按预期工作

Custom cache with iterator does not work as intended

我得到以下 class,其中:

iterable是传递的参数,例如range(20)n_max是一个可选值,它限制了缓存应该有的元素数量,iterator是一个由可迭代对象启动的字段,cache 是我要填充的列表,finished 是一个 bool,它表示迭代器是否为“空”。这是一个示例输入:

>>> iterable = range(20)
>>> cachedtuple = CachedTuple(iterable)
>>> print(cachedtuple[0])
0
>>> print(len(cachedtuple.cache))
1
>>> print(cachedtuple[10])
10
>>> print(len(cachedtuple.cache))
11
>>> print(len(cachedtuple))
20
>>> print(len(cachedtuple.cache))
20
>>> print(cachedtuple[25])


@dataclass
class CachedTuple:
    iterable: Iterable = field(init=True)
    n_max: Optional[int] = None
    iterator: Iterator = field(init=False)
    cache: list = field(default_factory=list)
    finished: bool = False

    def __post_init__(self):
        self.iterator = iter(self.iterable)

    def cache_next(self):
        
        if self.n_max and self.n_max <= len(self.cache):
            self.finished = True
        else:
            try:
                nxt = next(self.iterator)
                self.cache.append(nxt)

            except StopIteration:
                self.finished = True


    def __getitem__(self, item: int):

        match item:
            case item if type(item) != int:
                raise IndexError

            case item if item < 0:
                raise IndexError

            case item if self.finished or self.n_max and item > self.n_max:
                raise IndexError(f"Index {item} out of range")

            case item if item >= len(self.cache):
                while item - len(self.cache) >= 0:
                    self.cache_next()

                return self.__getitem__(item)

            case item if item < len(self.cache):
                return self.cache[item]


    def __len__(self):

        while not self.finished:
            self.cache_next()
        return len(self.cache)

虽然这段代码肯定不好,但至少它几乎适用于所有场景,但以Python的range函数为例。例如,如果我使用

cachedtuple = CachedTuple(range(20))
for element in cachedtuple:
    print(element)

我得到元素直到 19 然后程序无限循环。我认为一个问题可能是我的代码中没有 raise StopIteration。所以我有点不知道如何解决这个问题。

您的错误是由于这些行造成的:

case item if item >= len(self.cache):
    while item - len(self.cache) >= 0:
        self.cache_next()

基本上,CachedTuple((1,2,3))[50]会无限循环,因为50大于缓存的长度,self.cache_next()不会生成任何新值。

添加 self.finished 检查的简单更改将起作用:

case item if item >= len(self.cache):
    while item - len(self.cache) >= 0 and not self.finished:
        self.cache_next()

但是我确实相信您的代码还有很多其他问题,我认为您可以极大地改进它:

  1. 删除匹配语句。它什么都不做。
  2. 使用__iter__实现迭代,而不是依赖__getitem__的旧迭代机制。
  3. 继承collections.abc.Sequence并遵守Sequence协议。
  4. 删除数据类。这不是数据类。您似乎喜欢令人愉快的新语言功能,但不幸的是 none 它们是相关的,这导致您的代码更长、更不清晰,并且无法按预期工作。

请记住,简单易读的代码比使用新的语言功能重要得多。


我冒昧地花了几个小时创建了一个符合 collections.abc.Sequence 的示例代码。享受吧!

from collections.abc import Sequence
import itertools
from typing import Iterable, Iterator, Optional, TypeVar, overload

_T_co =TypeVar("_T_co", covariant=True)

class CachedIterable(Sequence[_T_co]):
    def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
        self._cache: list[_T_co] = []
        
        if max_length is not None:
            if max_length <= 0:
                raise ValueError('max_length must be > 0')
            iterable = itertools.islice(iterable, max_length)
        else:
            try:
                # Attempt to optimize and get a length.
                max_length = len(iterable)  # type: ignore
            except TypeError:
                max_length = None

        self._max_length = max_length
        self._iterator: Optional[Iterator] = iter(iterable)
    
    def __repr__(self) -> str:
        return (f'<{self.__class__.__name__} {self._cache!r}'
                f'{"+" if self._iterator else ""}>')
    
    def _exhaust_iterator(self) -> None:
        """Fully exhaust the iterator."""
        assert self._iterator
        try:
            self._cache.extend(self._iterator)
        finally:
            self._iterator = None

    def _advance_iterator(self, n: int) -> None:
        """Attempt to advance the iterator by n steps.

        May advance by less than n steps if the iterator is exhausted.
        """
        assert self._iterator
        
        pre_advance_length = len(self._cache)

        try:
            self._cache.extend(itertools.islice(self._iterator, n))
        except Exception:
            # Iterator threw an exception.
            self._iterator = None
            raise

        # If iterator exhausted, clear it.
        if pre_advance_length + n > len(self._cache):
            self._iterator = None
        
    def _grow_cache(self, size: int) -> None:
        """Atttempt grow the cache to be at least size.
        
        May grow to less than size if the iterator is exhausted.
        """
        if size <= len(self._cache):
            return

        if self._max_length and size >= self._max_length:
            self._exhaust_iterator()
            return
        
        self._advance_iterator(size - len(self._cache))
    
    @overload
    def __getitem__(self, i: int) -> _T_co: ...

    @overload
    def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
        
    def __getitem__(self, index):
        if not isinstance(index, (slice, int)):
            raise TypeError(f'index must be int or slice, not {index!r}')

        if not self._iterator:
            return self._cache[index]

        if isinstance(index, slice):
            # Stop might be less than start if step is negative.
            max_index = max(index.stop or 0, index.start or 0)
            
            # If we're counting from the end, exaust the iterator.
            if (index.stop is not None and index.stop < 0 or
                    index.start is not None and index.start < 0):
                self._exhaust_iterator()
            
            else:
                self._grow_cache(max_index + 1)

            return self._cache[index]

        # Asking for a number beyond the limit.
        if self._max_length and index > self._max_length:
            raise IndexError(f'index {index} out of range')

        # If we're counting from the end, exaust the iterator.
        if index < 0:
            self._exhaust_iterator()
        else:
            self._grow_cache(index + 1)

        return self._cache[index]
    
    def __iter__(self) -> Iterator[_T_co]:
        if not self._iterator:
            yield from self._cache
            return
        
        yield from self._cache
        while True:
            try:
                item = next(self._iterator)
                # Iterator threw an exception.
            except StopIteration:
                self._iterator = None
                return
            except BaseException:
                self._iterator = None
                raise
            
            self._cache.append(item)
            # Prevent capturing GeneratorExit and other gen.throw() exceptions.
            yield item


    def __len__(self) -> int:
        # TODO: Can optimize for known lengths.
        if not self._iterator:
            return len(self._cache)

        self._exhaust_iterator()
        return len(self._cache)