在 Tensorflow 中查找图像分类的混淆矩阵

Find confusion matrix of image classification in Tensorflow

我正在训练一个模型,按照本教程class将图像分成 2 classes:https://www.tensorflow.org/tutorials/images/classification

model.fit() 之后,我想使用包含未包含在训练或验证集中的图像的测试集来评估模型预测的准确性。测试集包含2个文件夹,其中包含对应class.

的图像
├── test_data/
│   ├── class1/
│   ├── class2/

我想使用混淆矩阵找出每个 class 的召回率、精确率和准确率。 但是,我是深度学习和 Tensorflow 的新手。我不知道如何获得每个 class 的混淆矩阵。我也不确定我将图像传递给模型的方式是否正确。

以下是我目前使用模型预测新数据的实现。

# get the list of class names in the training set
train_class_names = train_ds.class_names

# load the test data
test_data_dir = pathlib.Path('test_data/')
test_data_list = list(test_data_dir.glob('*/*.jpg'))

test_ds = tf.keras.utils.image_dataset_from_directory(
  test_data_dir,
  image_size=(img_height, img_width),
  batch_size=batch_size)

predicted_img = []

# for every image in the test_data folder, pass it to the model to predict its class
for path in test_data_list:
    img = tf.keras.utils.load_img(
        path, target_size=(img_height, img_width)
    )
    img_array = tf.keras.utils.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)
    
    test_class_name = path.parent.name

    predictions = model.predict(img_array)
    score = tf.nn.softmax(predictions[0])
    
    # append the image, predicted class and actual class to a list 
    # so that I can print them out to see if the prediction is correct
    predicted_img.append([img, train_class_names[np.argmax(score)], test_class_name])

试试这个

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
sns.set_style('darkgrid')
classes=test_ds.class_names # ordered list of class names
ytrue=[]
for images, label in test_ds:   
    for e in label:
        ytrue.append(classes[e]) # list of class names associated with each image file in test dataset 
ypred=[]
errors=0
count=0
preds=model.predict(test_ds, verbose=1) # predict on the test data
for i, p in enumerate(preds):
    count +=1
    index=np.argmax(p) # get index of prediction with highest probability
    klass=classes
[index] 
    ypred.append(klass)  
    if klass != ytrue[i]:
        errors +=1
acc= (count-errors)* 100/count
msg=f'there were {count-errors} correct predictions in {count} tests for an accuracy of {acc:6.2f} % '
print(msg) 
ypred=np.array(ypred)
ytrue=np.array(ytrue)
if len(classes)<= 30: # if more than 30 classes plot is not useful to cramed
        # create a confusion matrix 
        cm = confusion_matrix(y_true, y_pred )        
        length=len(classes)
        if length<8:
            fig_width=8
            fig_height=8
        else:
            fig_width= int(length * .5)
            fig_height= int(length * .5)
        plt.figure(figsize=(fig_width, fig_height))
        sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)       
        plt.xticks(np.arange(length)+.5, classes, rotation= 90)
        plt.yticks(np.arange(length)+.5, classes, rotation=0)
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.title("Confusion Matrix")
        plt.show()
clr = classification_report(ytrue, ypred, target_names=class_names)
print("Classification Report:\n----------------------\n", clr)