将批量大小增加到超过 1 时 Pytorch 中的 RuntimeError

RuntimeError in Pytorch when increasing batch size to more than 1

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("TkAgg")
import os, h5py
import PIL
#------------------------------
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
#------------------------------
from data_augmentation import *
#------------------------------
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

class NiftiDataset(Dataset):

    def __init__(self,transformation_params,data_path, mode='train',transforms=None ):
        """
        Parameters:
            data_path (string): Root directory of the preprocessed dataset.
            mode (string, optional): Select the image_set to use, ``train``, ``valid``
            transforms (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_path = data_path
        self.mode = mode
        self.images = []
        self.labels = []
        self.W_maps = []
        self.centers = []
        self.radiuss = []
        self.pixel_spacings = []
        self.transformation_params = transformation_params
        self.transforms = transforms
        #-------------------------------------------------------------------------------------
        if self.mode == 'train': 
            self.data_path = os.path.join(self.data_path,'train_set')  
        elif self.mode == 'valid':  
            self.data_path = os.path.join(self.data_path,'validation_set') 
        #-------------------------------------------------------------------------------------
        for _, _, f in os.walk(self.data_path):
            for file in f:
                hdf_file = os.path.join(self.data_path,file)
                data = h5py.File(hdf_file,'r') # Dictionary
                # Preprocessing of Input Image and Label
                patch_img, patch_gt, patch_wmap = PreProcessData(file, data, self.mode, self.transformation_params)
                #print(type(data))
                self.images.append(patch_img) # 2D image
                #print('image shape is : ',patch_img.shape)
                self.labels.append(patch_gt) # 2D label
                #print('label shape is : ',patch_img.shape)
                self.W_maps.append(patch_wmap) # Weight_Map
                # self.centers.append(data['roi_center'][:]) # [x,y]
                # self.radiuss.append(data['roi_radii'][:]) # [R_min,R_max]
                # self.pixel_spacings.append(data['pixel_spacing'][:]) # [x , y , z]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        W_map = self.W_maps[index]
        if self.transforms is not None:
            image, label, W_maps = self.transforms(image, label, W_map)

        return image, label, W_map
#=================================================================================================

if __name__ == '__main__':
    # Test Routinue to check your threaded dataloader
    # ACDC dataset has 4 labels
    n_labels = 4
    path = './hdf5_files'
    batch_size = 1
    # Data Augmentation Parameters
    # Set patch extraction parameters
    size1 = (128, 128)
    patch_size = size1
    mm_patch_size = size1
    max_size = size1

    train_transformation_params = {
        'patch_size': patch_size,
        'mm_patch_size': mm_patch_size,
        'add_noise': ['gauss', 'none1', 'none2'],
        'rotation_range': (-5, 5),
        'translation_range_x': (-5, 5),
        'translation_range_y': (-5, 5),
        'zoom_range': (0.8, 1.2),
        'do_flip': (False, False),
        }

    valid_transformation_params = {
        'patch_size': patch_size,
        'mm_patch_size': mm_patch_size}

    transformation_params = { 'train': train_transformation_params,
                              'valid': valid_transformation_params,
                              'n_labels': 4,
                              'data_augmentation': True,
                              'full_image': False,
                              'data_deformation': False,
                              'data_crop_pad': max_size}
#====================================================================                              
dataset = NiftiDataset(transformation_params=transformation_params,data_path=path,mode='train')
dataloader = DataLoader(dataset=dataset,batch_size=2,shuffle=True,num_workers=0)

dataiter = iter(dataloader)
data = dataiter.next()
images, labels,W_map = data

#===============================================================================
# Data Visualization 
#===============================================================================
print('image: ',images.shape,images.type(),'label: ',labels.shape,labels.type(),
    'W_map: ',W_map.shape,W_map.type())

img   = transforms.ToPILImage()(images[0,0,:,:,0].float())
lbl   = transforms.ToPILImage()(labels[0,0,:,:].float())
W_mp  = transforms.ToPILImage()(W_map [0,0,:,:].float())

plt.subplot(1,3,1)
plt.imshow(img,cmap='gray',interpolation=None)
plt.title('image')
plt.subplot(1,3,2)
plt.imshow(lbl,cmap='gray',interpolation=None)
plt.title('label')
plt.subplot(1,3,3)
plt.imshow(W_mp,cmap='gray',interpolation=None)
plt.title('Weight Map')
plt.show()

我注意到一些奇怪的事情,例如 Tensor 类型不同,即使图像和标签以及权重图是具有相同类型和大小的图像。 错误回溯:

Traceback (most recent call last):
  File "D:\Saudi_CV\Vibot\Smester_2_Medical Image analysis\Project_2020\OUR_Project\data_loader.py", line 118, in <module>
    data = dataiter.next()
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
    data = self._next_data()
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 64, in default_collate
    return default_collate([torch.as_tensor(b) for b in batch])
  File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: Expected object of scalar type Double but got scalar type Long for sequence element 1 in sequence argument at position #1 'tensors'
[Finished in 19.9s with exit code 1]

问题已通过本页解释的解决方案解决link

        image = torch.from_numpy(self.images[index]).type(torch.FloatTensor)
        label = torch.from_numpy(self.labels[index]).type(torch.FloatTensor)
        W_map = torch.from_numpy(self.W_maps[index]).type(torch.FloatTensor)