如何加载 .ckpt、.meta 输出 ffiles rom 一个 TF 模型,输入预测数据并进行预测?



我按照教程使用tf.estimator.DNNLinearCombinedRegressor来训练TF模型来预测芝加哥出租车费用。该模型在 Google Cloud Storage 存储桶中生成以下输出。

  • 检查站
  • eval_estimator-eval/
  • events.out.tfevents.1590293793.taxi-train-job-5-5qj5l
  • graph.pbtxt
  • model.ckpt-0.data-00000-of-00002
  • model.ckpt-0.data-00001-of-00002
  • model.ckpt-0.index
  • model.ckpt-0.meta
  • 型号.ckpt-15165.data-00000-of-00002
  • 型号.ckpt-15165.data-00001-of-00002
  • 型号.ckpt-15165.index
  • 型号.ckpt-15165.meta
  • model.ckpt-25000.data-00000-of-00002
  • 型号.ckpt-25000.data-00001-of-00002
  • 模型.ckpt-25000.index
  • 模型.ckpt-25000.meta

问:如何使用这些文件加载模型、输入新数据点并进行预测? 假设这些是变量和相应的值:

  • hour_of_day: 8
  • pickup_latitude: 41.5
  • pickup_longitude: -87.2
  • dropoff_latitude: 41.9
  • dropoff_longitude: -87.1

我在网上搜索了答案,这就是我现在所拥有的。它给了我错误消息。任何人都可以帮忙,让我知道如何纠正它?谢谢!

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('gs://folder_name/model.ckpt-15165.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('gs://folder_name/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("Placeholder:0")
feed_dict = {x:[8, 41.5, -87.2, 41.9, -87.1]}

print(sess.run(feed_dict))

我想通了。

有一个内置函数,可以轻松进行预测。

pred_input_func = tf.estimator.inputs.pandas_input_fn(x = x_test, shuffle = False)
predictions = estimator.predict(pred_input_func)
for result in predictions:
print (result)

最新更新