如何在自定义 DataGenerator 中实现缩放?
How to implement rescaling in custom DataGenerator?
我使用 tf.keras.utils.Sequence
:(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence) 创建了自定义 DataGenerator
。
这是自定义的 DataGenerator
:
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class DataGenerator(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_x = np.array(batch_x)
batch_x = batch_x*1/255
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
return np.array([
resize(imread(file_name), (64, 128))
for file_name in batch_x]), np.array(batch_y)
x_set
是我的图像的路径列表,y_set
是关联的 classes。
我现在想添加一个函数来重新缩放图像的每个像素,方法是将其与 rescale = 1./255
相乘,如 ImageDataGenerator
class: https://keras.io/api/preprocessing/image/#ImageDataGenerator%20class
当我将此代码与 model.fit_generator
一起应用时:
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
steps_per_epoch = num_train_samples // 128,
validation_steps = num_val_samples // 128,
epochs = 10)
我收到这个错误:
---------------------------------------------------------------------------
UFuncTypeError Traceback (most recent call last)
<ipython-input-62-571a868b2d2a> in <module>()
3 steps_per_epoch = num_train_samples // 128,
4 validation_steps = num_val_samples // 128,
----> 5 epochs = 10)
8 frames
<ipython-input-54-d98c3b0c7c56> in __getitem__(self, idx)
15 self.batch_size]
16 batch_x = np.array(batch_x)
---> 17 batch_x = batch_x*1/255
18 batch_y = self.y[idx * self.batch_size:(idx + 1) *
19 self.batch_size]
UFuncTypeError: ufunc 'multiply' did not contain a loop with signature matching types (dtype('<U77'), dtype('<U77')) -> dtype('<U77')
我必须如何修改我的代码?
这样试试
class DataGenerator(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_x = np.array([resize(imread(file_name), (64, 128)) for file_name in batch_x])
batch_x = batch_x * 1./255
batch_y = self.y[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_y = np.array(batch_y)
return batch_x, batch_y
我使用 tf.keras.utils.Sequence
:(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence) 创建了自定义 DataGenerator
。
这是自定义的 DataGenerator
:
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class DataGenerator(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_x = np.array(batch_x)
batch_x = batch_x*1/255
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
return np.array([
resize(imread(file_name), (64, 128))
for file_name in batch_x]), np.array(batch_y)
x_set
是我的图像的路径列表,y_set
是关联的 classes。
我现在想添加一个函数来重新缩放图像的每个像素,方法是将其与 rescale = 1./255
相乘,如 ImageDataGenerator
class: https://keras.io/api/preprocessing/image/#ImageDataGenerator%20class
当我将此代码与 model.fit_generator
一起应用时:
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
steps_per_epoch = num_train_samples // 128,
validation_steps = num_val_samples // 128,
epochs = 10)
我收到这个错误:
---------------------------------------------------------------------------
UFuncTypeError Traceback (most recent call last)
<ipython-input-62-571a868b2d2a> in <module>()
3 steps_per_epoch = num_train_samples // 128,
4 validation_steps = num_val_samples // 128,
----> 5 epochs = 10)
8 frames
<ipython-input-54-d98c3b0c7c56> in __getitem__(self, idx)
15 self.batch_size]
16 batch_x = np.array(batch_x)
---> 17 batch_x = batch_x*1/255
18 batch_y = self.y[idx * self.batch_size:(idx + 1) *
19 self.batch_size]
UFuncTypeError: ufunc 'multiply' did not contain a loop with signature matching types (dtype('<U77'), dtype('<U77')) -> dtype('<U77')
我必须如何修改我的代码?
这样试试
class DataGenerator(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_x = np.array([resize(imread(file_name), (64, 128)) for file_name in batch_x])
batch_x = batch_x * 1./255
batch_y = self.y[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_y = np.array(batch_y)
return batch_x, batch_y