预测keras / tensorflow中的所有测试批次
predict all test batches in keras / tensorflow
我正在尝试预测我在 keras/tensorflow 中的所有测试批次,然后绘制一个混淆矩阵。
当前BATCH_SIZE
为:32
我的测试数据集是使用以下代码从一个大数据集中生成的:
test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)
在 model.compile()
和 model.fit()
之后,我使用以下代码得到了预测和正确的标签:
points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()
此方法仅预测一批 --> 32 个预测。
有没有办法预测 keras / tensorflow 中的所有测试批次?
提前致谢!
您可以根据 docs:
将整个数据集传递给 model.predict
Input samples. It could be: A Numpy array (or array-like), or a list
of arrays (in case the model has multiple inputs). A TensorFlow
tensor, or a list of tensors (in case the model has multiple inputs).
A tf.data dataset. A generator or keras.utils.Sequence instance. A
more detailed description of unpacking behavior for iterator types
(Dataset, generator, Sequence) is given in the Unpacking behavior for
iterator-like inputs section of Model.fit.
points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
或 numpy
:
points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
我正在尝试预测我在 keras/tensorflow 中的所有测试批次,然后绘制一个混淆矩阵。
当前BATCH_SIZE
为:32
我的测试数据集是使用以下代码从一个大数据集中生成的:
test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)
在 model.compile()
和 model.fit()
之后,我使用以下代码得到了预测和正确的标签:
points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()
此方法仅预测一批 --> 32 个预测。
有没有办法预测 keras / tensorflow 中的所有测试批次?
提前致谢!
您可以根据 docs:
将整个数据集传递给model.predict
Input samples. It could be: A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs). A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). A tf.data dataset. A generator or keras.utils.Sequence instance. A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given in the Unpacking behavior for iterator-like inputs section of Model.fit.
points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
或 numpy
:
points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)