Source code for syne_tune.optimizer.schedulers.synchronous.hyperband

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Optional, List, Set, Dict, Any, Tuple
import logging
import numpy as np

from syne_tune.callbacks.remove_checkpoints_callback import (
    DefaultRemoveCheckpointsSchedulerMixin,
)
from syne_tune.optimizer.schedulers.random_seeds import RANDOM_SEED_UPPER_BOUND
from syne_tune.optimizer.schedulers.synchronous.hyperband_bracket_manager import (
    SynchronousHyperbandBracketManager,
)
from syne_tune.optimizer.schedulers.synchronous.hyperband_bracket import SlotInRung
from syne_tune.optimizer.schedulers.synchronous.hyperband_rung_system import (
    RungSystemsPerBracket,
)
from syne_tune.optimizer.scheduler import TrialSuggestion, SchedulerDecision
from syne_tune.optimizer.schedulers.scheduler_searcher import TrialSchedulerWithSearcher
from syne_tune.optimizer.schedulers.multi_fidelity import MultiFidelitySchedulerMixin
from syne_tune.backend.trial_status import Trial
from syne_tune.config_space import cast_config_values
from syne_tune.optimizer.schedulers.searchers.utils.default_arguments import (
    check_and_merge_defaults,
    Categorical,
    String,
    assert_no_invalid_options,
    Integer,
)
from syne_tune.optimizer.schedulers.searchers.searcher import BaseSearcher
from syne_tune.optimizer.schedulers.searchers.searcher_factory import searcher_factory

logger = logging.getLogger(__name__)


_ARGUMENT_KEYS = {
    "searcher",
    "search_options",
    "metric",
    "mode",
    "points_to_evaluate",
    "random_seed",
    "max_resource_attr",
    "max_resource_level",
    "resource_attr",
    "searcher_data",
}

_DEFAULT_OPTIONS = {
    "searcher": "random",
    "mode": "min",
    "resource_attr": "epoch",
    "searcher_data": "rungs",
}

_CONSTRAINTS = {
    "metric": String(),
    "mode": Categorical(choices=("min", "max")),
    "random_seed": Integer(0, RANDOM_SEED_UPPER_BOUND),
    "max_resource_attr": String(),
    "max_resource_level": Integer(1, None),
    "resource_attr": String(),
    "searcher_data": Categorical(("rungs", "all")),
}


