当我在 model.fit() 的参数中已经有了 validation_data 时,没有 val_loss 和 val_accuracy 键

No val_loss and val_accuracy keys when I've already had validation_data in model.fit()'s argument

这是图像增强代码:

batch_size = 16

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

# test_datagen = ImageDataGenerator(rescale=1./255)

# Use flow from dataframe
train_generator = train_datagen.flow_from_dataframe(
        dataframe=train,
        directory="train_images",
        x_col="id",
        y_col=["not_ready", "ready"],
        target_size=(300, 300),
        batch_size=batch_size,
        class_mode="raw",
        color_mode="grayscale",
        subset="training")

validation_generator = train_datagen.flow_from_dataframe(
        dataframe=train,
        directory="train_images",
        x_col="id",
        y_col=["not_ready", "ready"],
        target_size=(300, 300),
        batch_size=batch_size,
        class_mode="raw",
        color_mode="grayscale",
        subset="validation")

设置模型:

early_stopping = EarlyStopping(monitor='loss',mode='min',verbose=1,patience=7, restore_best_weights=True)

opt = Adam(learning_rate=0.0002)

model.compile(optimizer=opt, loss='binary_crossentropy', metrics=['accuracy'])

history = model.fit(train_generator,
        steps_per_epoch=train_generator.n // batch_size,
        epochs=100,
        validation_data=validation_generator,
        validation_steps=validation_generator.n // batch_size,
        callbacks=[early_stopping])

并打印历史键:

print(history.history.keys())

但是结果:

dict_keys(['loss', 'accuracy'])

没有 val_lossval_accuracy 而我已经有 validation_data。为什么会这样以及如何让它们出现?

首先:确保您的模型是 运行 而没有您的 validation_generator。第二:通过迭代几个样本,确保您的 validation_generator 确实 数据。这是一个工作示例:

import tensorflow as tf

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, validation_split=0.2)
BATCH_SIZE = 32

train_generator = img_gen.flow_from_directory(flowers, class_mode='sparse', batch_size=BATCH_SIZE, target_size=(300, 300), shuffle=True, subset="training", color_mode="grayscale")
validation_generator = img_gen.flow_from_directory(flowers, class_mode='sparse', batch_size=BATCH_SIZE, target_size=(300, 300), shuffle=True, subset="validation", color_mode="grayscale")

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu', input_shape=(300, 300, 1)),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(5)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs=10
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss',mode='min',verbose=1,patience=7, restore_best_weights=True)

history = model.fit(train_generator,
        steps_per_epoch=train_generator.n // BATCH_SIZE,
        epochs=1,
        validation_data=validation_generator,
        validation_steps=validation_generator.n // BATCH_SIZE,
        callbacks=[early_stopping])
print(history.history.keys())
Found 2939 images belonging to 5 classes.
Found 731 images belonging to 5 classes.
91/91 [==============================] - 44s 462ms/step - loss: 1.8690 - accuracy: 0.2298 - val_loss: 1.6060 - val_accuracy: 0.2443
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])

还要检查model.fit中的参数validation_steps,例如如果它是0,你将不会在history.history.keys()中看到验证损失和准确性。如果是这种情况,请尝试根本不设置参数:

history = model.fit(train_generator,
        steps_per_epoch=train_generator.n // BATCH_SIZE,
        epochs=1,
        validation_data=validation_generator,
        callbacks=[early_stopping])

有关详细信息,请参阅 docs