加载冻结的 tf 模型 - 没有占位符张量值



>我试图冻结我的张量流图并恢复它,但是当我尝试运行预测时,我收到错误:

You must feed a value for placeholder tensor 'DQNetwork/actions' with dtype float and shape [?,10] 

我的恢复代码是:

sess = tf.Session()
graph = tf.get_default_graph()
with graph.as_default():
    with sess.as_default():
        GRAPH_PB_PATH = "./frozentensorflowModel.pb"
        with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
            graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        x_tensor = graph.get_tensor_by_name("DQNetwork/inputs:0")
        op_to_restore = graph.get_tensor_by_name("DQNetwork/actions:0")
        new_state(cards.copy())
        state = game_state.state
        feed_dict={x_tensor: state.reshape((1, *state.shape))}
        opt = []
        opt = sess.run(op_to_restore, feed_dict) # Error throws
        predictions = np.argmax(opt, 1)

我像这样定义我的DQNetwork输入:

DQNetwork.inputs = tf.placeholder(tf.float32, [None, state_size], name="inputs") 
DQNetwork.actions = tf.placeholder(tf.float32, [None, action_size], name="actions")

更多信息:

>>>op_to_restore
<tf.Tensor 'DQNetwork/actions:0' shape=(?, 10) dtype=float32>
>>>op_to_restore.op
<tf.Operation 'DQNetwork/actions' type=Placeholder>

培训专线:

results = sess.run(DQNetwork.output, feed_dict = {DQNetwork.inputs: input_batch})

这可能有助于您:

sess = tf.Session()
graph = tf.get_default_graph()
with graph.as_default():
    with sess.as_default():
        GRAPH_PB_PATH = "./frozentensorflowModel.pb"
        with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
            graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        x_tensor = graph.get_tensor_by_name("DQNetwork/inputs:0")
        op_to_restore = graph.get_operation_by_name("DQNetwork/actions")
        new_state(cards.copy())
        state = game_state.state
        feed_dict={x_tensor: state.reshape((1, *state.shape))}
        opt = []
        opt = sess.run(op_to_restore, feed_dict) # Error throws
        predictions = np.argmax(opt, 1)

这就是我的建议。

我明白了:

feed_dict={x_tensor: state.reshape((1, *state.shape))}

与其尝试sess.run(op_to_restore, feed_dict) op_to_restore.eval(feeddict)

相关内容

最新更新