Source code for syne_tune.backend.sagemaker_backend.sagemaker_utils

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import functools
import logging
import os
import re
import subprocess
import tarfile
from ast import literal_eval
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Any, Callable

import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from sagemaker import Session
from sagemaker.estimator import Framework, EstimatorBase

import syne_tune
from syne_tune.backend.trial_status import TrialResult
from syne_tune.constants import (
    ST_SAGEMAKER_METRIC_TAG,
    MAX_METRICS_SUPPORTED_BY_SAGEMAKER,
)
from syne_tune.report import retrieve
from syne_tune.util import experiment_path, random_string, s3_experiment_path

logger = logging.getLogger(__name__)


[docs] def default_config() -> Config: """ https://aws.amazon.com/premiumsupport/knowledge-center/sagemaker-python-throttlingexception/ :return: Default config which avoids throttling """ return Config( connect_timeout=5, read_timeout=60, retries={"max_attempts": 20, "mode": "standard"}, )
[docs] def default_sagemaker_session(): sagemaker_client = boto3.client(service_name="sagemaker", config=default_config()) return Session(sagemaker_client=sagemaker_client)
[docs] def get_log(jobname: str, log_client=None) -> List[str]: """ :param jobname: name of a sagemaker training job :param log_client: a log client, for instance ``boto3.client('logs')`` if None, the client is instantiated with the default AWS configuration :return: lines appearing in the log of the Sagemaker training job """ if log_client is None: log_client = boto3.client("logs", config=default_config()) streams = log_client.describe_log_streams( logGroupName="/aws/sagemaker/TrainingJobs", logStreamNamePrefix=jobname ) res = [] for stream in streams["logStreams"]: get_response = functools.partial( log_client.get_log_events, logGroupName="/aws/sagemaker/TrainingJobs", logStreamName=stream["logStreamName"], startFromHead=True, ) response = get_response() for event in response["events"]: res.append(event["message"]) next_token = None while ( "nextForwardToken" in response and next_token != response["nextForwardToken"] ): next_token = response["nextForwardToken"] response = get_response(nextToken=next_token) for event in response["events"]: res.append(event["message"]) return res
[docs] def decode_sagemaker_hyperparameter(hp: str): # Sagemaker encodes hyperparameters as literals which are compatible with Python, except for true and false # that are respectively encoded as 'true' and 'false'. if hp == "true": return True elif hp == "false": return False return literal_eval(hp)
[docs] def metric_definitions_from_names(metrics_names: List[str]): """ :param metrics_names: names of the metrics present in the log. Metrics must be written in the log as [metric-name]: value, for instance [accuracy]: 0.23 :return: a list of metric dictionaries that can be passed to sagemaker so that metrics are parsed from logs, the list can be passed to ``metric_definitions`` in sagemaker. """ def metric_dict(metric_name): """ :param metric_name: :return: a sagemaker metric definition to enable Sagemaker to interpret metrics from logs """ regex = rf".*[{ST_SAGEMAKER_METRIC_TAG}].*\"{re.escape(metric_name)}\": ([-+]?\d\.?\d*)" return {"Name": metric_name, "Regex": regex} return [metric_dict(m) for m in metrics_names]
[docs] def add_metric_definitions_to_sagemaker_estimator( estimator: EstimatorBase, metrics_names: List[str] ): """ Adds metric definitions according to :func:`metric_definitions_from_names` to ``estimator`` for each name in ``metrics_names``. The regexp for each name is compatible with how :class:`~syne_tune.Reporter` outputs metric values. :param estimator: SageMaker estimator :param metrics_names: Names of metrics to be appended """ if metrics_names: current_metric_definitions = estimator.metric_definitions if current_metric_definitions is None: current_metric_definitions = [] new_names = set(metrics_names) current_metric_definitions = [ x for x in current_metric_definitions if x["Name"] not in new_names ] current_metric_definitions += metric_definitions_from_names(metrics_names) if len(current_metric_definitions) > MAX_METRICS_SUPPORTED_BY_SAGEMAKER: current_metric_definitions = current_metric_definitions[ :MAX_METRICS_SUPPORTED_BY_SAGEMAKER ] logger.warning( f"Sagemaker only supports {MAX_METRICS_SUPPORTED_BY_SAGEMAKER} " "metrics for learning curve visualization, keeping only these:\n" + str([x["Name"] for x in current_metric_definitions]) ) estimator.metric_definitions = current_metric_definitions
[docs] def add_syne_tune_dependency(sm_estimator): # adds code of syne tune to the estimator to be sent with the estimator dependencies so that report.py or # other functions of syne tune can be found sm_estimator.dependencies = sm_estimator.dependencies + [ str(Path(syne_tune.__path__[0])) ]
[docs] def sagemaker_fit( sm_estimator: Framework, hyperparameters: Dict[str, object], checkpoint_s3_uri: Optional[str] = None, wait: bool = False, job_name: Optional[str] = None, *sagemaker_fit_args, **sagemaker_fit_kwargs, ): """ :param sm_estimator: sagemaker estimator to be fitted :param hyperparameters: dictionary of hyperparameters that are passed to ``entry_point_script`` :param checkpoint_s3_uri: checkpoint_s3_uri of Sagemaker Estimator :param wait: whether to wait for job completion :param metrics_names: names of metrics to track reported with ``report.py``. In case those metrics are passed, their learning curves will be shown in Sagemaker console. :return: name of sagemaker job """ experiment = sm_estimator experiment._hyperparameters = hyperparameters experiment.checkpoint_s3_uri = checkpoint_s3_uri experiment.fit( wait=wait, job_name=job_name, *sagemaker_fit_args, **sagemaker_fit_kwargs ) return experiment.latest_training_job.job_name
[docs] def get_execution_role(): """ :return: sagemaker execution role that is specified with the environment variable ``AWS_ROLE``, if not specified then we infer it by searching for the role associated to Sagemaker. Note that ``import sagemaker; sagemaker.get_execution_role()`` does not return the right role outside of a Sagemaker notebook. """ if "AWS_ROLE" in os.environ: aws_role = os.environ["AWS_ROLE"] logger.info( f"Using Sagemaker role {aws_role} passed set as environment variable $AWS_ROLE" ) return aws_role else: logger.info( f"No Sagemaker role passed as environment variable $AWS_ROLE, inferring it." ) client = boto3.client("iam", config=default_config()) sm_roles = client.list_roles(PathPrefix="/service-role/")["Roles"] for role in sm_roles: if "AmazonSageMaker-ExecutionRole" in role["RoleName"]: return role["Arn"] raise Exception( "Could not infer Sagemaker role, specify it by specifying ``AWS_ROLE`` environement variable " "or refer to https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html to create a new one" )
[docs] def untar(filename: Path): if str(filename).endswith("tar.gz"): tar = tarfile.open(filename, "r:gz") tar.extractall(path=filename.parent) tar.close()
[docs] def download_sagemaker_results(s3_path: Optional[str] = None): """ Download results obtained after running tuning remotely on Sagemaker, e.g. when using ``RemoteLauncher``. """ if s3_path is None: s3_path = s3_experiment_path() tgt_dir = str(experiment_path()) cmd = f"aws s3 sync {s3_path} {tgt_dir}" logger.info(f"downloading sagemaker results to {tgt_dir} with command {cmd}") subprocess.run(cmd.split(" "))
[docs] def map_identifier_limited_length( name: str, max_length: int = 63, rnd_digits: int = 4 ) -> str: """ If ``name`` is longer than 'max_length`` characters, it is mapped to a new identifier of length ``max_length``, being the concatenation of the first ``max_length - rnd_digits`` characters of ``name``, followed by a random string of length ``hash_digits``. :param name: Identifier to be limited in length :param max_length: Maximum length for output :param rnd_digits: See above :return: See above """ orig_length = len(name) if orig_length <= max_length: return name else: assert 1 < rnd_digits < max_length postfix = random_string(rnd_digits) return name[: (max_length - rnd_digits)] + postfix
def _s3_traverse_recursively( s3_client, action: Callable[[str], Optional[str]], bucket: str, prefix: str, valid_postfixes: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Traverses directory from root ``prefix``. The function ``action`` is applied to all objects encountered, the signature is ``action(object_key)``. ``action`` returns ``None`` if successful, otherwise an error message. We return a dict with "num_action_calls", "num_successful_action_calls", "first_error_message" (the error message for the first failed ``action`` call encountered). If ``valid_postfixes`` is given, ``action`` is only applied to such ``object_key`` for which ``object_key.endswith(postfix)`` for some ``postfix in valid_postfixes``. :param s3_client: S3 client :param action: See above :param bucket: S3 bucket name :param prefix: Prefix from where to traverse, must end with "/" :param valid_postfixes: See above, optional :return: See above """ more_objects = True continuation_kwargs = dict() list_objects_kwargs = dict(Bucket=bucket, Prefix=prefix, Delimiter="/") all_next_prefixes = [] num_action_calls = 0 num_successful_action_calls = 0 first_error_message = None while more_objects: response = s3_client.list_objects_v2( **list_objects_kwargs, **continuation_kwargs ) # Subdirectories for next_prefix in response.get("CommonPrefixes", []): all_next_prefixes.append(next_prefix["Prefix"]) # Objects for source in response.get("Contents", []): object_key = source["Key"] if valid_postfixes is not None and not any( object_key.endswith(postfix) for postfix in valid_postfixes ): continue # Skip this key ret_msg = action(object_key) num_action_calls += 1 if ret_msg is None: num_successful_action_calls += 1 elif first_error_message is None: first_error_message = ret_msg more_objects = "NextContinuationToken" in response if more_objects: continuation_kwargs = { "ContinuationToken": response["NextContinuationToken"] } # Recursive calls for next_prefix in all_next_prefixes: result = _s3_traverse_recursively(s3_client, action, bucket, prefix=next_prefix) num_action_calls += result["num_action_calls"] num_successful_action_calls += result["num_successful_action_calls"] if first_error_message is None: first_error_message = result["first_error_message"] return dict( num_action_calls=num_action_calls, num_successful_action_calls=num_successful_action_calls, first_error_message=first_error_message, ) def _split_bucket_prefix(s3_path: str) -> (str, str): assert s3_path[:5] == "s3://", s3_path parts = s3_path[5:].split("/") bucket = parts[0] prefix = "/".join(parts[1:]) if prefix[-1] != "/": prefix += "/" return bucket, prefix
[docs] def s3_copy_objects_recursively( s3_source_path: str, s3_target_path: str ) -> Dict[str, Any]: """ Recursively copies objects from ``s3_source_path`` to ``s3_target_path``. We return a dict with 'num_action_calls', 'num_successful_action_calls', 'first_error_message' (the error message for the first failed ``action`` call encountered). .. note:: This function should not be used to copy a large number of objects, as it is rather slow (one API call for object) :param s3_source_path: :param s3_target_path: :return: See above """ src_bucket, src_prefix = _split_bucket_prefix(s3_source_path) trg_bucket, trg_prefix = _split_bucket_prefix(s3_target_path) s3_client = boto3.client("s3") def copy_action(object_key: str) -> Optional[str]: assert object_key.startswith( src_prefix ), f"object_key = {object_key} must start with {src_prefix}" target_key = trg_prefix + object_key[len(src_prefix) :] copy_source = dict(Bucket=src_bucket, Key=object_key) ret_msg = None try: s3_client.copy_object( CopySource=copy_source, Bucket=trg_bucket, Key=target_key ) except ClientError as ex: ret_msg = str(ex) return ret_msg return _s3_traverse_recursively( s3_client=s3_client, action=copy_action, bucket=src_bucket, prefix=src_prefix )
[docs] def s3_delete_objects_recursively(s3_path: str) -> Dict[str, Any]: """ Recursively deletes objects from ``s3_path``. We return a dict with 'num_action_calls', 'num_successful_action_calls', 'first_error_message' (the error message for the first failed ``action`` call encountered). .. note:: This function should not be used to delete a large number of objects, as it is rather slow (one API call for object) :param s3_path: :return: See above """ bucket_name, prefix = _split_bucket_prefix(s3_path) s3_client = boto3.client("s3") def delete_action(object_key: str) -> Optional[str]: ret_msg = None try: s3_client.delete_object(Bucket=bucket_name, Key=object_key) except ClientError as ex: ret_msg = str(ex) return ret_msg return _s3_traverse_recursively( s3_client=s3_client, action=delete_action, bucket=bucket_name, prefix=prefix )
[docs] def s3_download_files_recursively( s3_source_path: str, target_path: str, valid_postfixes: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Recursively downloads objects from ``s3_source_path`` and stores them locally as files below ``target_path`` We return a dict with 'num_action_calls', 'num_successful_action_calls', 'first_error_message' (the error message for the first failed ``action`` call encountered). If ``valid_postfixes`` is given, only such objects are downloaded for which ``object_key.endswith(postfix)`` for some ``postfix in valid_postfixes``. .. note:: This function should not be used to download a large number of objects, as it is rather slow (one API call for object). In this case, running ``aws s3 sync`` can be much faster. :param s3_source_path: See above :param target_path: See above :param valid_postfixes: See above, optional :return: See above """ src_bucket, src_prefix = _split_bucket_prefix(s3_source_path) if target_path[-1] != "/": target_path = target_path + "/" s3_client = boto3.client("s3") def download_action(object_key: str) -> Optional[str]: assert object_key.startswith( src_prefix ), f"object_key = {object_key} must start with {src_prefix}" target_file = target_path + object_key[len(src_prefix) :] # Ensure that target directory exists Path(target_file).parent.mkdir(exist_ok=True, parents=True) ret_msg = None try: s3_client.download_file(src_bucket, object_key, target_file) except ClientError as ex: ret_msg = str(ex) return ret_msg return _s3_traverse_recursively( s3_client=s3_client, action=download_action, bucket=src_bucket, prefix=src_prefix, valid_postfixes=valid_postfixes, )
[docs] def backend_path_not_synced_to_s3() -> Path: """ When an experiment with the local backend is run remotely (as SageMaker training job), we do not want checkpoints to be synced to S3, since this is expensive and error-prone (since several trials may write checkpoints at the same time). Pass the returned path to ``trial_backend_path`` when constructing the :class`~syne_tune.Tuner`. Here, we direct checkpoint writing to /opt/ml/input/data/, which is mounted on a partition with sufficient space. Different to /opt/ml/checkpoints, this directory is not synced to S3. :return: Path to set in local backend """ path = Path("/opt/ml/input/data/") path.mkdir(parents=True, exist_ok=True) return path