我已经在Tensorflow上训练了一个CNN模型,我想重用来执行分类并测试它。 这是我目前正在做的事情:
def test(trained_model):
# returns a iterator.get.next()
x_test, y_test = inputs('test_set.tfrecords', batch_size=128, training_size=10000, shuffle=False, num_epochs=1)
# get the output of the cnn
predictions = tf.nn.softmax(AlexNet(x_test))
with tf.name_scope('Accuracy'):
# Accuracy
acc = tf.equal(tf.argmax(predictions, 1), tf.argmax(y_test, 1))
acc = tf.reduce_mean(tf.cast(acc, tf.float32))
# Initializing the variables
init = tf.global_variables_initializer()
with tf.Session() as new_sess:
saver = tf.train.import_meta_graph(trained_model)
saver.restore(new_sess,tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
cnt = 1
try:
while(True):
new_sess.run(init)
print(acc.eval(), cnt)
cnt+=1
except tf.errors.OutOfRangeError:
print('Finished batch')
它似乎有效,但它与我找到的其他答案不同,人们正在使用graph.get_tensor_by_name("y_:0")
,以及我不明白feed_dict
。 谁能告诉我我正在做的事情是否正确,工作流程是正确的?
你正在做的事情是正确的,没有"正确的工作流程"(tl;DR:它们在逻辑上是等价的(。
当您使用Saver
保存模型时,Tensorflow 会自动为您创建.meta
和.ckpt
文件,其中.meta
包含图形定义(节点及其连接的列表(,.ckpt
文件包含模型参数。
tf.train.import_meta_graph
当前默认图形中加载保存在.meta
文件中的图形定义,restore()
调用将使用ckpt
文件的权重集填充图形。
显然,如果当前默认图形已经具有import_meta_graph
尝试定义的相同定义,则跳过定义步骤。
这意味着,如果您在导入元图之前已经定义了相同的图,则可以使用 python 变量(例如predictions
(来引用图中的节点。
相反,如果您尚未定义图形,则import_meta_graph
将为您定义图形,但您将没有任何可供使用的 python 变量。
因此,您必须从图中提取对所需节点的引用并创建一个要使用的python变量(例如input = graph.get_tensor_by_name("logits:0")
(