from typing import Optional, List
import logging
import numpy as np
logger = logging.getLogger(__name__)
def _is_positive_int(x):
return int(x) == x and x >= 1
[docs]
def successive_halving_rung_levels(
rung_levels: Optional[List[int]],
grace_period: int,
reduction_factor: Optional[float],
rung_increment: Optional[int],
max_t: int,
) -> List[int]:
"""Creates ``rung_levels`` from ``grace_period``, ``reduction_factor``
Note: If ``rung_levels`` is given and ``rung_levels[-1] == max_t``, we strip
off this final entry, so that all rung levels are ``< max_t``.
:param rung_levels: If given, this is returned (but see above)
:param grace_period: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler`
:param reduction_factor: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler`
:param rung_increment: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler`
:param max_t: See :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler`
:return: List of rung levels
"""
if rung_levels is not None:
assert (
isinstance(rung_levels, list) and len(rung_levels) > 1
), "rung_levels must be list of size >= 2"
assert all(
_is_positive_int(x) for x in rung_levels
), "rung_levels must be list of positive integers"
rung_levels = [int(x) for x in rung_levels]
assert all(
x < y for x, y in zip(rung_levels, rung_levels[1:])
), "rung_levels must be strictly increasing sequence"
assert (
rung_levels[-1] <= max_t
), f"Last entry of rung_levels ({rung_levels[-1]}) must be <= max_t ({max_t})"
else:
# Rung levels given by grace_period, reduction_factor, max_t
assert _is_positive_int(grace_period)
assert _is_positive_int(max_t)
assert (
max_t > grace_period
), f"max_t ({max_t}) must be greater than grace_period ({grace_period})"
if reduction_factor is not None:
assert reduction_factor >= 2
rf = reduction_factor
min_t = grace_period
max_rungs = 0
while min_t * np.power(rf, max_rungs) < max_t:
max_rungs += 1
rung_levels = [
int(round(min_t * np.power(rf, k))) for k in range(max_rungs)
]
assert rung_levels[-1] <= max_t # Sanity check
if rung_increment is not None:
logger.warning(
f"You specified both reduction_factor = {reduction_factor} "
f"and rung_increment = {rung_increment}. The former takes "
"precedence, the latter will be ignored"
)
else:
assert rung_increment is not None
assert _is_positive_int(rung_increment)
rung_levels = list(range(grace_period, max_t, rung_increment))
if rung_levels[-1] == max_t:
rung_levels = rung_levels[:-1]
return rung_levels