我已经使用 Sagemaker 在 Pytorch 中训练并部署了一个模型。我能够调用端点并获得预测。我正在使用默认的 input_fn(( 函数(即未在我的 serve.py 中定义(。
model = PyTorchModel(model_data=trained_model_location,
role=role,
framework_version='1.0.0',
entry_point='serve.py',
source_dir='source')
predictor = model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')
可以按如下方式进行预测:
input ="0.12787057, 1.0612601, -1.1081504"
predictor.predict(np.genfromtxt(StringIO(input), delimiter=",").reshape(1,3) )
我希望能够使用 REST API 为模型提供服务,并且使用 lambda 和 API 网关进行 HTTP POST。我能够以这种方式将invoke_endpoint((与Sagemaker中的XGBOOST模型一起使用。我不确定该为Pytorch发送什么。
client = boto3.client('sagemaker-runtime')
response = client.invoke_endpoint(EndpointName=ENDPOINT ,
ContentType='text/csv',
Body=???)
我相信我需要了解如何编写客户input_fn,以接受和处理我能够通过invoke_client发送的数据类型。我是否走在正确的轨道上,如果是这样,如何编写input_fn以接受来自invoke_endpoint的 csv?
是的,你走在正确的轨道上。您可以将 csv 序列化输入发送到终端节点,而无需使用 SageMaker 开发工具包中的predictor
,也可以使用其他开发工具包,例如安装在 lambda 中的boto3
:
import boto3
runtime = boto3.client('sagemaker-runtime')
payload = '0.12787057, 1.0612601, -1.1081504'
response = runtime.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
ContentType='text/csv',
Body=payload.encode('utf-8'))
result = json.loads(response['Body'].read().decode())
这会将 csv 格式的输入传递给端点,您可能需要在input_fn
中重新调整该输入,以放入模型所需的适当维度。
例如:
def input_fn(request_body, request_content_type):
if request_content_type == 'text/csv':
return torch.from_numpy(
np.genfromtxt(StringIO(request_body), delimiter=',').reshape(1,3))
注意:我无法使用您的输入内容和形状测试上面的特定input_fn
,但我在 Sklearn RandomForest 上使用了几次该方法,并且查看 Pytorch SageMaker 服务文档,上述基本原理应该有效。
不要犹豫,使用 Cloudwatch 中的终端节点日志来诊断任何推理错误(可从控制台的终端节点 UI 获得(,这些日志通常比推理开发工具包返回的高级日志详细得多