# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import numpy as np
from syne_tune.optimizer.scheduler import TrialScheduler
[docs]
class BracketDistribution:
"""
Configures asynchronous multi-fidelity schedulers such as
:class:`~syne_tune.optimizer.schedulers.HyperbandScheduler` with
distribution over brackets. This distribution can be fixed up front, or
change adaptively during the course of an experiment. It has an effect
only if the scheduler is run with more than one bracket.
"""
def __call__(self) -> np.ndarray:
"""
:return: Distribution over brackets
"""
raise NotImplementedError
[docs]
class DefaultHyperbandBracketDistribution(BracketDistribution):
"""
Implements default bracket distribution, where probability for each bracket
is proportional to the number of slots in each bracket in synchronous
Hyperband.
"""
def __init__(self):
self.num_brackets = None
self.rung_levels = None
self._distribution = None
def __call__(self) -> np.ndarray:
assert self._distribution is not None, "Call 'configure' first"
return self._distribution
def _set_distribution(self):
if self.num_brackets > 1:
smax_plus1 = len(self.rung_levels)
assert self.num_brackets <= smax_plus1 # Sanity check
self._distribution = np.array(
[
smax_plus1 / ((smax_plus1 - s) * self.rung_levels[s])
for s in range(self.num_brackets)
]
)
self._distribution /= self._distribution.sum()
else:
self._distribution = np.ones(1)