在 Tensorflow 图像分类中获取标签

Getting Labels in a Tensorflow Image Classification

我正在按照 this TensorFlow tutorial 进行图像分类,并从 Gdrive 加载我自己的数据集。 现在我想绘制混淆矩阵。首先,我预测了验证数据集的标签:

val_preds = model.predict(val_ds)

但我不确定如何获取原始标签以将预测与它们进行比较。我尝试了不同的方法,但准确率很低,所以我知道标签不是它们应该的样子。

val_ds_labels = np.concatenate([y for x, y in val_ds], axis=0)

这给了我 0.067 的准确度,而下面给了我大约 0.70 的准确度。

epochs = 10
history=model.fit(train_ds, epochs=epochs, validation_data=val_ds)

以下是我创建验证和训练数据集的方式:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "images",
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=image_size,
    batch_size=batch_size,
    label_mode='int'
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "images",
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=image_size,
    batch_size=batch_size,
    label_mode='int'
)
train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)

然后创建模型并编译它:

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseTopKCategoricalAccuracy(k=1)],
)

适合

epochs = 10
history=model.fit(train_ds, epochs=epochs, validation_data=val_ds)

我有 22 个标签。

val_preds = model.predict(val_ds)

训练完成后得到验证集的真实标签如下:

epochs=5
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

....
....
Epoch 4/5
20ms/step - loss: 0.6368 - accuracy: 0.7613 - val_loss: 0.9294 - val_accuracy: 0.6185
Epoch 5/5
20ms/step - loss: 0.4307 - accuracy: 0.8531 - val_loss: 0.9552 - val_accuracy: 0.6635

# get the labels 
predictions = np.array([])
labels =  np.array([])

for x, y in val_ds:
  predictions = np.concatenate([predictions, np.argmax(model.predict(x), axis=-1)])
  labels = np.concatenate([labels, y.numpy()])

predictions[:10]
array([0., 4., 3., 0., 3., 4., 2., 4., 4., 0.])

labels[:10]
array([0., 4., 3., 0., 3., 4., 1., 2., 4., 0.])

m = tf.keras.metrics.Accuracy()
m(labels, predictions).numpy()
# 0.66348773