Python Dataset Class + PyTorch Dataloader: 卡在__getitem__,如何在测试期间获取Index、Label等?
Python Dataset Class + PyTorch Dataloader: Stuck at __getitem__, how to get Index, Label and so on during Testing?
我有一个小问题,但我现在被困了很长一段时间。希望有人可以帮助我。我目前正在使用 Kddcup99 数据集,我喜欢通过 DeepLearning (CNN Network)
进行训练
我有一个 "Dataset" Class,其中包含 Panda Dataframe。因此我分成正常和验证数据集。到目前为止,没问题。
我将它加载到一个 Numpy 向量中,将它连接到 Tensor,然后将它定向到 DataLoader。
数据集 Class 有两个重要的 class 用于迭代:
def __len__(self):
return len(self.val_df)
def __getitem__(self, index):
img, target = self.val_df[index][:-1], self.val_df[index][-1]
return img, target, index
不在 class 中的是 DataLoader 字符串:
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)
在我的 Trainer Class 中,我有一个 for 循环,它应该遍历 Dataloader:
with torch.no_grad():
for data in dataloader:
inputs, labels, idx = data
inputs = inputs.to(self.device)
但不会。我无法访问标签、索引等。
我现在的问题是:为什么?
如何通过数据加载器访问给定数据集中的标签和索引?
谢谢大家的帮助!
非常感谢。
DataLoader
is the dataset from which you want to load the data, that's usually a Dataset
的第一个参数,但不限于 Dataset
的任何实例。只要它定义了长度(__len__
)并且可以被索引(__getitem__
允许)它是可以接受的。
您正在将 datat.val_df
传递给 DataLoader
,这可能是一个 NumPy 数组。 NumPy 数组有长度并且可以被索引,所以它可以用在 DataLoader
中。由于您直接传递该数组,因此永远不会调用数据集的 __getitem__
,但数组本身已被索引,因此每个项目都只是 data.val_df[index]
.
您必须使用数据集本身 (datat
),而不是使用 DataLoader
的基础数据:
test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)
我有一个小问题,但我现在被困了很长一段时间。希望有人可以帮助我。我目前正在使用 Kddcup99 数据集,我喜欢通过 DeepLearning (CNN Network)
进行训练我有一个 "Dataset" Class,其中包含 Panda Dataframe。因此我分成正常和验证数据集。到目前为止,没问题。 我将它加载到一个 Numpy 向量中,将它连接到 Tensor,然后将它定向到 DataLoader。
数据集 Class 有两个重要的 class 用于迭代:
def __len__(self):
return len(self.val_df)
def __getitem__(self, index):
img, target = self.val_df[index][:-1], self.val_df[index][-1]
return img, target, index
不在 class 中的是 DataLoader 字符串:
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)
在我的 Trainer Class 中,我有一个 for 循环,它应该遍历 Dataloader:
with torch.no_grad():
for data in dataloader:
inputs, labels, idx = data
inputs = inputs.to(self.device)
但不会。我无法访问标签、索引等。
我现在的问题是:为什么? 如何通过数据加载器访问给定数据集中的标签和索引?
谢谢大家的帮助! 非常感谢。
DataLoader
is the dataset from which you want to load the data, that's usually a Dataset
的第一个参数,但不限于 Dataset
的任何实例。只要它定义了长度(__len__
)并且可以被索引(__getitem__
允许)它是可以接受的。
您正在将 datat.val_df
传递给 DataLoader
,这可能是一个 NumPy 数组。 NumPy 数组有长度并且可以被索引,所以它可以用在 DataLoader
中。由于您直接传递该数组,因此永远不会调用数据集的 __getitem__
,但数组本身已被索引,因此每个项目都只是 data.val_df[index]
.
您必须使用数据集本身 (datat
),而不是使用 DataLoader
的基础数据:
test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)