DataGenerator(Sequence) - 如何检查 batch_x 和 batch_y.shape?

DataGenerator(Sequence) - How to check batch_x and batch_y.shape?

我创建了这个 DataGenerator:

class DataGenerator(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx*self.batch_size : (idx + 1)*self.batch_size]
        batch_x = np.array([resize(imread(file_name), (224, 224)) for file_name in batch_x])
        batch_x = batch_x * 1./255
        batch_y = self.y[idx*self.batch_size : (idx + 1)*self.batch_size]
        batch_y = np.array(batch_y)

        return batch_x, batch_y

我现在想检查 batch_xbatch_yshapetype。我该怎么做?

只需在 __getitem__ 函数中添加两行 print 行,这样每次调用生成器时,您都会看到所需的信息:

print('batch_x : shape = %s, type = %s' % (batch_x.shape, batch_x.dtype) ) # If np.array
print('batch_y : shape = %s, type = %s' % (batch_y.shape, batch_y.dtype) )