# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
import boto3
from botocore.exceptions import ClientError
import time
from sagemaker import LocalSession
from sagemaker.estimator import Framework
from sagemaker.interactive_apps import TensorBoardApp
from syne_tune.backend.trial_backend import TrialBackend, BUSY_STATUS
from syne_tune.constants import (
ST_INSTANCE_TYPE,
ST_INSTANCE_COUNT,
ST_CHECKPOINT_DIR,
ST_CONFIG_JSON_FNAME_ARG,
)
from syne_tune.util import s3_experiment_path, dump_json_with_numpy
from syne_tune.backend.trial_status import TrialResult, Status
from syne_tune.backend.sagemaker_backend.sagemaker_utils import (
sagemaker_search,
get_log,
sagemaker_fit,
add_syne_tune_dependency,
map_identifier_limited_length,
s3_copy_objects_recursively,
s3_delete_objects_recursively,
default_config,
default_sagemaker_session,
add_metric_definitions_to_sagemaker_estimator,
)
logger = logging.getLogger(__name__)
CONFIG_JSON_FILENAME = "syne_tune_sm_backend_config_31415927"
[docs]
class SageMakerBackend(TrialBackend):
"""
This backend executes each trial evaluation as a separate SageMaker
training job, using ``sm_estimator`` as estimator.
Checkpoints are written to and loaded from S3, using ``checkpoint_s3_uri``
of the estimator.
Compared to :class:`LocalBackend`, this backend can run any number of
jobs in parallel (given sufficient resources), and any instance type can
be used.
This backend allows to select the instance type and count for a trial
evaluation, by passing values in the configuration, using names
:const:`~syne_tune.constants.ST_INSTANCE_TYPE` and
:const:`~syne_tune.constants.ST_INSTANCE_COUNT`. If these are given in the
configuration, they overwrite the default in ``sm_estimator``. This allows
for tuning instance type and count along with the hyperparameter
configuration.
Additional arguments on top of parent class
:class:`~syne_tune.backend.trial_backend.TrialBackend`:
:param sm_estimator: SageMaker estimator for trial evaluations.
:param metrics_names: Names of metrics passed to ``report``, used to plot
live curve in SageMaker (optional, only used for visualization)
:param s3_path: S3 base path used for checkpointing. The full path
also involves the tuner name and the ``trial_id``. The default base
path is the S3 bucket associated with the SageMaker account
:param sagemaker_fit_kwargs: Extra arguments that passed to
:class:`sagemaker.estimator.Framework` when fitting the job, for instance
:code:`{'train': 's3://my-data-bucket/path/to/my/training/data'}`
"""
def __init__(
self,
sm_estimator: Framework,
metrics_names: Optional[List[str]] = None,
s3_path: Optional[str] = None,
delete_checkpoints: bool = False,
pass_args_as_json: bool = False,
**sagemaker_fit_kwargs,
):
super(SageMakerBackend, self).__init__(
delete_checkpoints=delete_checkpoints, pass_args_as_json=pass_args_as_json
)
self.sm_estimator = sm_estimator
# edit the sagemaker estimator so that metrics of the user can be plotted over time by sagemaker and so that
# the report.py code is available
if metrics_names is None:
metrics_names = []
self.add_metric_definitions_to_sagemaker_estimator(metrics_names)
st_prefix = "st-"
if self.sm_estimator.base_job_name is None:
base_job_name = st_prefix
else:
base_job_name = st_prefix + self.sm_estimator.base_job_name
# Make sure len(base_job_name) <= 63
self.sm_estimator.base_job_name = map_identifier_limited_length(base_job_name)
add_syne_tune_dependency(self.sm_estimator)
self.job_id_mapping = dict()
self.sagemaker_fit_kwargs = sagemaker_fit_kwargs
# we keep the list of jobs that were paused/stopped as Sagemaker training job status is not immediately changed
# after stopping a job.
self.paused_jobs = set()
self.stopped_jobs = set()
# Counts how often a trial has been resumed
self.resumed_counter = dict()
if s3_path is None:
s3_path = s3_experiment_path()
self.s3_path = s3_path.rstrip("/")
# ``tuner_name`` has to be set before the backend can be used. This is
# typically done when the ``Tuner`` is created
self.tuner_name = None
# Trials which may currently be busy (status in ``BUSY_STATUS``). The
# corresponding jobs are polled for status in ``busy_trial_ids``, and
# new trials are addd in :meth:`_schedule`.
# Note: A trial can be in ``paused_jobs`` or ``stopped_jobs`` and still
# be busy, because the underlying SM training job is still not completed
self._busy_trial_id_candidates = set()
# This is to estimate the stopping time for a trial (useful information
# for now, can be removed once stop delays are reduced).
# Note: Trials with a very short stop delay may be missed. This is fine,
# because we mainly want to highlight long stop delays.
self._stopping_time = dict()
# Collects trial IDs for which checkpoints have been deleted (see
# :meth:`delete_checkpoint`)
self._trial_ids_deleted_checkpoints = set()
@property
def sm_client(self):
return boto3.client(service_name="sagemaker", config=default_config())
[docs]
def add_metric_definitions_to_sagemaker_estimator(self, metrics_names: List[str]):
# We add metric definitions corresponding to the metrics passed by ``report`` that the user wants to track
# this allows to plot live learning curves of metrics in Sagemaker console.
# The reason why we ask to the user metric names is that they are required to be known before hand so that live
# plotting works.
add_metric_definitions_to_sagemaker_estimator(self.sm_estimator, metrics_names)
def _all_trial_results(self, trial_ids: List[int]) -> List[TrialResult]:
trial_ids_and_names = []
for jobid in trial_ids:
name = self.job_id_mapping.get(jobid)
if name is not None:
trial_ids_and_names.append((jobid, name))
if trial_ids_and_names:
res = sagemaker_search(
trial_ids_and_names=trial_ids_and_names,
sm_client=self.sm_client,
)
else:
res = []
# overrides the status return by Sagemaker as the stopping decision may not have been propagated yet.
for trial_res in res:
trial_id = trial_res.trial_id
if trial_id in self.paused_jobs:
trial_res.status = Status.paused
if trial_id in self.stopped_jobs:
trial_res.status = Status.stopped
return res
@staticmethod
def _numpy_serialize(mydict):
return json.loads(dump_json_with_numpy(mydict))
def _assert_tuner_name_is_set(self):
assert (
self.tuner_name is not None
), "tuner_name has to be set (by calling set_path) before the backend can be used"
def _checkpoint_s3_uri_for_trial(self, trial_id: int) -> str:
self._assert_tuner_name_is_set()
res_path = f"{self.s3_path}/{self.tuner_name}"
return f"{res_path}/{str(trial_id)}/checkpoints/"
def _config_json_filename(self, trial_id: int, with_path: bool) -> str:
fname = CONFIG_JSON_FILENAME + f"_{trial_id}.json"
if with_path and self.source_dir is not None:
return str(Path(self.source_dir) / fname)
else:
return fname
def _hyperparameters_from_config(
self, trial_id: int, config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Prepares hyperparameters, to be sent to the entry point script as
command line arguments, given the configuration ``config``. If
``pass_args_as_json == False``, this is just a copy of ``config``.
But otherwise, the configuration is written to a JSON file, whose
name becomes a hyperparameter, but entries of the config are not
hyperparameters. Note that some default entries attached to the
config by Syne Tune are always passed as command line arguments, so if
``pass_args_as_json == True``, they are removed from the config before
this is written as JSON file.
:param trial_id: ID of trial
:param config: Configuration
:return: Hyperparameters to be passed to estimator entry point
"""
config_copy = config.copy()
if not self.pass_args_as_json:
return config_copy
else:
self._set_source_dir() # Make sure that ``source_dir`` attribute is set
result = self._prepare_hyperparameters_if_args_as_json(
trial_id, config_copy
)
dump_json_with_numpy(
config_copy, self._config_json_filename(trial_id, with_path=True)
)
return result
def _set_source_dir(self):
if self.source_dir is None:
entrypoint_path = self.entrypoint_path()
source_dir = str(entrypoint_path.parent)
entrypoint_name = entrypoint_path.name
logger.warning(
"sm_estimator.source_dir is not set, but is needed for "
"pass_args_as_json == True. Setting them to:\n"
f"source_dir = {source_dir}, entry_point = {entrypoint_name}"
)
self.sm_estimator.source_dir = source_dir
self.sm_estimator.entry_point = entrypoint_name
def _prepare_hyperparameters_if_args_as_json(
self, trial_id: int, config: Dict[str, Any]
) -> Dict[str, Any]:
# The filename depends on the trial ID. Otherwise, there would be
# clashes between trials which run at overlapping times
result = {
ST_CONFIG_JSON_FNAME_ARG: "./"
+ self._config_json_filename(trial_id, with_path=False)
}
# These arguments remain command line parameters
if ST_INSTANCE_TYPE in config:
result[ST_INSTANCE_TYPE] = config.pop(ST_INSTANCE_TYPE)
if ST_INSTANCE_COUNT in config:
result[ST_INSTANCE_COUNT] = config.pop(ST_INSTANCE_COUNT)
return result
def _schedule(self, trial_id: int, config: Dict[str, Any]):
hyperparameters = self._hyperparameters_from_config(trial_id, config)
hyperparameters[ST_CHECKPOINT_DIR] = "/opt/ml/checkpoints"
# This passes the instance type and instance count to the training function in Sagemaker as hyperparameters
# with reserved names ``st_instance_type`` and ``st_instance_count``.
# We pass them as hyperparameters as it is not easy to get efficiently from inside Sagemaker training script
# (this information is not given for instance as Sagemaker environment variables).
# This allows to: 1) measure cost in the worker 2) tune instance_type and instance_count by having
# ``st_instance_type`` or ``st_instance_count`` in the config space.
# TODO once we have a multiobjective scheduler, we should add an example on how to tune instance-type/count.
if ST_INSTANCE_TYPE not in config:
hyperparameters[ST_INSTANCE_TYPE] = self.sm_estimator.instance_type
else:
self.sm_estimator.instance_type = config[ST_INSTANCE_TYPE]
if ST_INSTANCE_COUNT not in config:
hyperparameters[ST_INSTANCE_COUNT] = self.sm_estimator.instance_count
else:
self.sm_estimator.instance_count = config[ST_INSTANCE_COUNT]
if self.sm_estimator.instance_type != "local":
checkpoint_s3_uri = self._checkpoint_s3_uri_for_trial(trial_id)
logging.info(
f"Trial {trial_id} will checkpoint results to {checkpoint_s3_uri}."
)
else:
# checkpointing is not supported in local mode. When using local mode with remote tuner (for instance for
# debugging), results are not stored.
checkpoint_s3_uri = None
# Once a trial gets resumed, the running job number has to feature in
# the SM job_name
try:
jobname = sagemaker_fit(
sm_estimator=self.sm_estimator,
hyperparameters=self._numpy_serialize(hyperparameters),
checkpoint_s3_uri=checkpoint_s3_uri,
job_name=self._make_sagemaker_jobname(
trial_id=trial_id,
job_running_number=self.resumed_counter.get(trial_id, 0),
),
**self.sagemaker_fit_kwargs,
)
except ClientError as ex:
if "ResourceLimitExceeded" in str(ex):
logger.warning(
"Your resource limit has been exceeded. Here are some hints:\n"
"- Choose Tuner.n_workers <= your limit for the instance type\n"
"- Use Tuner.start_jobs_without_delay = False. Setting this to "
"True (default) means that more than Tuner.n_workers jobs "
"will run at certain times"
)
raise
logger.info(f"scheduled {jobname} for trial-id {trial_id}")
self.job_id_mapping[trial_id] = jobname
self._busy_trial_id_candidates.add(trial_id) # Mark trial as busy
def _make_sagemaker_jobname(self, trial_id: int, job_running_number: int) -> str:
"""
:param trial_id: ID of trial
:param job_running_number: Number of times the trial was resumed
:return: sagemaker job name with the form
``[trial_id]-[job_running_number]-[tuner_name]``. ``trial_id`` is put
first to avoid mismatch when searching for job information in
SageMaker from prefix.
"""
self._assert_tuner_name_is_set()
job_name = f"{trial_id}"
if job_running_number > 0:
job_name += f"-{job_running_number}"
job_name += f"-{self.tuner_name}"
return job_name
def _pause_trial(self, trial_id: int, result: Optional[dict]):
self._stop_trial_job(trial_id)
self.paused_jobs.add(trial_id)
def _stop_trial(self, trial_id: int, result: Optional[dict]):
training_job_name = self.job_id_mapping[trial_id]
logger.info(f"stopping {trial_id} ({training_job_name})")
self._stop_trial_job(trial_id)
self.stopped_jobs.add(trial_id)
def _stop_trial_job(self, trial_id: int):
training_job_name = self.job_id_mapping[trial_id]
try:
self.sm_client.stop_training_job(TrainingJobName=training_job_name)
except ClientError:
# the scheduler may have decided to stop a job that finished already
pass
def _resume_trial(self, trial_id: int):
assert (
trial_id in self.paused_jobs
), f"Try to resume trial {trial_id} that was not paused before."
self.paused_jobs.remove(trial_id)
if trial_id in self.resumed_counter:
self.resumed_counter[trial_id] += 1
else:
self.resumed_counter[trial_id] = 1
def _get_busy_trial_ids(
self, trial_results: List[TrialResult]
) -> List[Tuple[int, str]]:
busy_list = []
reported_trial_ids = set()
for result in trial_results:
trial_id, status = result.trial_id, result.status
reported_trial_ids.add(trial_id)
if status in BUSY_STATUS:
busy_list.append((trial_id, result.status))
if status == Status.stopping and trial_id not in self._stopping_time:
# First time we see ``Status.stopping`` for this ``trial_id``
self._stopping_time[trial_id] = time.time()
elif trial_id in self._stopping_time:
# Trial just stopped being busy
stop_time = time.time() - self._stopping_time[trial_id]
logger.info(
f"Estimated stopping delay for trial_id {trial_id}: {stop_time:.2f} secs"
)
del self._stopping_time[trial_id]
# Note: It can happen that the result of ``sagemaker_search`` does
# not contain all trial_id's requested. We keep such trial_id's in
# the busy list
extra_trial_ids = []
for trial_id in self._busy_trial_id_candidates.difference(reported_trial_ids):
# Assume that status is "in_progress": If ``sagemaker_search``
# drops jobs, they are the ones that have just been started
busy_list.append((trial_id, Status.in_progress))
extra_trial_ids.append(trial_id)
if extra_trial_ids:
logger.info(
f"Did not obtain status for these trial ids: [{extra_trial_ids}]. "
f"Will count them as busy with status {Status.in_progress}"
)
return busy_list
[docs]
def busy_trial_ids(self) -> List[Tuple[int, str]]:
# Note that at this point, ``self._busy_trial_id_candidates`` contains
# trials whose jobs have been busy in the past, but they may have
# stopped or terminated since. We query the current status for all
# these jobs and update ``self._busy_trial_id_candidates`` accordingly.
# It can happen that the status for such a trial is not returned (if
# it has just been started). In this case, the trial is kept in the
# list and treated as busy.
if self._busy_trial_id_candidates:
trial_ids_and_names = [
(trial_id, self.job_id_mapping[trial_id])
for trial_id in self._busy_trial_id_candidates
]
# This is calling the SageMaker API in order to query the current
# status for all trials in ``_busy_trial_id_candidates``
trial_results = sagemaker_search(
trial_ids_and_names, sm_client=self.sm_client
)
busy_list = self._get_busy_trial_ids(trial_results)
# Update internal candidate list
self._busy_trial_id_candidates = set(trial_id for trial_id, _ in busy_list)
return busy_list
else:
return []
[docs]
def stdout(self, trial_id: int) -> List[str]:
return get_log(self.job_id_mapping[trial_id])
[docs]
def stderr(self, trial_id: int) -> List[str]:
return get_log(self.job_id_mapping[trial_id])
@property
def source_dir(self) -> Optional[str]:
return self.sm_estimator.source_dir
[docs]
def set_entrypoint(self, entry_point: str):
self.sm_estimator.entry_point = entry_point
[docs]
def entrypoint_path(self) -> Path:
if self.source_dir is None:
return Path(self.sm_estimator.entry_point)
else:
return Path(self.source_dir) / self.sm_estimator.entry_point
def __getstate__(self):
# dont store sagemaker client or members that cannot be serialized (for instance because SSLContext cannot
# be serialized), we could remove it by changing our interface
# and having kwargs/args of SagemakerFramework in the constructor of this class (that would be serializable)
# plus the class (for instance PyTorch)
self.sm_estimator.sagemaker_session = None
self.sm_estimator.latest_training_job = None
if hasattr(self.sm_estimator, "tensorboard_app"):
self.sm_estimator.tensorboard_app = None
self.sm_estimator.jobs = []
return self.__dict__
def __setstate__(self, state):
self.__dict__ = state
self.initialize_sagemaker_session()
if hasattr(self.sm_estimator, "tensorboard_app"):
self.sm_estimator.tensorboard_app = TensorBoardApp(
region=self.sm_estimator.sagemaker_session.boto_region_name
)
# adjust the dependencies when running Sagemaker backend on sagemaker with remote launcher
# since they are in a different path
is_running_on_sagemaker = "SM_OUTPUT_DIR" in os.environ
if is_running_on_sagemaker:
# todo support dependencies on Sagemaker estimator, one way would be to ship them with the remote
# dependencies
self.sm_estimator.dependencies = [
Path(dep).name for dep in self.sm_estimator.dependencies
]
[docs]
def initialize_sagemaker_session(self):
if boto3.Session().region_name is None:
# avoids error "Must setup local AWS configuration with a region supported by SageMaker."
# in case no region is explicitely configured
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"
if self.sm_estimator.instance_type in ("local", "local_gpu"):
if (
self.sm_estimator.instance_type == "local_gpu"
and self.sm_estimator.instance_count > 1
):
raise RuntimeError("Distributed Training in Local GPU is not supported")
self.sm_estimator.sagemaker_session = LocalSession()
else:
# Use SageMaker boto3 client with default config. This is important
# to configure automatic retry options properly
self.sm_estimator.sagemaker_session = default_sagemaker_session()
[docs]
def copy_checkpoint(self, src_trial_id: int, tgt_trial_id: int):
s3_source_path = self._checkpoint_s3_uri_for_trial(src_trial_id)
s3_target_path = self._checkpoint_s3_uri_for_trial(tgt_trial_id)
logger.info(
f"Copying checkpoint files from {s3_source_path} to " + s3_target_path
)
result = s3_copy_objects_recursively(s3_source_path, s3_target_path)
num_action_calls = result["num_action_calls"]
if num_action_calls == 0:
logger.info(f"No checkpoint files found at {s3_source_path}")
else:
num_successful_action_calls = result["num_successful_action_calls"]
assert num_successful_action_calls == num_action_calls, (
f"{num_successful_action_calls} files copied successfully, "
+ f"{num_action_calls - num_successful_action_calls} failures. "
+ "Error:\n"
+ result["first_error_message"]
)
[docs]
def delete_checkpoint(self, trial_id: int):
if trial_id in self._trial_ids_deleted_checkpoints:
return
s3_path = self._checkpoint_s3_uri_for_trial(trial_id)
result = s3_delete_objects_recursively(s3_path)
self._trial_ids_deleted_checkpoints.add(trial_id)
num_action_calls = result["num_action_calls"]
if num_action_calls <= 0:
return
num_successful_action_calls = result["num_successful_action_calls"]
if num_successful_action_calls == num_action_calls:
logger.info(
f"Deleted {num_action_calls} checkpoint files for "
f"trial_id {trial_id} from {s3_path}"
)
else:
logger.warning(
f"Successfully deleted {num_successful_action_calls} "
f"checkpoint files for trial_id {trial_id} from "
f"{s3_path}, but failed to delete "
f"{num_action_calls - num_successful_action_calls} "
"files. Error:\n" + result["first_error_message"]
)
[docs]
def set_path(
self, results_root: Optional[str] = None, tuner_name: Optional[str] = None
):
"""
For this backend, it is mandatory to call this method passing ``tuner_name``
before the backend is used. ``results_root`` is ignored here.
"""
if tuner_name is not None:
self.tuner_name = tuner_name
[docs]
def on_tuner_save(self):
# Re-initialize the session after :class:`~syne_tune.Tuner` is serialized
self.initialize_sagemaker_session()
def _cleanup_after_trial(self, trial_id: int):
if self.pass_args_as_json:
filename = self._config_json_filename(trial_id, with_path=True)
Path(filename).unlink(missing_ok=True)