如何使用 Tensorflow Cifar10 教程代码进行推理?



我是TensorFlow的绝对初学者。

如果我有一张(或一组图片(想尝试使用 Cifar10 TensorFlow 教程中的代码进行分类,我将如何做?

我完全不知道从哪里开始。

  1. 完全按照教程使用基本数据集CIFAR10数据集训练模型。
  2. 使用您自己的输入创建一个新图表 - 可能最简单的方法是使用tf.placeholder并按如下方式提供数据,但还有很多其他方法。
  3. 启动会话,加载以前保存的权重。
  4. 运行会话(如果您使用的是上述placeholder,则使用feed_dict(。

.

import tensorflow as tf
train_dir = '/tmp/cifar10_train'  # or use FLAGS as in the train example
batch_size = 8
height = 32
width = 32
image = tf.placeholder(shape=(batch_size, height, width, 3), dtype=tf.uint8)
std_img = tf.image.per_image_standardization(image)
logits = cifar10.inference(std_img)
predictions = tf.argmax(logits, axis=-1)
def get_image_data_batches():
n_batchs = 100
for i in range(n_batchs):
yield (np.random.uniform(size=(batch_size, height, width, 3)*255).astype(np.uint8)
def do_stuff_with(logit_vals, prediction_vals):
pass
with tf.Session() as sess:
# restore variables
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(train_dir))
# run inference
for batch_data in get_image_data_batches():
logit_vals, prediction_vals = sess.run([logits, predictions], feed_dict={image: image_data})
do_stuff_with(logit_vals, prediction_vals)

有更好的方法将数据放入图形中(见tf.data.Dataset(,但我相信tf.placeholder是学习和启动和运行某些东西的最简单方法。

另请查看tf.estimator.Estimator,了解更简洁的会话管理方式。它与本教程中的方式非常不同,并且稍微不那么灵活,但对于标准网络,它们可以节省您编写大量样板代码的时间。

最新更新