# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Tuple, Optional, List
import copy
from syne_tune.optimizer.schedulers.synchronous.hyperband_bracket import (
SynchronousHyperbandBracket,
SlotInRung,
)
from syne_tune.optimizer.schedulers.synchronous.hyperband_rung_system import (
RungSystemsPerBracket,
)
[docs]
class SynchronousHyperbandBracketManager:
"""
Maintains all brackets, relays requests for another job and report of
result to one of the brackets.
Each bracket contains a number of rungs, the largest one ``max_num_rungs``.
A bracket with k rungs has offset ``max_num_rungs - k``. Hyperband cycles
through brackets with offset ``0, ..., num_brackets - 1``, where
``num_brackets <= max_num_rungs``.
At any given time, one bracket is primary, all other active brackets are
secondary. Jobs are preferentially assigned to the primary bracket, but
if its current rung has no free slots (all are pending), secondary
brackets are considered.
Each bracket has a ``bracket_id`` (nonnegative int). The primary bracket
always has the lowest id of all active ones. For job assignment, we iterate
over active brackets starting from the primary, and assign the job to the
first bracket which has a free slot. If none of the active brackets have
a free slot, a new bracket is created.
:param bracket_rungs: Rungs for successive brackets, from largest to
smallest
:param mode: Criterion is minimized ('min') or maximized ('max')
"""
def __init__(self, bracket_rungs: RungSystemsPerBracket, mode: str):
self.num_bracket_offsets = len(bracket_rungs)
assert self.num_bracket_offsets > 0
assert mode in {"min", "max"}
self.mode = mode
self.max_num_rungs = len(bracket_rungs[0])
for offset, rungs in enumerate(bracket_rungs):
assert len(rungs) == self.max_num_rungs - offset, (
f"bracket_rungs[{offset}] has size {len(rungs)}, should "
+ f"have size {self.max_num_rungs - offset}"
)
SynchronousHyperbandBracket.assert_check_rungs(rungs)
self._bracket_rungs = copy.deepcopy(bracket_rungs)
# List of all brackets. We do not delete brackets which are
# complete, but just keep them for a record
self._brackets = []
# Maps bracket_id to offset
self._bracket_id_to_offset = []
# Maps (offset, level), level a rung level in the bracket, to
# the previous rung level (or 0)
self._level_to_prev_level = dict()
for offset, rungs in enumerate(bracket_rungs):
_, levels = zip(*rungs)
levels = (0,) + levels
self._level_to_prev_level.update(
((offset, lv), plv) for (lv, plv) in zip(levels[1:], levels[:-1])
)
# Create primary bracket
self._primary_bracket_id = self._create_new_bracket()
@property
def bracket_rungs(self) -> RungSystemsPerBracket:
return self._bracket_rungs
@property
def _next_bracket_id(self) -> int:
return len(self._brackets)
[docs]
def level_to_prev_level(self, bracket_id: int, level: int) -> int:
"""
:param bracket_id:
:param level: Level in bracket
:return: Previous level; or 0
"""
offset = self._bracket_id_to_offset[bracket_id]
return self._level_to_prev_level[(offset, level)]
def _create_new_bracket(self) -> int:
# Sanity check:
assert len(self._brackets) == len(self._bracket_id_to_offset)
bracket_id = self._next_bracket_id
offset = bracket_id % self.num_bracket_offsets
self._bracket_id_to_offset.append(offset)
self._brackets.append(
SynchronousHyperbandBracket(self._bracket_rungs[offset], self.mode)
)
return bracket_id
[docs]
def next_job(self) -> Tuple[int, SlotInRung]:
"""
Called by scheduler to request a new job. Jobs are preferentially
assigned to the primary bracket, which has the lowest id among all
active brackets. If the primary bracket does not accept jobs (because
all remaining slots are already pending), further active brackets are
polled. If none of the active brackets accept jobs, a new bracket is
created.
The job description returned is (bracket_id, slot_in_rung), where
``slot_in_rung`` is :class:`SlotInRung`, containing the info of what
is to be done (``trial_id``, ``level`` fields). It is this entry which
has to be returned in 'on_result``, which the ``metric_val`` field set.
If the job returned here has ``trial_id == None``, it comes from the
lowest rung of its bracket, and the ``trial_id`` has to be set as well
when returning the record in ``on_result``.
:return: Tuple ``(bracket_id, slot_in_rung)``
"""
# Try to assign job to active bracket. There must be at least one,
# the primary one
bracket_ids = range(self._primary_bracket_id, self._next_bracket_id)
for bracket_id in bracket_ids:
slot_in_rung = self._brackets[bracket_id].next_free_slot()
if slot_in_rung is not None:
return bracket_id, slot_in_rung
# None of the existing brackets accept jobs. Create a new one
bracket_id = self._create_new_bracket()
slot_in_rung = self._brackets[bracket_id].next_free_slot()
assert slot_in_rung is not None, "Newly created bracket has to have a free slot"
return bracket_id, slot_in_rung
[docs]
def on_result(self, result: Tuple[int, SlotInRung]) -> Optional[List[int]]:
"""
Called by scheduler to provide result for previously requested job.
See :meth:`next_job`.
:param result: Tuple ``(bracket_id, slot_in_rung)``
:return: See :meth:`~syne_tune.optimizer.schedulers.synchronous.hyperband_bracket.SynchronousBracket.on_result`
"""
bracket_id, slot_in_rung = result
assert self._primary_bracket_id <= bracket_id < self._next_bracket_id, (
f"Invalid bracket_id = {bracket_id}, must be in "
+ f"[{self._primary_bracket_id}, {self._next_bracket_id})"
)
bracket = self._brackets[bracket_id]
trials_not_promoted = bracket.on_result(slot_in_rung)
for_primary = bracket_id == self._primary_bracket_id
if for_primary:
# Primary bracket is complete: Move to next one. While very
# unlikely, brackets after the primary one could be complete
# as well
last_bracket = self._next_bracket_id - 1
while (
bracket.is_bracket_complete()
and self._primary_bracket_id < last_bracket
):
self._primary_bracket_id += 1
bracket = self._brackets[self._primary_bracket_id]
# May have to create a new bracket
if bracket.is_bracket_complete():
self._primary_bracket_id = self._create_new_bracket()
return trials_not_promoted