我下载了MLP的示例代码,并试图了解其工作原理。我来了这条线,它计算了测试数据集的准确性。
accuracy = sess.run(accuracy, feed_dict={X: data_test, y: labels_test, dropout_keep_prob:1.})
现在,我也想获得预测的标签。我如何获得预测的标签?
您需要获取预测张量。如果您有代码,则是您与y
相比的代码来计算准确性。说这就是prediction
,然后您可以写:
accuracy, prediction = sess.run([accuracy, prediction], feed_dict={X: data_set, y:labels_test, ...})