通过C API访问tensorflow 2.0 SavedModel的输入和输出张量



从加载了C_API的tensorflow 2.0 SavedModel运行推理时遇到问题,因为我无法按名称访问输入和输出操作。

我通过TF_LoadSessionFromSavedModel(…(成功加载会话:

#include <tensorflow/c/c_api>
...
TF_Status* status = TF_NewStatus();
TF_Graph*  graph  = TF_NewGraph();
TF_Buffer* r_opts = TF_NewBufferFromString("",0);
TF_Buffer* meta_g = TF_NewBuffer();
TF_SessionOptions* opts = TF_NewSessionOptions();
const char* tags[] = {"serve"};
TF_Session* session = TF_LoadSessionFromSavedModel(opts, r_opts, "saved_model/tf2_model", tags, 1, graph, meta_g, status);
if ( TF_GetCode(status) != TF_OK ) exit(-1); //does not happen

然而,我在尝试使用设置输入和输出张量时遇到了一个错误

TF_Operation* inputOp  = TF_GraphOperationByName(graph, "input"); //works with "serving_default_input"
TF_Operation* outputOp = TF_GraphOperationByName(graph, "prediction"); //does not work

我作为参数传递的名称被分配给保存模型的输入和输出keras层,但不在加载的graph中。运行saved_model_cli(遵循这里的tfSavedModel教程(显示具有这些名称的腱存在于SignatureDefserving_default下,所以我想我需要将serving_default实例化到图中(换句话说,根据签名创建图(,但是我找不到使用C API实现这一点的方法。

请注意,tensorflows的C_API测试使用C++tensorflow/core/功能从元图加载签名定义映射,并使用它来查找输入和输出操作名称,但我希望避免对C++的依赖。

还要注意,按名称访问操作适用于冻结的.pb图,但这种格式已被弃用。

提前感谢您的任何想法和提示!

目前(截至2020年5月(Tensorflow C API并不正式支持SavedModel(Tensorflow 2.0(格式,尽管他们可能很快就会发布该功能。

无论如何,您可以使用导出模型时定义的默认SignatureDefs,并使用saved_model_cli工具查找输入和输出张量的名称。

假设您使用保存了您的模型

model.save('/path/to/model/folder')

然后你打开一个狂欢节并进行

cd /python/folder/bin/
saved_model_cli show --dir /path/to/model/folder --tag_set serve --signature_def serving_default

(saved_model_cli的实际位置各不相同,但在bin/文件夹上使用anaconda时默认安装(

默认情况下会产生类似的结果

serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['graph_input'] tensor_info:
dtype: DT_DOUBLE
shape: (-1, 28, 28)
name: serving_default_graph_input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['graph_output'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 10)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

在这种情况下,serving_default_graph_input是输入张量名称,而StatefulPartitionedCall则是输出张量名称。然后,您可以使用TF_GraphOperationByName()加载这些。

有了对Tensorflow 2的C API支持,您就可以使用一组定义的SignatureDef保存模型,然后加载所需的concrete_function(),而不必担心张量名称。然而,目前的这种方法应该仍然有效。

相关内容

  • 没有找到相关文章

最新更新