Source code for syne_tune.optimizer.schedulers.utils.successive_halving

from typing import Optional, List
import logging
import numpy as np

logger = logging.getLogger(__name__)


def _is_positive_int(x):
    return int(x) == x and x >= 1


[docs] def successive_halving_rung_levels( rung_levels: Optional[List[int]], grace_period: int, reduction_factor: Optional[float], rung_increment: Optional[int], max_t: int, ) -> List[int]: """Creates ``rung_levels`` from ``grace_period``, ``reduction_factor`` Note: If ``rung_levels`` is given and ``rung_levels[-1] == max_t``, we strip off this final entry, so that all rung levels are ``< max_t``. :param rung_levels: If given, this is returned (but see above) :param grace_period: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler` :param reduction_factor: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler` :param rung_increment: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler` :param max_t: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler` :return: List of rung levels """ if rung_levels is not None: assert ( isinstance(rung_levels, list) and len(rung_levels) > 1 ), "rung_levels must be list of size >= 2" assert all( _is_positive_int(x) for x in rung_levels ), "rung_levels must be list of positive integers" rung_levels = [int(x) for x in rung_levels] assert all( x < y for x, y in zip(rung_levels, rung_levels[1:]) ), "rung_levels must be strictly increasing sequence" assert ( rung_levels[-1] <= max_t ), f"Last entry of rung_levels ({rung_levels[-1]}) must be <= max_t ({max_t})" else: # Rung levels given by grace_period, reduction_factor, max_t assert _is_positive_int(grace_period) assert _is_positive_int(max_t) assert ( max_t > grace_period ), f"max_t ({max_t}) must be greater than grace_period ({grace_period})" if reduction_factor is not None: assert reduction_factor >= 2 rf = reduction_factor min_t = grace_period max_rungs = 0 while min_t * np.power(rf, max_rungs) < max_t: max_rungs += 1 rung_levels = [ int(round(min_t * np.power(rf, k))) for k in range(max_rungs) ] assert rung_levels[-1] <= max_t # Sanity check if rung_increment is not None: logger.warning( f"You specified both reduction_factor = {reduction_factor} " f"and rung_increment = {rung_increment}. The former takes " "precedence, the latter will be ignored" ) else: assert rung_increment is not None assert _is_positive_int(rung_increment) rung_levels = list(range(grace_period, max_t, rung_increment)) if rung_levels[-1] == max_t: rung_levels = rung_levels[:-1] return rung_levels