我想使用C++来加载TensorFlow模型。我想知道模型输入的大小,这是模型中的占位符。
我用谷歌搜索这个问题,但我只是在堆栈溢出中找到了这个链接:
C++相当于python:tf。Graph.get_tensor_by_name() 在 Tensorflow 中?
虽然我可以得到节点,但是张量流文档没有告诉我如何访问节点的大小。那么有没有人知道这件事呢?
非常感谢!
好的,经过多次尝试。我找到了一个解决方法,它可能很棘手,但效果很好。
首先,我们可以使用以下代码获取占位符节点:
GraphDef mygd = graph_def.graph_def();
for (int i = 0; i < mygd.node_size(); i++)
{
if (mygd.node(i).name() == input_name)
{
auto node = mygd.node(i);
}
}
然后通过 NodeDef.pd.h(tensorflow/core/framework/node_def.pb.h),我们可以通过如下代码获取 AttrValue:
auto attr = node.attr();
然后通过 attr_value.cc(tensorflow/core/framework/attr_value.cc),我们可以通过如下代码获取形状 attr 值:
tensorflow::AttrValue shape = attr["shape"];
形状 AttrValue 是用于存储形状信息的结构。我们可以通过函数 SummarizeAttrValue in tensorflow/core/framework/attr_value_util.h 获取详细信息。
string size_summary = SummarizeAttrValue(shape);
然后我们可以得到形状的字符串格式,如下所示:
[?,1024]