预测keras / tensorflow中的所有测试批次

predict all test batches in keras / tensorflow

我正在尝试预测我在 keras/tensorflow 中的所有测试批次,然后绘制一个混淆矩阵。 当前BATCH_SIZE为:32

我的测试数据集是使用以下代码从一个大数据集中生成的:

test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)

model.compile()model.fit() 之后,我使用以下代码得到了预测和正确的标签:

points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()

此方法仅预测一批 --> 32 个预测。

有没有办法预测 keras / tensorflow 中的所有测试批次?

提前致谢!

您可以根据 docs:

将整个数据集传递给 model.predict

Input samples. It could be: A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs). A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). A tf.data dataset. A generator or keras.utils.Sequence instance. A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given in the Unpacking behavior for iterator-like inputs section of Model.fit.

points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)

numpy:

points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)