从导出的TensorFlow模型调用函数时,我收到两个字符串("output_0", "output_1")而不是实际的模型。我怎样才能返回与这个字符串相关联的张量来访问输出?
导出模型:
class OneStep(tf.keras.Model):
...
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string), tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
def generate_one_step(self, inputs, states):
...
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def generate_one_step_none(self, inputs):
...
tf.saved_model.save(one_step_model, 'one_Step', signatures={ "generate_one_step": one_step_model.generate_one_step, "generate_one_step_none": one_step_model.generate_one_step_none})
要导入的代码:
one_step = tf.saved_model.load('one_step')
step_gen = one_step.signatures["generate_one_step"]
step_gen_none = one_step.signatures["generate_one_step_none"]
next_char = tf.constant(['Test'], tf.string)
a, b = step_gen_none(inputs=next_char)
print(a,b) # returns "input_0", "input_1"
必须存储函数调用的结果,然后将其作为数组访问。正确答案是
res = step_gen_none(inputs=next_char)
a = res["output_0"]
b = res["output_1"]
名称也可以像这里描述的那样更改为不那么通用的名称。我在一本移民指南中找到了答案。虽然与答案无关,但值得指出的是,Tensorflow 2没有很好的文档,你通常不应该相信网络上的任何Tensorflow源,除非它们明确提到了v2。