如何重用经过训练的模型来执行分类 - Tensorflow



我已经在Tensorflow上训练了一个CNN模型,我想重用来执行分类并测试它。 这是我目前正在做的事情:

def test(trained_model):
# returns a iterator.get.next()
x_test, y_test = inputs('test_set.tfrecords', batch_size=128, training_size=10000, shuffle=False, num_epochs=1)
# get the output of the cnn
predictions = tf.nn.softmax(AlexNet(x_test))
with tf.name_scope('Accuracy'):
# Accuracy
acc = tf.equal(tf.argmax(predictions, 1), tf.argmax(y_test, 1))
acc = tf.reduce_mean(tf.cast(acc, tf.float32))
# Initializing the variables
init = tf.global_variables_initializer()
with tf.Session() as new_sess:
saver = tf.train.import_meta_graph(trained_model)
saver.restore(new_sess,tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
cnt = 1
try:
while(True):
new_sess.run(init)
print(acc.eval(), cnt)
cnt+=1
except tf.errors.OutOfRangeError:
print('Finished batch')

它似乎有效,但它与我找到的其他答案不同,人们正在使用graph.get_tensor_by_name("y_:0"),以及我不明白feed_dict。 谁能告诉我我正在做的事情是否正确,工作流程是正确的?

你正在做的事情是正确的,没有"正确的工作流程"(tl;DR:它们在逻辑上是等价的(。

当您使用Saver保存模型时,Tensorflow 会自动为您创建.meta.ckpt文件,其中.meta包含图形定义(节点及其连接的列表(,.ckpt文件包含模型参数。

tf.train.import_meta_graph当前默认图形中加载保存在.meta文件中的图形定义,restore()调用将使用ckpt文件的权重集填充图形。

显然,如果当前默认图形已经具有import_meta_graph尝试定义的相同定义,则跳过定义步骤。

这意味着,如果您在导入元图之前已经定义了相同的图,则可以使用 python 变量(例如predictions(来引用图中的节点。

相反,如果您尚未定义图形,则import_meta_graph将为您定义图形,但您将没有任何可供使用的 python 变量。

因此,您必须从图中提取对所需节点的引用并创建一个要使用的python变量(例如input = graph.get_tensor_by_name("logits:0")(

相关内容

最新更新