Source code for syne_tune.util

import json
import logging
import os
import random
import string
import time
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from time import perf_counter
from typing import Optional, Dict, Any, Iterable
from typing import Tuple, Union, List

import numpy as np

from syne_tune.constants import (
    SYNE_TUNE_DEFAULT_FOLDER,
    SYNE_TUNE_ENV_FOLDER,
    ST_DATETIME_FORMAT,
)

logger = logging.getLogger(__name__)


[docs] class RegularCallback: """ Allows to call the callback function at most once every ``call_seconds_frequency`` seconds. :param callback: Callback object :param call_seconds_frequency: Wait time between subsequent calls """ def __init__(self, callback: callable, call_seconds_frequency: float): self.time_last_recent_call = datetime.now() self.frequency = call_seconds_frequency self.callback = callback def __call__(self, *args, **kwargs): seconds_since_last_call = (datetime.now() - self.time_last_recent_call).seconds if seconds_since_last_call > self.frequency: self.time_last_recent_call = datetime.now() self.callback(*args, **kwargs)
[docs] def experiment_path( tuner_name: Optional[str] = None, local_path: Optional[str] = None ) -> Path: """ Return the path of an experiment which is used both by :class:`~syne_tune.Tuner` and to collect results of experiments. :param tuner_name: Name of a tuning experiment :param local_path: Local path where results should be saved. If not specified, then the environment variable ``"SYNETUNE_FOLDER"`` is used if defined otherwise ``~/syne-tune/`` is used. Defining the environment variable ``"SYNETUNE_FOLDER"`` allows to override the default path. :return: Path where to write logs and results for Syne Tune tuner. """ # means we are running on a local machine, we store results in a local path if local_path is None: if SYNE_TUNE_ENV_FOLDER in os.environ: result_path = Path(os.environ[SYNE_TUNE_ENV_FOLDER]).expanduser() else: result_path = Path(f"~/{SYNE_TUNE_DEFAULT_FOLDER}").expanduser() else: result_path = Path(local_path) if tuner_name is not None: result_path = result_path / tuner_name return result_path
[docs] def name_from_base(base: Optional[str], max_length: int = 63) -> str: """Append a timestamp to the provided string. This function assures that the total length of the resulting string is not longer than the specified max length, trimming the input parameter if necessary. :param base: String used as prefix to generate the unique name :param max_length: Maximum length for the resulting string (default: 63) :return: Input parameter with appended timestamp """ if base is None: base = "st-tuner" moment = time.time() moment_ms = repr(moment).split(".")[1][:3] format = ST_DATETIME_FORMAT + f"-{moment_ms}" timestamp = time.strftime(format, time.gmtime(moment)) trimmed_base = base[: max_length - len(timestamp) - 1] return "{}-{}".format(trimmed_base, timestamp)
[docs] def random_string(length: int) -> str: pool = string.ascii_letters + string.digits return "".join(random.choice(pool) for _ in range(length))
[docs] def repository_root_path() -> Path: """ :return: Returns path including ``syne_tune``, ``examples``, ``benchmarking`` """ return Path(__file__).parent.parent
[docs] def script_checkpoint_example_path() -> Path: """ :return: Path of checkpoint example """ path = ( repository_root_path() / "examples" / "training_scripts" / "checkpoint_example" / "checkpoint_example.py" ) assert path.exists() return path
[docs] def script_height_example_path() -> Path: """ :return: Path of ``train_heigth`` example """ path = ( repository_root_path() / "examples" / "training_scripts" / "height_example" / "train_height.py" ) assert path.exists() return path
[docs] @contextmanager def catchtime(name: str) -> float: start = perf_counter() try: print(f"start: {name}") yield lambda: perf_counter() - start finally: print(f"Time for {name}: {perf_counter() - start:.4f} secs")
[docs] def is_increasing(lst: List[Union[float, int]]) -> bool: """ :param lst: List of float or int entries :return: Is ``lst`` strictly increasing? """ return all(x < y for x, y in zip(lst, lst[1:]))
[docs] def is_positive_integer(lst: List[int]) -> bool: """ :param lst: List of int entries :return: Are all entries of ``lst`` of type ``int`` and positive? """ return all(x == int(x) and x >= 1 for x in lst)
[docs] def is_integer(lst: list) -> bool: """ :param lst: List of entries :return: Are all entries of ``lst`` of type ``int``? """ return all(x == int(x) for x in lst)
[docs] def dump_json_with_numpy( x: dict, filename: Optional[Union[str, Path]] = None ) -> Optional[str]: """ Serializes dictionary ``x`` in JSON, taking into account NumPy specific value types such as ``n.p.int64``. :param x: Dictionary to serialize or encode :param filename: Name of file to store JSON to. Optional. If not given, the JSON encoding is returned as string :return: If ``filename is None``, JSON encoding is returned """ def np_encoder(obj): if isinstance(obj, np.generic): return obj.item() if filename is None: return json.dumps(x, default=np_encoder) else: with open(filename, "w") as f: json.dump(x, f, default=np_encoder) return None
[docs] def dict_get(params: Dict[str, Any], key: str, default: Any) -> Any: """ Returns ``params[key]`` if this exists and is not None, and ``default`` otherwise. Note that this is not the same as ``params.get(key, default)``. Namely, if ``params[key]`` is equal to None, this would return None, but this method returns ``default``. This function is particularly helpful when dealing with a dict returned by :class:`argparse.ArgumentParser`. Whenever ``key`` is added as argument to the parser, but a value is not provided, this leads to ``params[key] = None``. """ v = params.get(key) return default if v is None else v
[docs] def recursive_merge( a: Dict[str, Any], b: Dict[str, Any], stop_keys: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Merge dictionaries ``a`` and ``b``, where ``b`` takes precedence. We typically use this to modify a dictionary ``a``, so ``b`` is smaller than ``a``. Further recursion is stopped on any node with key in ``stop_keys``. Use this for dictionary-valued entries not to be merged, but to be replaced by what is in ``b``. :param a: Dictionary :param b: Dictionary (can be empty) :param stop_keys: See above, optional :return: Merged dictionary """ if b: if stop_keys is None: stop_keys = [] result = dict() keys_b = set(b.keys()) for k, va in a.items(): if k in keys_b: keys_b.remove(k) vb = b[k] stop_recursion = k in stop_keys if isinstance(va, dict) and not stop_recursion: assert isinstance( vb, dict ), f"k={k} has dict value in a, but not in b:\n{va}\n{vb}" result[k] = recursive_merge(va, vb) else: assert stop_recursion or not isinstance( vb, dict ), f"k={k} has dict value in b, but not in a:\n{va}\n{vb}" result[k] = vb else: result[k] = va result.update({k: b[k] for k in keys_b}) return result else: return a
[docs] def find_first_of_type(a: Iterable[Any], typ) -> Optional[Any]: try: return next(x for x in a if isinstance(x, typ)) except StopIteration: return None
[docs] def metric_name_mode( metric_names: List[str], metric_mode: Union[str, List[str]], metric: Union[str, int] ) -> Tuple[str, str]: """ Retrieve the metric mode given a metric queried by either index or name. :param metric_names: metrics names defined in a scheduler :param metric_mode: metric mode or modes of a scheduler :param metric: Index or name of the selected metric :return the name and the mode of the queried metric """ if isinstance(metric, str): assert ( metric in metric_names ), f"Attempted to use {metric} while available metrics are {metric_names}" metric_name = metric elif isinstance(metric, int): assert metric < len( metric_names ), f"Attempted to use metric index={metric} with {len(metric_names)} available metrics" metric_name = metric_names[metric] else: raise AttributeError( f"metric must be <int> or <str> but {type(metric)} was provided" ) if len(metric_names) > 1: logger.warning( "Several metrics exist, this will " f"use metric={metric_name} (index={metric}) out of {metric_names}." ) if isinstance(metric_mode, list): metric_index = ( metric_names.index(metric_name) if isinstance(metric, str) else metric ) metric_mode = metric_mode[metric_index] return metric_name, metric_mode