Source code for syne_tune.optimizer.schedulers.searchers.dyhpo.hyperband_dyhpo

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Dict, Any, Tuple
from numpy.random import RandomState
import logging
from collections import Counter

from syne_tune.optimizer.schedulers.hyperband_promotion import (
    PromotionRungSystem,
)
from syne_tune.optimizer.schedulers.searchers import BaseSearcher

logger = logging.getLogger(__name__)


DEFAULT_SH_PROBABILITY = 0.25


KEY_NEW_CONFIGURATION = "new_configuration"


[docs] class ScheduleDecision: PROMOTE_SH = 0 PROMOTE_DYHPO = 1 START_DYHPO = 2
_SUMMARY_SCHEDULE_RECORDS = [ ("promoted_by_sh", ScheduleDecision.PROMOTE_SH), ("promoted_by_dyhpo", ScheduleDecision.PROMOTE_DYHPO), ("started_by_dyhpo", ScheduleDecision.START_DYHPO), ]
[docs] class DyHPORungSystem(PromotionRungSystem): """ Implements the logic which decides which paused trial to promote to the next resource level, or alternatively which configuration to start as a new trial, proposed in: | Wistuba, M. and Kadra, A. and Grabocka, J. | Dynamic and Efficient Gray-Box Hyperparameter Optimization for Deep Learning | https://arxiv.org/abs/2202.09774 We do promotion-based scheduling, as in :class:`~syne_tune.optimizer.schedulers.hyperband_promotion.PromotionRungSystem`. In fact, we run the successive halving rule in :meth:`on_task_schedule` with probability ``probability_sh``, and the DyHPO logic otherwise, or if the SH rule does not promote a trial. This mechanism (not contained in the paper) ensures that trials are promoted eventually, even if DyHPO only starts new trials. Since :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler` was designed for promotion decisions to be separate from decisions about new configs, the overall workflow is a bit tricky: * In :meth:`FIFOScheduler._suggest`, we first call :code:`promote_trial_id, extra_kwargs = self._promote_trial()`. If ``promote_trial_id != None``, this trial is promoted. Otherwise, we call :code:`config = self.searcher.get_config(**extra_kwargs, trial_id=trial_id)` and start a new trial with this config. In most cases, :meth:`_promote_trial` makes a promotion decision without using the searcher. * Here, we use the fact that information can be passed from :meth:`_promote_trial` to ``self.searcher.get_config`` via ``extra_kwargs``. Namely, :meth:``HyperbandScheduler._promote_trial` calls :meth:`on_task_schedule` here, which calls :meth:`~syne_tune.optimizer.schedulers.searchers.dyhpo.DynamicHPOSearcher.score_paused_trials_and_new_configs`, where everything happens. * First, all paused trials are scored w.r.t. the value of running them for one more unit of resource. Also, a number of random configs are scored w.r.t. the value of running them to the minimum resource. * If the winning config is from a paused trial, this is resumed. If the winning config is a new one, :meth:`on_task_schedule` returns this config using a special key :const:`KEY_NEW_CONFIGURATION`. This dict becomes part of ``extra_kwargs`` and is passed to ``self.searcher.get_config`` * :meth:`~syne_tune.optimizer.schedulers.searchers.dyhpo.DynamicHPOSearcher.get_config` is trivial. It obtains an argument of name :const:`KEY_NEW_CONFIGURATION` returns its value, which is the winning config to be started as new trial We can ignore ``rung_levels`` and ``promote_quantiles``, they are not used. For each trial, we only need to maintain the resource level at which it is paused. """ def __init__( self, rung_levels: List[int], promote_quantiles: List[float], metric: str, mode: str, resource_attr: str, max_t: int, searcher: BaseSearcher, probability_sh: bool, random_state: RandomState, ): assert len(rung_levels) > 0, "rung_levels must not be empty" assert ( 0 <= probability_sh < 1 ), f"probability_sh = {probability_sh}, must be in [0, 1)" super().__init__( rung_levels, promote_quantiles, metric, mode, resource_attr, max_t ) self._check_rung_levels(rung_levels) self._searcher = searcher self._min_resource = rung_levels[0] self._probability_sh = probability_sh self._random_state = random_state # Maps rung level to the one below self._previous_rung_level = dict(zip(rung_levels[1:] + [max_t], rung_levels)) # Keeps a record of outcomes of :meth:`on_task_schedule` calls. Entries # are ``(trial_id, decision, milestone)``, where ``decision`` is # constant from :class:`ScheduleDecision`, and ``milestone`` is the # rung level which the trial reaches next self._schedule_records = [] @staticmethod def _check_rung_levels(rung_levels: List[int]): if len(rung_levels) > 1: rmin = rung_levels[0] step = rung_levels[1] - rmin should_be = list(range(rmin, rung_levels[-1] + 1, step)) if rmin != step or rung_levels != should_be: logger.warning( "DyHPO should be run with linearly spaced rung levels, in " "that reduction_factor is not used, and grace_period == " "rung_increment, bracket == 1. Running with rung_levels = " f"{rung_levels} is not recommended" ) def _paused_trials_and_milestones(self) -> List[Tuple[str, int, int]]: """ Return list of all trials which are paused. Entries are ``(trial_id, pos, resource)``, where ``pos`` is the position of the trial in its rung, and ``resource`` is the next rung level the trial reaches after being resumed. :return: See above """ paused_trials = [] next_level = self._max_t for rung in self._rungs: level = rung.level paused_trials.extend( (entry.trial_id, pos, next_level) for pos, entry in enumerate(rung.data) if self._is_promotable_trial(entry, level) ) next_level = level return paused_trials
[docs] def on_task_schedule(self, new_trial_id: str) -> Dict[str, Any]: """ The main decision making happens here. We collect ``(trial_id, resource)`` for all paused trials and call ``searcher``. The searcher scores all these trials along with a certain number of randomly drawn new configurations. If one of the paused trials has the best score, we return its ``trial_id`` along with extra information, so it gets promoted. If one of the new configurations has the best score, we return this configuration. In this case, a new trial is started with this configuration. Note: For this scheduler type, ``kwargs`` must contain the trial ID of the new trial to be started, in case none can be promoted. """ if self._random_state.rand() <= self._probability_sh: # Try to promote trial based on successive halving logic result = super().on_task_schedule(new_trial_id) if result.get("trial_id") is not None: self._schedule_records.append( ( result["trial_id"], ScheduleDecision.PROMOTE_SH, result["milestone"], ) ) return result # Follow DyHPO logic paused_trials = self._paused_trials_and_milestones() assert new_trial_id is not None, ( "Internal error: kwargs must contain 'trial_id', the ID for a new " "trial to be started if no paused one is resumed. Make sure to " "pass this to the _promote_trial method when calling it in " "_suggest" ) result = self._searcher.score_paused_trials_and_new_configs( paused_trials, min_resource=self._min_resource, new_trial_id=new_trial_id, ) trial_id = result.get("trial_id") if trial_id is not None: # Trial is to be promoted pos = result["pos"] # Position of trial in its rung milestone = next(r for i, _, r in paused_trials if i == trial_id) resume_from = self._previous_rung_level[milestone] rung = next(rung for rung in self._rungs if rung.level == resume_from) self._mark_as_promoted(rung, pos, trial_id=trial_id) ret_dict = { "trial_id": trial_id, "resume_from": resume_from, "milestone": milestone, } self._schedule_records.append( (trial_id, ScheduleDecision.PROMOTE_DYHPO, milestone) ) else: # New trial is to be started ret_dict = {KEY_NEW_CONFIGURATION: result["config"]} self._schedule_records.append( (new_trial_id, ScheduleDecision.START_DYHPO, self._min_resource) ) return ret_dict
@property def schedule_records(self) -> List[Tuple[str, int, int]]: return self._schedule_records
[docs] @staticmethod def summary_schedule_keys() -> List[str]: return [key for key, _ in _SUMMARY_SCHEDULE_RECORDS]
[docs] def summary_schedule_records(self) -> Dict[str, Any]: histogram = Counter([x[1] for x in self._schedule_records]) return {name: histogram[value] for name, value in _SUMMARY_SCHEDULE_RECORDS}
[docs] def support_early_checkpoint_removal(self) -> bool: """ Early checkpoint removal currently not supported for DyHPO """ return False