Keras+Tensorflow 中的混淆矩阵
Confusion Matrix in Keras+Tensorflow
Q1
我训练了一个 CNN 模型并将其保存为 model.h5
。我正在尝试检测 3 个对象。比如说,“猫”、“狗”和“其他”。我的测试集有 300 张图像,每个类别 100 张。前 100 是“猫”,第二个 100 是“狗”,第三个 100 是“其他”。我正在使用 Keras class ImageDataGenerator
和 flow_from_directory
。这是示例代码:
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(150, 150),
batch_size=20,
class_mode='sparse',
shuffle=False)
现在可以使用
from sklearn.metrics import confusion_matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
我需要 y_test
和 y_pred
。我可以使用以下代码获得 y_pred
:
probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)
[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2]
这基本上是将对象预测为 0,1 和 2。现在我知道前 100 个对象(猫)是 0,第二个 100 个对象(狗)是 1,第三个 100 个对象(其他)是 2。做我使用 numpy
手动创建一个列表,其中第一个 100 点是 0,第二个 100 点是 1,第三个 100 点是 2 以获得 y_test
?有没有 Keras class 可以做到这一点(创建 y_test
)?
Q2
如何查看错误检测到的对象。如果你查看 print(y_pred)
,第 3 个点是 1,这是错误预测的。如何在不手动进入我的“test_dir”文件夹的情况下看到该图像?
由于您没有使用任何增强和 shuffle=False
,您可以简单地从生成器获取图像:
imgBatch = next(test_generator)
#it may be interesting to create the generator again if
#you're not sure it has output exactly all images before
使用 Pillow (PIL) 或 MatplotLib 等绘图库在 imgBatch 中绘制每个图像。
为了仅绘制所需的图像,请将 y_test
与 y_pred
进行比较:
compare = y_test == y_pred
position = 0
while position < len(y_test):
imgBatch = next(test_generator)
batch = imgBatch.shape[0]
for i in range(position,position+batch):
if compare[i] == False:
plot(imgBatch[i-position])
position += batch
Q1
我训练了一个 CNN 模型并将其保存为 model.h5
。我正在尝试检测 3 个对象。比如说,“猫”、“狗”和“其他”。我的测试集有 300 张图像,每个类别 100 张。前 100 是“猫”,第二个 100 是“狗”,第三个 100 是“其他”。我正在使用 Keras class ImageDataGenerator
和 flow_from_directory
。这是示例代码:
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(150, 150),
batch_size=20,
class_mode='sparse',
shuffle=False)
现在可以使用
from sklearn.metrics import confusion_matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
我需要 y_test
和 y_pred
。我可以使用以下代码获得 y_pred
:
probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)
[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2]
这基本上是将对象预测为 0,1 和 2。现在我知道前 100 个对象(猫)是 0,第二个 100 个对象(狗)是 1,第三个 100 个对象(其他)是 2。做我使用 numpy
手动创建一个列表,其中第一个 100 点是 0,第二个 100 点是 1,第三个 100 点是 2 以获得 y_test
?有没有 Keras class 可以做到这一点(创建 y_test
)?
Q2
如何查看错误检测到的对象。如果你查看 print(y_pred)
,第 3 个点是 1,这是错误预测的。如何在不手动进入我的“test_dir”文件夹的情况下看到该图像?
由于您没有使用任何增强和 shuffle=False
,您可以简单地从生成器获取图像:
imgBatch = next(test_generator)
#it may be interesting to create the generator again if
#you're not sure it has output exactly all images before
使用 Pillow (PIL) 或 MatplotLib 等绘图库在 imgBatch 中绘制每个图像。
为了仅绘制所需的图像,请将 y_test
与 y_pred
进行比较:
compare = y_test == y_pred
position = 0
while position < len(y_test):
imgBatch = next(test_generator)
batch = imgBatch.shape[0]
for i in range(position,position+batch):
if compare[i] == False:
plot(imgBatch[i-position])
position += batch