自定义数据集不接受 PyTorch 中的参数

Custom Dataset not accepting argument in PyTorch

我正在尝试使用 this dataset 在 PyTorch 中创建自定义数据集。它的形状是 (X, 785),X 是样本数,每行包含索引 0 处的标签和 784 个像素值。这是我的代码:

from torch.utils.data import Dataset
def SignMNISTDataset(Dataset):

  def __init__(self, csv_file_path, mode='Train'):
    self.labels = []
    self.pixels = []
    self.mode = mode

    data = pd.read_csv(csv_file_path).values
    if self.mode == 'Train':
      self.labels = data[:,0].tolist()
      print("Training labels acquired")

    for idx in range(len(self.labels)):
      self.pixels.append(data[idx][1:].tolist())

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    pixels = self.pixels[idx]
    if self.mode == 'Train':
      labels = self.labels[idx]
      content = {"pixels":pixels, "label":labels}
    else:
      content = {"pixels":pixels}
    return content

training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')

在 运行 上,我收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-46-0173199f8794> in <module>()
     27     return content
     28 
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')
     30 from torch.utils.data import DataLoader
     31 

TypeError: SignMNISTDataset() takes 1 positional argument but 2 were given

这到底是从哪里来的?在对象创建期间是否以某种方式未读取模式参数? 我的最终目标是创建一个用于对符号字符进行分类的神经网络,遵循 this tutorial.

我尝试在对象创建过程中明确提及关键字 mode。这就是我得到的 -

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-48-fd796c48dc67> in <module>()
     27     return content
     28 
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', mode='Train')

TypeError: SignMNISTDataset() got an unexpected keyword argument 'mode'

请使用

class SignMNISTDataset(Dataset):

而不是

def SignMNISTDataset(Dataset):