张量流 - 停止恢复网络参数



我正在尝试从张量流网络进行多个顺序预测,但性能似乎很差(对于 2 层 8x8 卷积网络,每个预测 ~500 毫秒(,即使对于 CPU。 我怀疑部分问题在于它似乎每次都在重新加载网络参数。 在下面的代码中,每次调用classifier.predict都会产生以下输出行 - 因此我看到数百次。

INFO:tensorflow:Restoring parameters from /tmp/model_data/model.ckpt-102001

如何重用已加载的检查点?

(我不能在这里进行批量预测,因为网络的输出是在游戏中玩的动作,然后需要在馈送新游戏状态之前将其应用于当前状态。

这是进行预测的循环。

def rollout(classifier, state):
while not state.terminated:
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": state.as_nn_input()}, shuffle=False)
prediction = next(classifier.predict(input_fn=predict_input_fn))
index = np.random.choice(NUM_ACTIONS, p=prediction["probabilities"]) # Select a move according to the network's output probabilities
state.apply_move(index)

classifiertf.estimator.Estimator用...

classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir=os.path.join(tempfile.gettempdir(), 'model_data'))

估算器 API 是一个高级 API。

tf.estimator 框架使构建和训练变得容易 通过其高级估算器 API 的机器学习模型。估计 提供可实例化以快速配置通用模型的类 回归器和分类器等类型。

估算器 API 抽象化了 TensorFlow 的许多复杂性,但在此过程中失去了一些通用性。 阅读代码后,很明显,如果不每次重新加载模型,就无法运行多个顺序预测。 低级 TensorFlow API 允许这种行为。 但。。。

Keras 是一个支持此用例的高级框架。 简单定义模型,然后重复调用predict

def rollout(model, state):
while not state.terminated:
predictions = model.predict(state.as_nn_input())
for _, prediction in enumerate(predictions):
index = np.random.choice(bt.ACTIONS, p=prediction)
state.apply_mode(index)

不科学的基准测试表明,这快了~100倍。

最新更新