我按照教程使用tf.estimator.DNNLinearCombinedRegressor来训练TF模型来预测芝加哥出租车费用。该模型在 Google Cloud Storage 存储桶中生成以下输出。
- 检查站
- eval_estimator-eval/
- events.out.tfevents.1590293793.taxi-train-job-5-5qj5l
- graph.pbtxt
- model.ckpt-0.data-00000-of-00002
- model.ckpt-0.data-00001-of-00002
- model.ckpt-0.index
- model.ckpt-0.meta
- 型号.ckpt-15165.data-00000-of-00002
- 型号.ckpt-15165.data-00001-of-00002
- 型号.ckpt-15165.index
- 型号.ckpt-15165.meta
- model.ckpt-25000.data-00000-of-00002
- 型号.ckpt-25000.data-00001-of-00002
- 模型.ckpt-25000.index
- 模型.ckpt-25000.meta
问:如何使用这些文件加载模型、输入新数据点并进行预测? 假设这些是变量和相应的值:
- hour_of_day: 8
- pickup_latitude: 41.5
- pickup_longitude: -87.2
- dropoff_latitude: 41.9
- dropoff_longitude: -87.1
我在网上搜索了答案,这就是我现在所拥有的。它给了我错误消息。任何人都可以帮忙,让我知道如何纠正它?谢谢!
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('gs://folder_name/model.ckpt-15165.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('gs://folder_name/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("Placeholder:0")
feed_dict = {x:[8, 41.5, -87.2, 41.9, -87.1]}
print(sess.run(feed_dict))
我想通了。
有一个内置函数,可以轻松进行预测。
pred_input_func = tf.estimator.inputs.pandas_input_fn(x = x_test, shuffle = False)
predictions = estimator.predict(pred_input_func)
for result in predictions:
print (result)