在 python3 中:list(iterables) 的奇怪行为

In python3: strange behaviour of list(iterables)

我有一个关于 python 中可迭代对象行为的具体问题。我的可迭代对象是 pytorch 中自定义构建的数据集 class:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

当我初始化该 datasetTest class 的特定实例时,现在会出现奇怪的行为。根据我作为参数 X 传递的数据结构,当我调用 list(datasetTestInstance) 时它的行为会有所不同。特别是,当将 torch.tensor 作为参数传递时没有问题,但是当将字典作为参数传递时,它将抛出 KeyError。这样做的原因是 list(iterable) 不只是调用 i=0, ..., len(iterable)-1,而是调用 i=0, ..., len(iterable)。也就是说,它将迭代直到(包括)索引等于可迭代的长度。显然,该索引未在任何 python 数据结构中定义,因为最后一个元素始终具有索引 len(datastructure)-1 而不是 len(datastructure)。如果 X 是 torch.tensor 或列表,则不会出现错误,即使我认为这应该是一个错误。即使对于索引为 len(datasetTestinstance) 的(不存在的)元素,它仍然会调用 getitem,但它不会计算 y=self.X[len(datasetTestInstance)。有谁知道 pytorch 是否在内部以某种方式优雅地处理这个问题?

当将字典作为数据传递时,当 x=len(datasetTestInstance) 时,它会在最后一次迭代中抛出错误。这实际上是我猜想的预期行为。但为什么这只发生在字典而不是列表或 torch.tensor?

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

如果您想更好地理解我所观察到的内容,您可以尝试这段代码。

我的问题是:

1.) 为什么 list(iterable) 迭代直到(包括)len(iterable)? for 循环不会那样做。

2.) 在 torch.tensor 或列表作为数据 X 传递的情况下:为什么即使在为索引 len(datasetTestInstance) 调用 getitem 方法时它也不会抛出错误,实际上应该是 out范围,因为它未定义为 tensor/list 中的索引?或者,换句话说,当达到索引 len(datasetTestInstance) 然后进入 getitem 方法时,到底发生了什么?它显然不再调用 'y = self.X[x]' (否则会出现 IndexError)但它确实进入了 getitem 方法,我可以看到它从 getitem 方法中打印索引 x。那么在那个方法中会发生什么?为什么它的行为会根据是否有 torch.tensor/list 或 dict?

而有所不同?

这并不是一个特定于 pytorch 的问题,而是一个普遍的 python 问题。

您正在使用 list(iterable) where an iterable class is one which implements sequence semantics 构建列表。

在这里查看 __getitem__ 对于序列类型的预期行为(最相关的部分以粗体显示)

object.__getitem__(self, key)

Called to implement evaluation of self[key]. For sequence types, the accepted keys should be integers and slice objects. Note that the special interpretation of negative indexes (if the class wishes to emulate a sequence type) is up to the __getitem__() method. If key is of an inappropriate type, TypeError may be raised; if of a value outside the set of indexes for the sequence (after any special interpretation of negative values), IndexError should be raised. For mapping types, if key is missing (not in the container), KeyError should be raised.

Note: for loops expect that an IndexError will be raised for illegal indexes to allow proper detection of the end of the sequence.

这里的问题是,对于序列类型,python 在使用无效索引调用 __getitem__ 的情况下需要一个 IndexErrorlist 构造函数似乎依赖于此行为。在您的示例中,当 X 是字典时,尝试访问无效键会导致 __getitem__ 引发 KeyError 而不是预期的情况,因此不会被捕获并导致构建列表失败。


根据此信息,您可以执行以下操作

class datasetTest:
    def __init__(self):
        self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

我不建议在实践中这样做,因为它依赖于你的字典 X 只包含整数键 01、...、len(X)-1 这意味着在大多数情况下它最终表现得就像一个列表,所以你最好只使用一个列表。

一堆有用的链接:

  1. [Python 3.Docs]: Data model - Emulating container types
  2. [Python 3.Docs]: Built-in Types - Iterator Types
  3. [Python 3.Docs]: Built-in Functions - iter(object[, sentinel])
  4. (所有答案)

关键点是 list 构造函数使用(可迭代的)参数的 __len__ ((如果提供的话)来计算新的容器长度),然后对其进行迭代(通过迭代器协议)。

由于 可怕的巧合 (记住 dict 支持迭代器协议,这发生在它的键(这是一个序列)上:

  • 您的字典只有 int 个键(以及更多)
  • 它们的值与它们的索引相同(按顺序)

改变以上 2 个项目符号表示的任何条件,会使实际错误更多 eloquent。

两个对象(dictlisttensors)都支持迭代器协议.为了使事情正常进行,您应该将其包装在 Dataset class 中,并稍微调整映射类型之一(以使用值而不是键)。
代码(key_func相关部分)有点复杂,但只是为了易于配置(如果你想改变一些东西 - 对于demo 用途).

code00.py:

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    {0:}".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: {0:}".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: {0:d}".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  {0:}".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: {0:}\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        ({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

输出:

[cfati@CFATI-5510-0:e:\Work\Dev\Whosebug\q059091544]> "e:\Work\Dev\VEnvs\py_064_03.07.03_test0\Scripts\python.exe" code00.py
Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32


Inner object: tensor([[ 0.6626,  0.1107],
        [-0.1118,  0.6177]])
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(int))
    tensor([0.6626, 0.1107])
  0: tensor([0.6626, 0.1107])
    __getitem__(1(int))
    tensor([-0.1118,  0.6177])
  1: tensor([-0.1118,  0.6177])

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  tensor([0.6626, 0.1107])
    __next__()
  tensor([-0.1118,  0.6177])
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [tensor([0.6626, 0.1107]), tensor([-0.1118,  0.6177])]


Inner object: {'1': 86, '0': 25}
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(str))
    25
  0: 25
    __getitem__(1(str))
    86
  1: 86

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  86
    __next__()
  25
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [86, 25]


Done.

您可能还想查看 [PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET (IterableDataset)。