DataLoader returns 连续多个值而不是列表或元组

DataLoader returns multiple values sequentially instead of a list or tuple

def __init__():

def __len__():

def __getitem__(self, idx):    
    cat_cols = (self.cat_cols.values.astype(np.float32))
    cont_cols = (self.cont_cols.values.astype(np.float32))
    label = (self.label.astype(np.int32))
    return (cont_cols[idx], cat_cols[idx], label[idx])

当我在上面 class 中使用数据加载器时,我得到 cont_cols、cat_cols 和标签作为索引为 0、1 和 2 的输出。而我希望它们在一起.我曾尝试将值作为字典返回,但后来我遇到了索引问题。

我必须将数据加载器的输出读取为

dl = DataLoader(dataset[0], batch_size = 1)



for i, data in enumerate(dl):
    if i == 0:
       cont = data
    if i == 1:
       cat = data
    if i == 2:
       label = data

目前我的输出为

for i, data in enumerate(dl):
   print(i, data) 

0张量([[3.2800e+02, 4.8000e+01, 1.0000e+03, 1.4069e+03, 4.6613e+05, 5.3300e+04, 0.0000e+00, 5.0000e+00, 1.0000e+00, 1.0000e+00, 2.0000e+00, 7.1610e+04, 6.5100e+03, 1.3020e+04, 5.2080e+04, 2.0040e+03]])

1张量([[ 2., 1., 1., 4., 2., 17., 0., 2., 3., 0., 4., 4., 1., 2 ., 2., 10., 1.]])

2张量([1], dtype=torch.int32)

我想要的是数据[0]、数据[1] 和数据[2] 访问的输出,但数据加载器只返回数据[0]。它首先返回 cont_cols,然后是 cat_cols,然后是标签。

我认为你在这里感到困惑,你的数据集确实可以 return 元组 s 但你必须以不同的方式处理它。

您的数据集定义为:

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __len__():
        pass

    def __getitem__(self, idx):    
        cat_cols = (self.cat_cols.values.astype(np.float32))
        cont_cols = (self.cont_cols.values.astype(np.float32))
        label = (self.label.astype(np.int32))
        return (cont_cols[idx], cat_cols[idx], label[idx])

然后定义数据集和数据加载器。请注意,您不应在此处提供 dataset[0],而应提供 dataset:

>>> dataset = Dataset()
>>> dl = DataLoader(dataset, batch_size=1)

然后循环访问您的数据加载器内容:

>>> for cont, cat, label in dl:
...   print(cont, cat, label)