带有迭代器的自定义缓存无法按预期工作
Custom cache with iterator does not work as intended
我得到以下 class,其中:
iterable
是传递的参数,例如range(20)
,n_max
是一个可选值,它限制了缓存应该有的元素数量,iterator
是一个由可迭代对象启动的字段,cache
是我要填充的列表,finished
是一个 bool
,它表示迭代器是否为“空”。这是一个示例输入:
>>> iterable = range(20)
>>> cachedtuple = CachedTuple(iterable)
>>> print(cachedtuple[0])
0
>>> print(len(cachedtuple.cache))
1
>>> print(cachedtuple[10])
10
>>> print(len(cachedtuple.cache))
11
>>> print(len(cachedtuple))
20
>>> print(len(cachedtuple.cache))
20
>>> print(cachedtuple[25])
@dataclass
class CachedTuple:
iterable: Iterable = field(init=True)
n_max: Optional[int] = None
iterator: Iterator = field(init=False)
cache: list = field(default_factory=list)
finished: bool = False
def __post_init__(self):
self.iterator = iter(self.iterable)
def cache_next(self):
if self.n_max and self.n_max <= len(self.cache):
self.finished = True
else:
try:
nxt = next(self.iterator)
self.cache.append(nxt)
except StopIteration:
self.finished = True
def __getitem__(self, item: int):
match item:
case item if type(item) != int:
raise IndexError
case item if item < 0:
raise IndexError
case item if self.finished or self.n_max and item > self.n_max:
raise IndexError(f"Index {item} out of range")
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
return self.__getitem__(item)
case item if item < len(self.cache):
return self.cache[item]
def __len__(self):
while not self.finished:
self.cache_next()
return len(self.cache)
虽然这段代码肯定不好,但至少它几乎适用于所有场景,但以Python的range函数为例。例如,如果我使用
cachedtuple = CachedTuple(range(20))
for element in cachedtuple:
print(element)
我得到元素直到 19
然后程序无限循环。我认为一个问题可能是我的代码中没有 raise StopIteration
。所以我有点不知道如何解决这个问题。
您的错误是由于这些行造成的:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
基本上,CachedTuple((1,2,3))[50]
会无限循环,因为50
大于缓存的长度,self.cache_next()
不会生成任何新值。
添加 self.finished
检查的简单更改将起作用:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0 and not self.finished:
self.cache_next()
但是我确实相信您的代码还有很多其他问题,我认为您可以极大地改进它:
- 删除匹配语句。它什么都不做。
- 使用
__iter__
实现迭代,而不是依赖__getitem__
的旧迭代机制。
- 继承
collections.abc.Sequence
并遵守Sequence
协议。
- 删除数据类。这不是数据类。您似乎喜欢令人愉快的新语言功能,但不幸的是 none 它们是相关的,这导致您的代码更长、更不清晰,并且无法按预期工作。
请记住,简单易读的代码比使用新的语言功能重要得多。
我冒昧地花了几个小时创建了一个符合 collections.abc.Sequence
的示例代码。享受吧!
from collections.abc import Sequence
import itertools
from typing import Iterable, Iterator, Optional, TypeVar, overload
_T_co =TypeVar("_T_co", covariant=True)
class CachedIterable(Sequence[_T_co]):
def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
self._cache: list[_T_co] = []
if max_length is not None:
if max_length <= 0:
raise ValueError('max_length must be > 0')
iterable = itertools.islice(iterable, max_length)
else:
try:
# Attempt to optimize and get a length.
max_length = len(iterable) # type: ignore
except TypeError:
max_length = None
self._max_length = max_length
self._iterator: Optional[Iterator] = iter(iterable)
def __repr__(self) -> str:
return (f'<{self.__class__.__name__} {self._cache!r}'
f'{"+" if self._iterator else ""}>')
def _exhaust_iterator(self) -> None:
"""Fully exhaust the iterator."""
assert self._iterator
try:
self._cache.extend(self._iterator)
finally:
self._iterator = None
def _advance_iterator(self, n: int) -> None:
"""Attempt to advance the iterator by n steps.
May advance by less than n steps if the iterator is exhausted.
"""
assert self._iterator
pre_advance_length = len(self._cache)
try:
self._cache.extend(itertools.islice(self._iterator, n))
except Exception:
# Iterator threw an exception.
self._iterator = None
raise
# If iterator exhausted, clear it.
if pre_advance_length + n > len(self._cache):
self._iterator = None
def _grow_cache(self, size: int) -> None:
"""Atttempt grow the cache to be at least size.
May grow to less than size if the iterator is exhausted.
"""
if size <= len(self._cache):
return
if self._max_length and size >= self._max_length:
self._exhaust_iterator()
return
self._advance_iterator(size - len(self._cache))
@overload
def __getitem__(self, i: int) -> _T_co: ...
@overload
def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
def __getitem__(self, index):
if not isinstance(index, (slice, int)):
raise TypeError(f'index must be int or slice, not {index!r}')
if not self._iterator:
return self._cache[index]
if isinstance(index, slice):
# Stop might be less than start if step is negative.
max_index = max(index.stop or 0, index.start or 0)
# If we're counting from the end, exaust the iterator.
if (index.stop is not None and index.stop < 0 or
index.start is not None and index.start < 0):
self._exhaust_iterator()
else:
self._grow_cache(max_index + 1)
return self._cache[index]
# Asking for a number beyond the limit.
if self._max_length and index > self._max_length:
raise IndexError(f'index {index} out of range')
# If we're counting from the end, exaust the iterator.
if index < 0:
self._exhaust_iterator()
else:
self._grow_cache(index + 1)
return self._cache[index]
def __iter__(self) -> Iterator[_T_co]:
if not self._iterator:
yield from self._cache
return
yield from self._cache
while True:
try:
item = next(self._iterator)
# Iterator threw an exception.
except StopIteration:
self._iterator = None
return
except BaseException:
self._iterator = None
raise
self._cache.append(item)
# Prevent capturing GeneratorExit and other gen.throw() exceptions.
yield item
def __len__(self) -> int:
# TODO: Can optimize for known lengths.
if not self._iterator:
return len(self._cache)
self._exhaust_iterator()
return len(self._cache)
我得到以下 class,其中:
iterable
是传递的参数,例如range(20)
,n_max
是一个可选值,它限制了缓存应该有的元素数量,iterator
是一个由可迭代对象启动的字段,cache
是我要填充的列表,finished
是一个 bool
,它表示迭代器是否为“空”。这是一个示例输入:
>>> iterable = range(20)
>>> cachedtuple = CachedTuple(iterable)
>>> print(cachedtuple[0])
0
>>> print(len(cachedtuple.cache))
1
>>> print(cachedtuple[10])
10
>>> print(len(cachedtuple.cache))
11
>>> print(len(cachedtuple))
20
>>> print(len(cachedtuple.cache))
20
>>> print(cachedtuple[25])
@dataclass
class CachedTuple:
iterable: Iterable = field(init=True)
n_max: Optional[int] = None
iterator: Iterator = field(init=False)
cache: list = field(default_factory=list)
finished: bool = False
def __post_init__(self):
self.iterator = iter(self.iterable)
def cache_next(self):
if self.n_max and self.n_max <= len(self.cache):
self.finished = True
else:
try:
nxt = next(self.iterator)
self.cache.append(nxt)
except StopIteration:
self.finished = True
def __getitem__(self, item: int):
match item:
case item if type(item) != int:
raise IndexError
case item if item < 0:
raise IndexError
case item if self.finished or self.n_max and item > self.n_max:
raise IndexError(f"Index {item} out of range")
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
return self.__getitem__(item)
case item if item < len(self.cache):
return self.cache[item]
def __len__(self):
while not self.finished:
self.cache_next()
return len(self.cache)
虽然这段代码肯定不好,但至少它几乎适用于所有场景,但以Python的range函数为例。例如,如果我使用
cachedtuple = CachedTuple(range(20))
for element in cachedtuple:
print(element)
我得到元素直到 19
然后程序无限循环。我认为一个问题可能是我的代码中没有 raise StopIteration
。所以我有点不知道如何解决这个问题。
您的错误是由于这些行造成的:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
基本上,CachedTuple((1,2,3))[50]
会无限循环,因为50
大于缓存的长度,self.cache_next()
不会生成任何新值。
添加 self.finished
检查的简单更改将起作用:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0 and not self.finished:
self.cache_next()
但是我确实相信您的代码还有很多其他问题,我认为您可以极大地改进它:
- 删除匹配语句。它什么都不做。
- 使用
__iter__
实现迭代,而不是依赖__getitem__
的旧迭代机制。 - 继承
collections.abc.Sequence
并遵守Sequence
协议。 - 删除数据类。这不是数据类。您似乎喜欢令人愉快的新语言功能,但不幸的是 none 它们是相关的,这导致您的代码更长、更不清晰,并且无法按预期工作。
请记住,简单易读的代码比使用新的语言功能重要得多。
我冒昧地花了几个小时创建了一个符合 collections.abc.Sequence
的示例代码。享受吧!
from collections.abc import Sequence
import itertools
from typing import Iterable, Iterator, Optional, TypeVar, overload
_T_co =TypeVar("_T_co", covariant=True)
class CachedIterable(Sequence[_T_co]):
def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
self._cache: list[_T_co] = []
if max_length is not None:
if max_length <= 0:
raise ValueError('max_length must be > 0')
iterable = itertools.islice(iterable, max_length)
else:
try:
# Attempt to optimize and get a length.
max_length = len(iterable) # type: ignore
except TypeError:
max_length = None
self._max_length = max_length
self._iterator: Optional[Iterator] = iter(iterable)
def __repr__(self) -> str:
return (f'<{self.__class__.__name__} {self._cache!r}'
f'{"+" if self._iterator else ""}>')
def _exhaust_iterator(self) -> None:
"""Fully exhaust the iterator."""
assert self._iterator
try:
self._cache.extend(self._iterator)
finally:
self._iterator = None
def _advance_iterator(self, n: int) -> None:
"""Attempt to advance the iterator by n steps.
May advance by less than n steps if the iterator is exhausted.
"""
assert self._iterator
pre_advance_length = len(self._cache)
try:
self._cache.extend(itertools.islice(self._iterator, n))
except Exception:
# Iterator threw an exception.
self._iterator = None
raise
# If iterator exhausted, clear it.
if pre_advance_length + n > len(self._cache):
self._iterator = None
def _grow_cache(self, size: int) -> None:
"""Atttempt grow the cache to be at least size.
May grow to less than size if the iterator is exhausted.
"""
if size <= len(self._cache):
return
if self._max_length and size >= self._max_length:
self._exhaust_iterator()
return
self._advance_iterator(size - len(self._cache))
@overload
def __getitem__(self, i: int) -> _T_co: ...
@overload
def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
def __getitem__(self, index):
if not isinstance(index, (slice, int)):
raise TypeError(f'index must be int or slice, not {index!r}')
if not self._iterator:
return self._cache[index]
if isinstance(index, slice):
# Stop might be less than start if step is negative.
max_index = max(index.stop or 0, index.start or 0)
# If we're counting from the end, exaust the iterator.
if (index.stop is not None and index.stop < 0 or
index.start is not None and index.start < 0):
self._exhaust_iterator()
else:
self._grow_cache(max_index + 1)
return self._cache[index]
# Asking for a number beyond the limit.
if self._max_length and index > self._max_length:
raise IndexError(f'index {index} out of range')
# If we're counting from the end, exaust the iterator.
if index < 0:
self._exhaust_iterator()
else:
self._grow_cache(index + 1)
return self._cache[index]
def __iter__(self) -> Iterator[_T_co]:
if not self._iterator:
yield from self._cache
return
yield from self._cache
while True:
try:
item = next(self._iterator)
# Iterator threw an exception.
except StopIteration:
self._iterator = None
return
except BaseException:
self._iterator = None
raise
self._cache.append(item)
# Prevent capturing GeneratorExit and other gen.throw() exceptions.
yield item
def __len__(self) -> int:
# TODO: Can optimize for known lengths.
if not self._iterator:
return len(self._cache)
self._exhaust_iterator()
return len(self._cache)