[docs] class SynchronousHyperbandCommon( TrialSchedulerWithSearcher, MultiFidelitySchedulerMixin ): """ Common code for :meth:`_create_internal` in :class:`~syne_tune.optimizer.schedulers.synchronous.SynchronousHyperbandScheduler` and :class:`~syne_tune.optimizer.schedulers.synchronous.DifferentialEvolutionHyperbandScheduler` """ def _create_internal_common( self, skip_searchers: Optional[Set[str]] = None, **kwargs ) -> Dict[str, Any]: self.metric = kwargs.get("metric") assert self.metric is not None, ( "Argument 'metric' is mandatory. Pass the name of the metric " + "reported by your training script, which you'd like to " + "optimize, and use 'mode' to specify whether it should " + "be minimized or maximized" ) self.mode = kwargs["mode"] self.max_resource_attr = kwargs.get("max_resource_attr") self._resource_attr = kwargs["resource_attr"] if self.max_resource_attr is None: logger.warning( "You do not specify max_resource_attr, but use max_resource_level " "instead. This is not recommended best practice and may lead to a " "loss of efficiency. Consider using max_resource_attr instead.\n" "See https://syne-tune.readthedocs.io/en/latest/tutorials/multifidelity/mf_setup.html#the-launcher-script " "for details." ) self._max_resource_level = self._infer_max_resource_level( max_resource_level=kwargs.get("max_resource_level"), max_resource_attr=self.max_resource_attr, ) assert self._max_resource_level is not None, ( "Maximum resource level has to be specified, please provide " "max_resource_attr or max_resource_level argument." ) self._searcher_data = kwargs["searcher_data"] # Generate searcher searcher = kwargs["searcher"] assert isinstance( searcher, str ), f"searcher must be of type string, but has type {type(searcher)}" search_options = kwargs.get("search_options") if search_options is None: search_options = dict() else: search_options = search_options.copy() if skip_searchers is None or searcher not in skip_searchers: search_options.update( { "config_space": self.config_space.copy(), "metric": self.metric, "points_to_evaluate": kwargs.get("points_to_evaluate"), "mode": kwargs["mode"], "random_seed_generator": self.random_seed_generator, "resource_attr": self._resource_attr, "scheduler": "hyperband_synchronous", } ) if searcher == "bayesopt": search_options["max_epochs"] = self._max_resource_level self._searcher: BaseSearcher = searcher_factory(searcher, **search_options) else: self._searcher = None return search_options @property def searcher(self) -> Optional[BaseSearcher]: return self._searcher @property def resource_attr(self) -> str: return self._resource_attr @property def max_resource_level(self) -> int: return self._max_resource_level @property def searcher_data(self) -> str: return self._searcher_data
[docs] class SynchronousHyperbandScheduler( SynchronousHyperbandCommon, DefaultRemoveCheckpointsSchedulerMixin ): """ Synchronous Hyperband. Compared to :class:`~syne_tune.optimizer.schedulers.HyperbandScheduler`, this is also scheduling jobs asynchronously, but decision-making is synchronized, in that trials are only promoted to the next milestone once the rung they are currently paused at, is completely occupied. Our implementation never delays scheduling of a job. If the currently active bracket does not accept jobs, we assign the job to a later bracket. This means that at any point in time, several brackets can be active, but jobs are preferentially assigned to the first one (the "primary" active bracket). :param config_space: Configuration space for trial evaluation function :param bracket_rungs: Determines rung level systems for each bracket, see :class:`~syne_tune.optimizer.schedulers.synchronous.hyperband_bracket_manager.SynchronousHyperbandBracketManager` :param metric: Name of metric to optimize, key in result's obtained via :meth:`on_trial_result` :type metric: str :param searcher: Searcher for ``get_config`` decisions. Passed to :func:`~syne_tune.optimizer.schedulers.searchers.searcher_factory` along with ``search_options`` and extra information. Supported values: :const:`~syne_tune.optimizer.schedulers.searchers.searcher_factory.SUPPORTED_SEARCHERS_HYPERBAND`. Defaults to "random" (i.e., random search) :type searcher: str, optional :param search_options: Passed to :func:`~syne_tune.optimizer.schedulers.searchers.searcher_factory`. :type search_options: Dict[str, Any], optional :param mode: Mode to use for the metric given, can be "min" (default) or "max" :type mode: str, optional :param points_to_evaluate: List of configurations to be evaluated initially (in that order). Each config in the list can be partially specified, or even be an empty dict. For each hyperparameter not specified, the default value is determined using a midpoint heuristic. If ``None`` (default), this is mapped to ``[dict()]``, a single default config determined by the midpoint heuristic. If ``[]`` (empty list), no initial configurations are specified. :type points_to_evaluate: ``List[dict]``, optional :param random_seed: Master random seed. Generators used in the scheduler or searcher are seeded using :class:`~syne_tune.optimizer.schedulers.random_seeds.RandomSeedGenerator`. If not given, the master random seed is drawn at random here. :type random_seed: int, optional :param max_resource_attr: Key name in config for fixed attribute containing the maximum resource. If given, trials need not be stopped, which can run more efficiently. :type max_resource_attr: str, optional :param max_resource_level: Largest rung level, corresponds to ``max_t`` in :class:`~syne_tune.optimizer.schedulers.FIFOScheduler`. Must be positive int larger than ``grace_period``. If this is not given, it is inferred like in :class:`~syne_tune.optimizer.schedulers.FIFOScheduler`. In particular, it is not needed if ``max_resource_attr`` is given. :type max_resource_level: int, optional :param resource_attr: Name of resource attribute in results obtained via ``:meth:`on_trial_result`. The type of resource must be int. Default to "epoch" :type resource_attr: str, optional :param searcher_data: Relevant only if a model-based searcher is used. Example: For NN tuning and ``resource_attr == "epoch"``, we receive a result for each epoch, but not all epoch values are also rung levels. searcher_data determines which of these results are passed to the searcher. As a rule, the more data the searcher receives, the better its fit, but also the more expensive get_config may become. Choices: * "rungs" (default): Only results at rung levels. Cheapest * "all": All results. Most expensive Note: For a Gaussian additive learning curve surrogate model, this has to be set to "all". :type searcher_data: str, optional """ def __init__( self, config_space: Dict[str, Any], bracket_rungs: RungSystemsPerBracket, **kwargs, ): super().__init__(config_space, **kwargs) self._create_internal(bracket_rungs, **kwargs) def _create_internal(self, bracket_rungs: RungSystemsPerBracket, **kwargs): # Check values and impute default values assert_no_invalid_options( kwargs, _ARGUMENT_KEYS, name="SynchronousHyperbandScheduler" ) kwargs = check_and_merge_defaults( kwargs, set(), _DEFAULT_OPTIONS, _CONSTRAINTS, dict_name="scheduler_options" ) self._create_internal_common(**kwargs) # Bracket manager self.bracket_manager = SynchronousHyperbandBracketManager( bracket_rungs, mode=self.mode, ) # Maps trial_id to tuples ``(bracket_id, slot_in_rung)``, as returned # by ``bracket_manager.next_job``, and required by # ``bracket_manager.on_result``. Entries are removed once passed to # ``on_result``. Note that a trial_id can be associated with different # job descriptions in its lifetime self._trial_to_pending_slot = dict() # Maps trial_id (active) to config self._trial_to_config = dict() self._rung_levels = [level for _, level in bracket_rungs[0]] self._trials_checkpoints_can_be_removed = [] @property def rung_levels(self) -> List[int]: return self._rung_levels @property def num_brackets(self) -> int: return len(self.bracket_manager.bracket_rungs) def _suggest(self, trial_id: int) -> Optional[TrialSuggestion]: do_debug_log = self.searcher.debug_log is not None if do_debug_log and trial_id == 0: # This is printed at the start of the experiment. Cannot do this # at construction, because with ``RemoteLauncher`` this does not end # up in the right log parts = ["Rung systems for each bracket:"] + [ f"Bracket {bracket}: {rungs}" for bracket, rungs in enumerate(self.bracket_manager.bracket_rungs) ] logger.info("\n".join(parts)) # Ask bracket manager for job bracket_id, slot_in_rung = self.bracket_manager.next_job() suggestion = None if slot_in_rung.trial_id is not None: # Paused trial to be resumed (``trial_id`` passed in is ignored) trial_id = slot_in_rung.trial_id _config = self._trial_to_config[trial_id] if self.max_resource_attr is not None: config = dict(_config, **{self.max_resource_attr: slot_in_rung.level}) else: config = _config suggestion = TrialSuggestion.resume_suggestion( trial_id=trial_id, config=config ) if do_debug_log: logger.info(f"trial_id {trial_id} promoted to {slot_in_rung.level}") else: # New trial to be started (id is ``trial_id`` passed in) config = self.searcher.get_config(trial_id=str(trial_id)) if config is not None: config = cast_config_values(config, self.config_space) self.searcher.register_pending( trial_id=str(trial_id), config=config, milestone=slot_in_rung.level ) if self.max_resource_attr is not None: config[self.max_resource_attr] = slot_in_rung.level self._trial_to_config[trial_id] = config suggestion = TrialSuggestion.start_suggestion(config=config) # Assign trial id to job descriptor slot_in_rung.trial_id = trial_id if do_debug_log: logger.info( f"trial_id {trial_id} starts (milestone = " f"{slot_in_rung.level})" ) if suggestion is not None: assert trial_id not in self._trial_to_pending_slot, ( f"Trial for trial_id = {trial_id} is already registered as " + "pending, cannot resume or start it" ) self._trial_to_pending_slot[trial_id] = (bracket_id, slot_in_rung) else: # Searcher failed to return a config for a new ``trial_id``. We report # the corresponding job as failed, so that in case the experiment # is continued, the bracket is not blocked with a slot which remains # pending forever logger.warning( "Searcher failed to suggest a configuration for new trial " f"{trial_id}. The corresponding rung slot is marked as failed." ) self._report_as_failed(bracket_id, slot_in_rung) return suggestion def _on_result(self, result: Tuple[int, SlotInRung]): trials_not_promoted = self.bracket_manager.on_result(result) if trials_not_promoted is not None: self._trials_checkpoints_can_be_removed.extend(trials_not_promoted) def _report_as_failed(self, bracket_id: int, slot_in_rung: SlotInRung): result_failed = SlotInRung( rung_index=slot_in_rung.rung_index, level=slot_in_rung.level, slot_index=slot_in_rung.slot_index, trial_id=slot_in_rung.trial_id, metric_val=np.NAN, ) self._on_result((bracket_id, result_failed))
[docs] def on_trial_result(self, trial: Trial, result: Dict[str, Any]) -> str: trial_id = trial.trial_id if trial_id in self._trial_to_pending_slot: bracket_id, slot_in_rung = self._trial_to_pending_slot[trial_id] assert slot_in_rung.trial_id == trial_id # Sanity check assert self.metric in result, ( f"Result for trial_id {trial_id} does not contain " + f"'{self.metric}' field" ) metric_val = float(result[self.metric]) assert self._resource_attr in result, ( f"Result for trial_id {trial_id} does not contain " + f"'{self._resource_attr}' field" ) resource = int(result[self._resource_attr]) milestone = slot_in_rung.level trial_decision = SchedulerDecision.CONTINUE if resource >= milestone: assert resource == milestone, ( f"Trial trial_id {trial_id}: Obtained result for " + f"resource = {resource}, but not for {milestone}. " + "Training script must not skip rung levels!" ) # Reached rung level: Pass result to bracket manager slot_in_rung.metric_val = metric_val self._on_result((bracket_id, slot_in_rung)) # Remove it from pending slots del self._trial_to_pending_slot[trial_id] # Trial should be paused trial_decision = SchedulerDecision.PAUSE prev_level = self.bracket_manager.level_to_prev_level(bracket_id, milestone) if resource > prev_level: # If the training script does not implement checkpointing, each # trial starts from scratch. In this case, the condition # ``resource > prev_level`` ensures that the searcher does not # receive multiple reports for the same resource update = self.searcher_data == "all" or resource == milestone self.searcher.on_trial_result( trial_id=str(trial_id), config=self._trial_to_config[trial_id], result=result, update=update, ) else: trial_decision = SchedulerDecision.STOP logger.warning( f"Received result for trial_id {trial_id}, which is not " f"pending. This result is not used:\n{result}" ) return trial_decision
[docs] def on_trial_error(self, trial: Trial): """ Given the ``trial`` is currently pending, we send a result at its milestone for metric value NaN. Such trials are ranked after all others and will most likely not be promoted. """ super().on_trial_error(trial) trial_id = trial.trial_id if trial_id in self._trial_to_pending_slot: bracket_id, slot_in_rung = self._trial_to_pending_slot[trial_id] self._report_as_failed(bracket_id, slot_in_rung) # A failed trial is not pending anymore del self._trial_to_pending_slot[trial_id] else: logger.warning( f"Trial trial_id {trial_id} not registered at pending: " "on_trial_error call is ignored" )
[docs] def metric_names(self) -> List[str]: return [self.metric]
[docs] def metric_mode(self) -> str: return self.mode
[docs] def trials_checkpoints_can_be_removed(self) -> List[int]: result = self._trials_checkpoints_can_be_removed self._trials_checkpoints_can_be_removed = [] return result