如何在GCP AI平台上使用TFRecord文件进行批量预测



TL;DR谷歌云AI平台在进行批量预测时如何解压缩TFRecord文件?

我已经在谷歌云人工智能平台上部署了一个经过训练的Keras模型,但我在批量预测的文件格式方面遇到了问题。对于训练,我使用tf.data.TFRecordDataset来读取TFRecord的列表,如下所示,一切都很好。

def unpack_tfrecord(record):
parsed = tf.io.parse_example(record, {
'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  # Input
'class': tf.io.FixedLenFeature([2], tf.int64),            # One-hot classification (binary)
})
return (parsed['chunk'], parsed['class'])
files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset, epochs=train_epochs)
tf.saved_model.save(model, model_save_path)

我将保存的模型上传到云存储,并在AI平台中创建一个新模型。AI平台文档指出;Batch with gcloud工具[支持]具有JSON实例字符串的文本文件或TFRecord文件(可压缩(";(https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data)。但是当我提供一个TFRecord文件时,我得到了错误:

("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte", 8)

我的TFRecord文件包含一堆Protobuf编码的tf.train.Example。我没有向AI平台提供unpack_tfrecord功能,所以我想它不能正确地打开它是有道理的,但我有节点的想法。我对使用JSON格式不感兴趣,因为数据太大了。

我不知道这是否是最好的方法,但对于TF 2.x,你可以做一些类似的事情:

import tensorflow as tf
def make_serving_input_fn():
# your feature spec
feature_spec = {
'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  
'class': tf.io.FixedLenFeature([2], tf.int64),
}
serialized_tf_examples = tf.keras.Input(
shape=[], name='input_example_tensor', dtype=tf.string)
examples = tf.io.parse_example(serialized_tf_examples, feature_spec)
# any processing 
processed_chunks = tf.map_fn(
<PROCESSING_FN>, 
examples['chunk'], # ?
dtype=tf.float32)
return tf.estimator.export.ServingInputReceiver(
features={<MODEL_FIRST_LAYER_NAME>: processed_chunks},
receiver_tensors={"input_example_tensor": serialized_tf_examples}
)

estimator = tf.keras.estimator.model_to_estimator(
keras_model=model,
model_dir=<ESTIMATOR_SAVE_DIR>)
estimator.export_saved_model(
export_dir_base=<WORKING_DIR>,
serving_input_receiver_fn=make_serving_input_fn)

相关内容

  • 没有找到相关文章

最新更新