自定义数据集和数据加载器

Custom dataset and dataloader

我是 pytorch 新手。 我的大数据集由两个 txt 文件组成,一个用于数据,另一个用于目标数据。 在训练文件中,每一行都是长度为 340 的列表,在目标中,每一行都是长度为 136 的列表。

我想问一下如何定义我的数据集,以便我可以使用 Dataloader 加载我的数据来训练 pytorch 模型?

我给你答案

来自 torch.utils.data

Dataset 是表示数据集的抽象 class。您的自定义数据集应继承 Dataset 并覆盖以下方法:

__len__() 这样 len(dataset) returns 数据集的大小。
__getitem__() 支持索引,这样 dataset[i] 可用于获取第 i 个样本

例如编写自定义数据集
我已经为您编写了一个通用的自定义数据加载器作为您的问题陈述。
这里 data.txt 有数据,label.txt 有标签。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        
       
        with open('data.txt', 'r') as f:
                self.data_info = f.readlines()
        
        with open('label.txt', 'r') as f:
                self.label_info = f.readlines()        


    def __getitem__(self, index):
        
        single_data = self.data_info[index].rstrip('\n')
        

        single_label = self.label_info[index].rstrip('\n')

        return ( single_data , single_label)

    def __len__(self):
        return len(self.data_info)
# Testing 
d = CustomDataset()
print(d[1]) # should output data along with label

这将是您案例的基础,但必须根据您的案例进行一些更改。

注意:您必须根据数据集进行必要的更改