Keras+Tensorflow 中的混淆矩阵

Confusion Matrix in Keras+Tensorflow

Q1

我训练了一个 CNN 模型并将其保存为 model.h5。我正在尝试检测 3 个对象。比如说,“猫”、“狗”和“其他”。我的测试集有 300 张图像,每个类别 100 张。前 100 是“猫”,第二个 100 是“狗”,第三个 100 是“其他”。我正在使用 Keras class ImageDataGeneratorflow_from_directory。这是示例代码:

test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='sparse',
        shuffle=False)

现在可以使用

from sklearn.metrics import confusion_matrix

cnf_matrix = confusion_matrix(y_test, y_pred)

我需要 y_testy_pred。我可以使用以下代码获得 y_pred

probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)

[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
 0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2]

这基本上是将对象预测为 0,1 和 2。现在我知道前 100 个对象(猫)是 0,第二个 100 个对象(狗)是 1,第三个 100 个对象(其他)是 2。做我使用 numpy 手动创建一个列表,其中第一个 100 点是 0,第二个 100 点是 1,第三个 100 点是 2 以获得 y_test?有没有 Keras class 可以做到这一点(创建 y_test)?

Q2

如何查看错误检测到的对象。如果你查看 print(y_pred),第 3 个点是 1,这是错误预测的。如何在不手动进入我的“test_dir”文件夹的情况下看到该图像?

由于您没有使用任何增强和 shuffle=False,您可以简单地从生成器获取图像:

imgBatch = next(test_generator)
    #it may be interesting to create the generator again if 
    #you're not sure it has output exactly all images before

使用 Pillow (PIL) 或 MatplotLib 等绘图库在 imgBatch 中绘制每个图像。

为了仅绘制所需的图像,请将 y_testy_pred 进行比较:

compare = y_test == y_pred

position = 0
while position < len(y_test):
    imgBatch = next(test_generator)
    batch = imgBatch.shape[0]

    for i in range(position,position+batch):
        if compare[i] == False:
            plot(imgBatch[i-position])

    position += batch