# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from syne_tune.remote.constants import (
DEFAULT_CPU_INSTANCE,
PYTORCH_LATEST_FRAMEWORK,
PYTORCH_LATEST_PY_VERSION,
HUGGINGFACE_LATEST_FRAMEWORK_VERSION,
HUGGINGFACE_LATEST_PYTORCH_VERSION,
HUGGINGFACE_LATEST_TRANSFORMERS_VERSION,
HUGGINGFACE_LATEST_PY_VERSION,
MXNET_LATEST_PY_VERSION,
MXNET_LATEST_VERSION,
)
from syne_tune.try_import import try_import_aws_message
try:
from sagemaker.pytorch import PyTorch
from sagemaker.huggingface import HuggingFace
from sagemaker.mxnet import MXNet
from sagemaker.tensorflow import TensorFlow
from sagemaker.sklearn import SKLearn
from sagemaker.chainer import Chainer
from sagemaker.xgboost import XGBoost
except ImportError:
print(try_import_aws_message())
[docs]
def instance_sagemaker_estimator(**kwargs):
"""
Returns SageMaker estimator to be used for simulator back-end experiments
and for remote launching of SageMaker back-end experiments.
:param kwargs: Extra arguments to SageMaker estimator
:return: SageMaker estimator
"""
return pytorch_estimator(
**kwargs,
)
[docs]
def basic_cpu_instance_sagemaker_estimator(**kwargs):
"""
Returns SageMaker estimator to be used for simulator back-end experiments
and for remote launching of SageMaker back-end experiments.
:param kwargs: Extra arguments to SageMaker estimator
:return: SageMaker estimator
"""
return pytorch_estimator(
instance_type=DEFAULT_CPU_INSTANCE,
instance_count=1,
**kwargs,
)
[docs]
def pytorch_estimator(**estimator_kwargs) -> PyTorch:
"""
Get the PyTorch sagemaker estimator with the most up-to-date framework version.
List of available containers: https://github.com/aws/deep-learning-containers/blob/master/available_images.md
:param estimator_kwargs: Estimator parameters as discussed in
https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/sagemaker.pytorch.html
:return: PyTorch estimator
"""
return PyTorch(
py_version=PYTORCH_LATEST_PY_VERSION,
framework_version=PYTORCH_LATEST_FRAMEWORK,
**estimator_kwargs,
)
[docs]
def huggingface_estimator(**estimator_kwargs) -> HuggingFace:
"""
Get the Huggingface sagemaker estimator with the most up-to-date framework version.
List of available containers: https://github.com/aws/deep-learning-containers/blob/master/available_images.md
:param estimator_kwargs: Estimator parameters as discussed in
https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html
:return: PyTorch estimator
"""
return HuggingFace(
framework_version=HUGGINGFACE_LATEST_FRAMEWORK_VERSION,
transformers_version=HUGGINGFACE_LATEST_TRANSFORMERS_VERSION,
pytorch_version=HUGGINGFACE_LATEST_PYTORCH_VERSION,
py_version=HUGGINGFACE_LATEST_PY_VERSION,
**estimator_kwargs,
)
[docs]
def sklearn_estimator(**estimator_kwargs) -> SKLearn:
"""
Get the Scikit-learn sagemaker estimator with the most up-to-date framework version.
List of available containers: https://github.com/aws/deep-learning-containers/blob/master/available_images.md
:param estimator_kwargs: Estimator parameters as discussed in
https://sagemaker.readthedocs.io/en/stable/frameworks/sklearn/sagemaker.sklearn.html
:return: PyTorch estimator
"""
return SKLearn(
framework_version="1.0-1",
py_version="py3",
**estimator_kwargs,
)
[docs]
def mxnet_estimator(**estimator_kwargs) -> MXNet:
"""
Get the MXNet sagemaker estimator with the most up-to-date framework version.
List of available containers: https://github.com/aws/deep-learning-containers/blob/master/available_images.md
:param estimator_kwargs: Estimator parameters as discussed in
https://sagemaker.readthedocs.io/en/stable/frameworks/mxnet/sagemaker.mxnet.html
:return: PyTorch estimator
"""
return MXNet(
framework_version=MXNET_LATEST_VERSION,
py_version=MXNET_LATEST_PY_VERSION,
**estimator_kwargs,
)
sagemaker_estimator = {
"PyTorch": pytorch_estimator,
"HuggingFace": huggingface_estimator,
"BasicCPU": basic_cpu_instance_sagemaker_estimator,
"MXNet": MXNet,
"TensorFlow": TensorFlow,
"SKLearn": sklearn_estimator,
"Chainer": Chainer,
"XGBoost": XGBoost,
}