下面的代码在云中的sagemaker笔记本中运行良好。在本地,我还通过aws cli创建了aws凭据。就我个人而言,我不喜欢笔记本(除非我做一些EDA之类的东西)。所以我想知道是否可以从本地机器(例如在visual studio代码中)启动此代码,因为它只告诉sagemaker要做什么?我猜这只是一个验证和获取会话对象的问题?谢谢!
import boto3
import os
import sagemaker
from sagemaker import get_execution_role
from sagemaker.inputs import TrainingInput
from sagemaker.serializers import CSVSerializer
from sagemaker import image_uris
region_name = boto3.session.Session().region_name
s3_bucket_name = 'bucket_name'
# this image cannot be used below !!! there must be an issue with sagemaker ?
training_image_name = image_uris.retrieve(framework='xgboost', region=region_name, version='latest')
role = get_execution_role()
s3_prefix = 'my_model'
train_file_name = 'sagemaker_train.csv'
val_file_name = 'sagemaker_val.csv'
sagemaker_session = sagemaker.Session()
s3_input_train = TrainingInput(s3_data='s3://{}/{}/{}'.format(s3_bucket_name, s3_prefix, train_file_name), content_type='csv')
s3_input_val = TrainingInput(s3_data='s3://{}/{}/{}'.format(s3_bucket_name, s3_prefix, val_file_name), content_type='csv')
hyperparameters = {
"max_depth":"5",
"eta":"0.2",
"gamma":"4",
"min_child_weight":"6",
"subsample":"0.7",
"objective":"reg:squarederror",
"num_round":"10"}
output_path = 's3://{}/{}/output'.format(s3_bucket_name, s3_prefix)
estimator = sagemaker.estimator.Estimator(image_uri=sagemaker.image_uris.retrieve("xgboost", region_name, "1.2-2"),
hyperparameters=hyperparameters,
role=role,
instance_count=1,
instance_type='ml.m5.2xlarge',
volume_size=1, # 1 GB
output_path=output_path)
estimator.fit({'train': s3_input_train, 'validation': s3_input_val})
在本地机器上,
- 确保安装AWS CLI: https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html
- 创建访问密钥id和秘密访问密钥以在本地访问Sagemaker服务:https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html
- 在本地使用
aws configure
命令设置这些凭证。 - 代码应该工作良好,除了获得执行角色。您可以在代码中硬编码Sagemaker角色(不是最佳实践),也可以将其存储在Parameter store中并从那里访问它。