Source code for syne_tune.optimizer.schedulers.synchronous.dehb_bracket_manager

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Tuple, Optional

from syne_tune.optimizer.schedulers.synchronous.hyperband_bracket_manager import (
    SynchronousHyperbandBracketManager,
)
from syne_tune.optimizer.schedulers.synchronous.dehb_bracket import (
    DifferentialEvolutionHyperbandBracket,
)


[docs] class DifferentialEvolutionHyperbandBracketManager(SynchronousHyperbandBracketManager): """ Special case of :class:`SynchronousHyperbandBracketManager` to manage DEHB brackets (type :class:`DifferentialEvolutionHyperbandBracket`). In DEHB, the list of brackets is determined by the first one and the number of brackets. Also, later brackets have less total budget, because the size of a rung is determined by its level, independent of the bracket. This is different to what is done in synchronous Hyperband, where the rungs of later brackets have larger sizes, so the total budget of each bracket is the same. We also need additional methods to access trial_id's in specific rungs, as well as entries of the top lists for completed rungs. This is because DEHB controls the creation of new configurations at higher rungs, while synchronous Hyperband relies on automatic promotion from lower rungs. """ def __init__( self, rungs_first_bracket: List[Tuple[int, int]], mode: str, num_brackets_per_iteration: Optional[int] = None, ): max_num_offsets = len(rungs_first_bracket) assert max_num_offsets > 0 if num_brackets_per_iteration is None: num_brackets_per_iteration = max_num_offsets else: assert 1 <= num_brackets_per_iteration <= max_num_offsets, ( f"num_brackets_per_iteration = {num_brackets_per_iteration}" + f", must be in [1, {max_num_offsets}]" ) # All brackets are determined by the first one in DEHB bracket_rungs = [ rungs_first_bracket[offset:] for offset in range(num_brackets_per_iteration) ] super().__init__(bracket_rungs, mode) # Maps (bracket_id, rung_index) to top list of previous rung, as # returned by # ``DifferentialEvolutionHyperbandBracket.top_list_for_previous_rung`` # when the current rung is ``rung_index`` in that bracket. We cache # these, so we don't have to repeat sorting many times self._top_list_of_previous_rung_cache = dict() # Maps (offset, level) to (bracket_delta, rung_index) in order to # determine the parent rung of a rung in a bracket with offset and # level (the parent rung has the same level). self._parent_rung = self._set_parent_rung() def _set_parent_rung(self): parent_rung = dict() for offset, rungs in enumerate(self._bracket_rungs): if offset > 0: # For bracket with offset > 0, the parent rung is in the # bracket just to the left bracket_delta = 1 for rung_index, (_, level) in enumerate(rungs): parent_rung[(offset, level)] = ( bracket_delta, rung_index + 1, ) else: # For bracket with offset 0, the parent rung is the base # rung in a bracket to the left for rung_index, (_, level) in enumerate(rungs): parent_rung[(offset, level)] = ( self.num_bracket_offsets - rung_index, 0, ) return parent_rung def _create_new_bracket(self) -> int: # Sanity check: assert len(self._brackets) == len(self._bracket_id_to_offset) bracket_id = self._next_bracket_id offset = bracket_id % self.num_bracket_offsets self._bracket_id_to_offset.append(offset) rungs = self._bracket_rungs[offset] self._brackets.append( DifferentialEvolutionHyperbandBracket(rungs=rungs, mode=self.mode) ) return bracket_id
[docs] def size_of_current_rung(self, bracket_id: int) -> int: return self._brackets[bracket_id].size_of_current_rung()
[docs] def trial_id_from_parent_slot( self, bracket_id: int, level: int, slot_index: int ) -> Optional[int]: """ The parent slot has the same slot index and rung level in the largest bracket ``< bracket_id`` with a trial_id not None. If no such slot exists, None is returned. For a cross-over or selection operation, the target is chosen from the parent slot. """ trial_id = None while trial_id is None and bracket_id > 0: bracket_delta, rung_index = self._parent_rung[ (self._bracket_id_to_offset[bracket_id], level) ] bracket_id = bracket_id - bracket_delta trial_id = self._brackets[bracket_id].trial_id_for_slot( rung_index=rung_index, slot_index=slot_index ) return trial_id
[docs] def top_of_previous_rung(self, bracket_id: int, pos: int) -> int: """ For the current rung in bracket ``bracket_id``, consider the slots of the previous rung (below) in sorted order. We return the trial_id of position ``pos`` (so for ``pos=0``, the best entry). """ bracket = self._brackets[bracket_id] rung_index = bracket.current_rung key = (bracket_id, rung_index) top_list = self._top_list_of_previous_rung_cache.get(key) if top_list is None: top_list = bracket.top_list_for_previous_rung() self._top_list_of_previous_rung_cache[key] = top_list return top_list[pos]