我试图预测keras/tensorflow中的所有测试批次,然后绘制混淆矩阵。当前BATCH_SIZE
为:32
我的测试数据集是用以下代码从一个大数据集生成的:
test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)
在model.compile()
和model.fit()
之后,我用这个代码得到了预测和正确的标签:
points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()
这个方法只预测一个批——>32预测。
是否有一种方法可以预测keras/tensorflow中的所有测试批次?
提前感谢!
您可以根据文档将整个数据集传递给model.predict
:
输入样本。它可以是:Numpy数组(或类似数组的数组)或列表数组(在模型有多个输入的情况下)。一个TensorFlow张量,或张量列表(在模型有多个输入的情况下)。一个特遣部队。数据集的数据。生成器或keras.utils.Sequence实例。一个对迭代器类型解包行为的更详细描述(Dataset, generator, Sequence)在解包行为中给出Model.fit.
的类迭代器输入部分
points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
或与numpy
:
points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)