为什么我们需要自定义数据集 class 以及在 NLP、BERT 微调等中使用 _getitem_ 方法

Why do we need the custom dataset class and use of _getitem_ method in NLP, BERT fine tuning etc

本人是NLP新手,一直在研究BERT在NLP任务中的用法。在许多笔记本中,我看到定义了自定义数据集 class 并定义了 getitem 方法(连同 len)。

此笔记本中的推文数据集 class - https://www.kaggle.com/abhishek/roberta-inference-5-folds

此笔记本中的

和 text_Dataset class - https://engineering.wootric.com/when-bert-meets-pytorch

有人能解释一下原因吗,需要定义自定义数据集class和getitem(和len)方法。谢谢

在pytorch中推荐通过继承torch.utils.data.Dataset来定义datasets的抽象。这些对象定义了有多少个元素(__len__ 方法)以及如何通过指定的索引(__getitem__(index))获取单个项目。

source code:

class Dataset(object):   
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

所以它基本上是一个薄包装器,增加了连接两个 Dataset 对象的可能性。为了可读性和 API 兼容性,您应该继承它(与 kaggle 中提供的不同)。

您可以阅读有关 PyTorch 数据功能的更多信息here