我有一个2gb Tensorflow模型,我想把它添加到我在App Engine上的Flask项目中,但我似乎找不到任何说明我试图做什么的文档。
由于应用程序引擎不允许写入文件系统,我将模型的文件存储在Google Bucket中,并试图从中恢复模型。这些是那里的文件:
- 型号.ckpt.data00000-of-00001
- 型号.ckpt.index
- 模型.ckpt.meta
- 检查点
在本地工作,我只能使用
with tf.Session() as sess:
logger.info("Importing model into TF")
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, model.ckpt)
其中,使用Flask的@before_first_request
将模型加载到内存中。
一旦它在应用程序引擎上,我认为我可以做到这一点:
blob = bucket.get_blob('blob_name')
filename = os.path.join(model_dir, blob.name)
blob.download_to_filename(filename)
然后执行相同的恢复。但应用引擎不允许。
有没有办法将这些文件流式传输到Tensorflow的恢复函数中,这样文件就不必写入文件系统?
Dan Cornelscu提供了一些技巧并深入研究后,我发现Tensorflow使用一个名为ParseFromString
的函数构建了MetaGraphDef
,因此我最终做了以下操作:
from google.cloud import storage
from tensorflow import MetaGraphDef
client = storage.Client()
bucket = client.get_bucket(Config.MODEL_BUCKET)
blob = bucket.get_blob('model.ckpt.meta')
model_graph = blob.download_as_string()
mgd = MetaGraphDef()
mgd.ParseFromString(model_graph)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(mgd)
我实际上并没有使用Tensorflow,答案是基于文档和GAE相关知识。
一般来说,使用GCS对象作为GAE中的文件以避免缺乏可写文件系统访问依赖于两种替代方法之一,而不是仅通过文件名由您的应用程序代码(和/或它可能使用的任何第三方实用程序/库(直接读/写(这不能用GCS对象完成(:
-
使用已打开的文件类处理程序从GCS读取/写入数据。您的应用程序将通过使用以下任一项获得:
- 来自GCS客户端库的
open
调用,而不是通常用于常规文件系统的通用调用。例如,请参阅编写CSV以存储在Google Cloud Storage中,或将python对象pickle到Google Cloud Storage - 一些内存中的文件伪造,使用类似
StringIO
的东西,请参阅如何在不向python中的文件系统写入任何内容的情况下压缩或标记静态文件夹?。内存中的伪文件还可以方便地访问原始数据,以防需要在GCS中持久化,见下文
- 来自GCS客户端库的
-
直接使用或仅生成相应的原始数据,您的应用程序将完全负责从GCS读取/写入GCS(再次使用GCS客户端库的
open
调用(,请参阅如何在gae云上打开gzip文件?
在您的特定情况下,tf.train.import_meta_graph()
调用似乎支持传递MetaGraphDef
协议缓冲区(即原始数据(,而不是应该从中加载的文件名:
Args:
meta_graph_or_file
:MetaGraphDef
协议缓冲区或包含MetaGraphDef
的文件名(包括路径(
因此,从GCS恢复模型应该是可能的,大致如下:
import cloudstorage
with cloudstorage.open('gcs_path_to_meta_graph_file', 'r') as fd:
meta_graph = fd.read()
# and later:
saver = tf.train.import_meta_graph(meta_graph)
然而,从快速文档扫描保存/检查点到GCS的模式可能很棘手,save()
似乎想将数据写入磁盘本身。但我没有挖得太深。