如何使用带 preprocessing_function 的数据生成器正确加载图像?

How can i load image correctly using Data Generator with preprocessing_function?

我正在尝试将 CLAHE 应用到我的数据集,但我不确定它是否应用正确。我试图在应用 CLAHE 后可视化结果,但我只能在使用 plt.imshow(img[0].astype('uint8')) 时可视化。所以我不确定它是否被正确应用。如果有人能帮助我,我将不胜感激。

BSZ=64
tsize=(48,48)    

clahe = cv2.createCLAHE(clipLimit=0.01, tileGridSize=(8,8))    
def claheImage(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = gray.astype(np.uint16)
    eq = clahe.apply(gray)
    eq = cv2.cvtColor(eq, cv2.COLOR_GRAY2BGR)
    eq = eq.astype(np.float32)
    return eq

val_generator = ImageDataGenerator(rescale=1/255,
                                  preprocessing_function=claheImage)

tr_generator = ImageDataGenerator(rescale = 1/255,
                                  zoom_range=0.3,
                                  shear_range=0.3,
                                  horizontal_flip=True,
                                  rotation_range=15,
                                  fill_mode="nearest",
                                  preprocessing_function=claheImage)

val_data = val_generator.flow_from_directory("./test",
                                             batch_size= BSZ,
                                             target_size=tsize,
                                             color_mode="rgb",
                                             interpolation="nearest")
tr_data = tr_generator.flow_from_directory("./train",
                                          batch_size= BSZ,
                                          target_size=tsize,
                                          color_mode="rgb",
                                          interpolation="nearest")

当我使用这段代码时,我得到一个空图像和警告: “使用 RGB 数据将输入数据剪切到 imshow 的有效范围([0..1] 用于浮点数或 [0..255] 用于整数)”

for _ in range(1):
    img, label = next(val_data)
    print(img.shape)
    plt.imshow(img[0])
    type(img[0])
    plt.show()

但是当我使用 astype('unit8') 时我可以显示它。所以,我不确定是否正确应用了 CLAHE。

for _ in range(1):
    img, label = next(val_data)
    print(img.shape)
    #plt.imshow(img[0])
    plt.imshow(img[0].astype('uint8'))
    type(img[0])
    plt.show()

更多详细信息可以在下面找到link。

https://colab.research.google.com/drive/1RguAZ9_9pREQNDkP6ort2mY1ffsFOLms?usp=sharing

这是因为 Matplotlib。像素值的范围应介于整数 [0, 255] 和浮点数 [0., 1.] 之间。您的预处理函数不得遵守这一点。

我建议您删除 rescale=1/255,但要确保函数输出的值范围在 0 到 1 之间。

def claheImage(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = gray.astype(np.uint16)
    eq = clahe.apply(gray)
    eq = cv2.cvtColor(eq, cv2.COLOR_GRAY2BGR)
    eq = eq.astype(np.float32)
    eq = eq / np.max(eq) # this makes sure the output is between 0 and 1
    return eq