Pytorch 自定义数据集的 __getitem__ 在处理异常时无限期地调用自身

Pytorch Custom dataset's __getitem__ calls itself indefinitely when handling exception

我正在为我的 customdatset class 编写脚本,但是每当我像这样使用 for 循环访问数据时,我都会收到 Index out of range 错误:

cd = CustomDataset(df)
for img, target in cd:
   pass

我意识到我可能无法读取一些图像(如果它们已损坏)所以我实现了一个 random_on_error 功能,如果当前图像有问题,该功能会选择随机图像。我确信这就是问题所在。正如我注意到的那样,数据集中的所有 2160 张图像都被读取而没有任何问题(我为每次迭代打印索引号)但循环不会停止并读取第 2161 张图像,这导致 Index out of range 异常通过读取随机图像来处理。这将永远持续下去。

这是我的 class:

class CustomDataset(Dataset):
    def __init__(self, data: pd.DataFrame, augmentations=None, exit_on_error=False, random_on_error: bool = True):
        """
        :param data: Pandas dataframe with paths as first column and target as second column
        :param augmentations: Image transformations
        :param exit_on_error: Stop execution once an exception rises. Cannot be used in conjunction with random_on_error
        :param random_on_error: Upon an exception while reading an image, pick a random image and process it instead.
        Cannot be used in conjuntion with exit_on_error.
        """
 
        if exit_on_error and random_on_error:
            raise ValueError("Only one of 'exit_on_error' and 'random_on_error' can be true")
 
        self.image_paths = data.iloc[:, 0].to_numpy()
        self.targets = data.iloc[:, 1].to_numpy()
        self.augmentations = augmentations
        self.exit_on_error = exit_on_error
        self.random_on_error = random_on_error
 
    def __len__(self):
        return self.image_paths.shape[0]
 
    def __getitem__(self, index):
        image, target = None, None
        try:
            image, target = self.read_image_data(index)
        except:
            print(f"Exception occurred while reading image, {index}")
            if self.exit_on_error:
                print(self.image_paths[index])
                raise
            if self.random_on_error:
                random_index = np.random.randint(0, self.__len__())
                print(f"Replacing with random image, {random_index}")
                image, target = self.read_image_data(random_index)
 
            else:  # todo implement return logic when self.random_on_error is false
                return
 
        if self.augmentations is not None:
            aug_image = self.augmentations(image=image)
            image = aug_image["image"]
 
        image = np.transpose(image, (2, 0, 1))
 
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(target, dtype=torch.long)
        )
 
    def read_image_data(self, index: int) -> ImagePlusTarget: 
        # reads image, converts to 3 channel ndarray if image is grey scale and converts rgba to rgb (if applicable)
        target = self.targets[index]
        image = io.imread(self.image_paths[index])
        if image.ndim == 2:
            image = np.expand_dims(image, 2)
        if image.shape[2] > 3:
            image = color.rgba2rgb(image)
 
        return image, target

我认为问题出在 __getitem__() 中的 except 块(第 27 行),因为当我删除它时代码工作正常。但是我看不出这里的问题是什么。

感谢任何帮助,谢谢

您如何期望 python 知道何时停止阅读您的 CustomDataset

CustomDataset 中定义方法 __getitem__ 使其成为 iterable object in python. That is, python can iterate over CustomDataset's items one by one. However, the iterable object must raise either StopIteration or IndexError 以便 python 知道它已到达迭代的末尾。

您可以将循环更改为 明确地 使用数据集的 __len__

for i in range(len(cd)):
  img, target = cd[i] 

或者,如果 index 超出范围,您应该确保从数据集中 raise IndexError。这可以使用 multiple except clauses.
来完成 类似于:

try: 
  image, target = self.read_image_data(index)
except IndexError:
  raise  # do not handle this error
except:
  # treat all other exceptions (corrupt images) here
  ...