# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import os
import re
import string
import random
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Union, Dict, Any, Iterable
from time import perf_counter
from contextlib import contextmanager
from typing import Tuple, Union, List
import logging
import numpy as np
from syne_tune.constants import (
SYNE_TUNE_DEFAULT_FOLDER,
SYNE_TUNE_ENV_FOLDER,
ST_DATETIME_FORMAT,
)
from syne_tune.try_import import try_import_aws_message
logger = logging.getLogger(__name__)
try:
import sagemaker
except ImportError:
print(try_import_aws_message())
[docs]
class RegularCallback:
"""
Allows to call the callback function at most once every ``call_seconds_frequency`` seconds.
:param callback: Callback object
:param call_seconds_frequency: Wait time between subsequent calls
"""
def __init__(self, callback: callable, call_seconds_frequency: float):
self.time_last_recent_call = datetime.now()
self.frequency = call_seconds_frequency
self.callback = callback
def __call__(self, *args, **kwargs):
seconds_since_last_call = (datetime.now() - self.time_last_recent_call).seconds
if seconds_since_last_call > self.frequency:
self.time_last_recent_call = datetime.now()
self.callback(*args, **kwargs)
[docs]
def experiment_path(
tuner_name: Optional[str] = None, local_path: Optional[str] = None
) -> Path:
"""
Return the path of an experiment which is used both by :class:`~syne_tune.Tuner`
and to collect results of experiments.
:param tuner_name: Name of a tuning experiment
:param local_path: Local path where results should be saved when running
locally outside of SageMaker. If not specified, then the environment
variable ``"SYNETUNE_FOLDER"`` is used if defined otherwise ``~/syne-tune/``
is used. Defining the environment variable ``"SYNETUNE_FOLDER"`` allows to
override the default path.
:return: Path where to write logs and results for Syne Tune tuner. On
SageMaker, results are written to ``"/opt/ml/checkpoints/"`` so that files
are persisted continuously to S3 by SageMaker.
"""
is_sagemaker = "SM_MODEL_DIR" in os.environ
if is_sagemaker:
# if SM_MODEL_DIR is present in the environment variable, this means that we are running on Sagemaker
# we use this path to store results as it is persisted by Sagemaker.
result_path = Path("/opt/ml/checkpoints")
else:
# means we are running on a local machine, we store results in a local path
if local_path is None:
if SYNE_TUNE_ENV_FOLDER in os.environ:
result_path = Path(os.environ[SYNE_TUNE_ENV_FOLDER]).expanduser()
else:
result_path = Path(f"~/{SYNE_TUNE_DEFAULT_FOLDER}").expanduser()
else:
result_path = Path(local_path)
if tuner_name is not None:
result_path = result_path / tuner_name
return result_path
[docs]
def s3_experiment_path(
s3_bucket: Optional[str] = None,
experiment_name: Optional[str] = None,
tuner_name: Optional[str] = None,
) -> str:
"""Returns S3 path for storing results and checkpoints.
:param s3_bucket: If not given, the default bucket for the SageMaker
session is used
:param experiment_name: If given, this is used as first directory
:param tuner_name: If given, this is used as second directory
:return: S3 path, ending on "/"
"""
if s3_bucket is None:
s3_bucket = sagemaker.Session().default_bucket()
s3_path = f"s3://{s3_bucket}/{SYNE_TUNE_DEFAULT_FOLDER}/"
for part in (experiment_name, tuner_name):
if part is not None:
s3_path += part + "/"
return s3_path
[docs]
def check_valid_sagemaker_name(name: str):
assert re.compile("^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$").match(
name
), f"{name} should consists in alpha-digits possibly separated by character -"
[docs]
def sanitize_sagemaker_name(name: str) -> str:
new_name = name.replace("_", "-")
check_valid_sagemaker_name(new_name)
return new_name
[docs]
def name_from_base(base: Optional[str], default: str, max_length: int = 63) -> str:
"""Append a timestamp to the provided string.
This function assures that the total length of the resulting string is
not longer than the specified max length, trimming the input parameter if
necessary.
:param base: String used as prefix to generate the unique name
:param default: String used if :code:`base is None`
:param max_length: Maximum length for the resulting string (default: 63)
:return: Input parameter with appended timestamp
"""
if base is None:
check_valid_sagemaker_name(default)
base = default
else:
check_valid_sagemaker_name(base)
moment = time.time()
moment_ms = repr(moment).split(".")[1][:3]
format = ST_DATETIME_FORMAT + f"-{moment_ms}"
timestamp = time.strftime(format, time.gmtime(moment))
trimmed_base = base[: max_length - len(timestamp) - 1]
return "{}-{}".format(trimmed_base, timestamp)
[docs]
def random_string(length: int) -> str:
pool = string.ascii_letters + string.digits
return "".join(random.choice(pool) for _ in range(length))
[docs]
def repository_root_path() -> Path:
"""
:return: Returns path including ``syne_tune``, ``examples``, ``benchmarking``
"""
return Path(__file__).parent.parent
[docs]
def script_checkpoint_example_path() -> Path:
"""
:return: Path of checkpoint example
"""
path = (
repository_root_path()
/ "examples"
/ "training_scripts"
/ "checkpoint_example"
/ "checkpoint_example.py"
)
assert path.exists()
return path
[docs]
def script_height_example_path() -> Path:
"""
:return: Path of ``train_heigth`` example
"""
path = (
repository_root_path()
/ "examples"
/ "training_scripts"
/ "height_example"
/ "train_height.py"
)
assert path.exists()
return path
[docs]
@contextmanager
def catchtime(name: str) -> float:
start = perf_counter()
try:
print(f"start: {name}")
yield lambda: perf_counter() - start
finally:
print(f"Time for {name}: {perf_counter() - start:.4f} secs")
[docs]
def is_increasing(lst: List[Union[float, int]]) -> bool:
"""
:param lst: List of float or int entries
:return: Is ``lst`` strictly increasing?
"""
return all(x < y for x, y in zip(lst, lst[1:]))
[docs]
def is_positive_integer(lst: List[int]) -> bool:
"""
:param lst: List of int entries
:return: Are all entries of ``lst`` of type ``int`` and positive?
"""
return all(x == int(x) and x >= 1 for x in lst)
[docs]
def is_integer(lst: list) -> bool:
"""
:param lst: List of entries
:return: Are all entries of ``lst`` of type ``int``?
"""
return all(x == int(x) for x in lst)
[docs]
def dump_json_with_numpy(
x: dict, filename: Optional[Union[str, Path]] = None
) -> Optional[str]:
"""
Serializes dictionary ``x`` in JSON, taking into account NumPy specific
value types such as ``n.p.int64``.
:param x: Dictionary to serialize or encode
:param filename: Name of file to store JSON to. Optional. If not given,
the JSON encoding is returned as string
:return: If ``filename is None``, JSON encoding is returned
"""
def np_encoder(obj):
if isinstance(obj, np.generic):
return obj.item()
if filename is None:
return json.dumps(x, default=np_encoder)
else:
with open(filename, "w") as f:
json.dump(x, f, default=np_encoder)
return None
[docs]
def dict_get(params: Dict[str, Any], key: str, default: Any) -> Any:
"""
Returns ``params[key]`` if this exists and is not None, and ``default`` otherwise.
Note that this is not the same as ``params.get(key, default)``. Namely, if ``params[key]``
is equal to None, this would return None, but this method returns ``default``.
This function is particularly helpful when dealing with a dict returned by
:class:`argparse.ArgumentParser`. Whenever ``key`` is added as argument to the parser,
but a value is not provided, this leads to ``params[key] = None``.
"""
v = params.get(key)
return default if v is None else v
[docs]
def recursive_merge(
a: Dict[str, Any],
b: Dict[str, Any],
stop_keys: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Merge dictionaries ``a`` and ``b``, where ``b`` takes precedence. We
typically use this to modify a dictionary ``a``, so ``b`` is smaller
than ``a``. Further recursion is stopped on any node with key in
``stop_keys``. Use this for dictionary-valued entries not to be merged,
but to be replaced by what is in ``b``.
:param a: Dictionary
:param b: Dictionary (can be empty)
:param stop_keys: See above, optional
:return: Merged dictionary
"""
if b:
if stop_keys is None:
stop_keys = []
result = dict()
keys_b = set(b.keys())
for k, va in a.items():
if k in keys_b:
keys_b.remove(k)
vb = b[k]
stop_recursion = k in stop_keys
if isinstance(va, dict) and not stop_recursion:
assert isinstance(
vb, dict
), f"k={k} has dict value in a, but not in b:\n{va}\n{vb}"
result[k] = recursive_merge(va, vb)
else:
assert stop_recursion or not isinstance(
vb, dict
), f"k={k} has dict value in b, but not in a:\n{va}\n{vb}"
result[k] = vb
else:
result[k] = va
result.update({k: b[k] for k in keys_b})
return result
else:
return a
[docs]
def find_first_of_type(a: Iterable[Any], typ) -> Optional[Any]:
try:
return next(x for x in a if isinstance(x, typ))
except StopIteration:
return None
[docs]
def metric_name_mode(
metric_names: List[str], metric_mode: Union[str, List[str]], metric: Union[str, int]
) -> Tuple[str, str]:
"""
Retrieve the metric mode given a metric queried by either index or name.
:param metric_names: metrics names defined in a scheduler
:param metric_mode: metric mode or modes of a scheduler
:param metric: Index or name of the selected metric
:return the name and the mode of the queried metric
"""
if isinstance(metric, str):
assert (
metric in metric_names
), f"Attempted to use {metric} while available metrics are {metric_names}"
metric_name = metric
elif isinstance(metric, int):
assert metric < len(
metric_names
), f"Attempted to use metric index={metric} with {len(metric_names)} available metrics"
metric_name = metric_names[metric]
else:
raise AttributeError(
f"metric must be <int> or <str> but {type(metric)} was provided"
)
if len(metric_names) > 1:
logger.warning(
"Several metrics exist, this will "
f"use metric={metric_name} (index={metric}) out of {metric_names}."
)
if isinstance(metric_mode, list):
metric_index = (
metric_names.index(metric_name) if isinstance(metric, str) else metric
)
metric_mode = metric_mode[metric_index]
return metric_name, metric_mode