预测在模型执行期间使用Tensorflow 1.6和ML-engine失败.它们应该兼容



我使用tensorflow1.6从头开始训练SSD-Inception-V2模型。没有警告或错误。然后,我使用以下标志导出了模型:

--pipeline_config_path experiments/ssd_inception_v2/ssd_inception_v2.config
--trained_checkpoint_prefix experiments/ssd_inception_v2/train/model.ckpt-400097
--output_directory experiments/ssd_inception_v2/frozen_graphs/

之后,我将saved_mode.pb上传到Google Cloud Storage Bucket,在ML-ingine中创建了一个模型并创建了一个版本(我确实使用了--runtime-version=1.6(。

最后,我使用gcloud命令要求进行在线预测,但获得了以下错误:

{
"error": "Prediction failed: Error during model execution: AbortionError(code=StatusCode.INVALID_ARGUMENT, details="The second input must be a scalar, but it has shape [1]nt [[Node: map/while/decode_image/cond_jpeg/cond_png/DecodePng/Switch = Switch[T=DT_STRING, _class=["loc:/TensorArrayReadV3"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](map/while/TensorArrayReadV3, map/while/decode_image/is_jpeg)]]")"
}

日志描述了模型执行时出现的问题。

预测请求的格式为(cf官方文档(:

{
  "instances": [
    ...
  ]
}

根据有关对象检测的此博客文章,FLAG --input_type encoded_image_string_tensor产生了一个名为inputs的单个输入的模型,该模型接受了一批JPG或PNG图像。这些图像必须是基本64编码。因此,将所有内容放在一起,实际请求应该看起来像:

{
  "instances": [
    {
      "inputs": {
        "b64": "..."
      }
    }
  ]
}

由于只有一个输入,因此我们可以使用一个速记,该速记是对object/dictionary {" inputs":{" b64":...}}的插入的速记,只是字典的值,即{" B64":...}:

{
  "instances": [
    {
      "b64": "..."
    }
  ]
}

请注意,如果完全有一个模型输入,则可以接受。

即使以上是服务接受的请求的格式,gcloud命令行工具实际上并不期望请求的整个主体。它期望 实际的"实例",即JSON中的[]之间的事物,被新线分开。这意味着您的文件应该看起来像这样:

{"b64": "..."}

或此

{"inputs": {"b64": "..."}}

如果要发送多个图像,则文件中的每行之一。

尝试类似以下代码的内容来产生输出:

json_data = []
for index, image in enumerate(images, 1):
    with open(image, "rb") as open_file:
        byte_content = open_file.read()
    # Convert to base64
    base64_bytes = b64encode(byte_content)
    # Decode bytes to text
    base64_string = base64_bytes.decode("utf-8")
    # Create dictionary
    raw_data = {"b64": base64_string}
    # Put data to json
    json_data.append(json.dumps(raw_data))
# Write to the file
with open(predict_instance_json, "w") as fp:
    fp.write('n'.join(json_data))

最新更新