如提到的文档/教程,我们可以调用Estimator.fit()
来启动训练作业。
该方法的必需参数将是对训练文件的 s3/文件引用的inputs
。例:
estimator.fit({'train':'s3://my-bucket/training_data})
training-script.py
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
我希望os.environ['SM_CHANNEL_TRAIN']
是 S3 路径。但相反,它会返回/opt/ml/input/data/train
.
有人知道为什么吗?
更新
我还尝试调用estimator.fit('s3://my-bucket/training_data'(。 不知何故,训练实例没有得到SM_CHANNEL_TRAIN环境变量。事实上,我根本没有在环境变量中看到 s3 URI。
在 SageMaker 中运行训练作业时,包含您提供的训练数据的 S3 URL 最终会从指定的 URL 复制到 docker 容器(也称为训练作业(中。因此,环境变量SM_CHANNEL_TRAIN指向从提供的 S3 URL 复制的训练数据的本地路径。
https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html#SageMaker-CreateTrainingJob-request-InputDataConfig
这很可能是因为您的参数os.environ['SM_CHANNEL_TRAIN']
没有给出带有s3://
前缀的路径,如果您希望它从 s3 中提取数据。如果没有该前缀,它会在映像中搜索自己的本地文件系统以查找该路径。