# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Any
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state import (
TuningJobState,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import (
INTERNAL_METRIC_NAME,
)
[docs]
class SkipOptimizationPredicate:
"""
Interface for ``skip_optimization`` predicate in
:class:`~syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_transformer.ModelStateTransformer`.
"""
[docs]
def reset(self):
"""
If there is an internal state, reset it to its initial value
"""
pass
def __call__(self, state: TuningJobState) -> bool:
"""
:param state: Current tuning job state
:return: Skip hyperparameter optimization?
"""
raise NotImplementedError
[docs]
class NeverSkipPredicate(SkipOptimizationPredicate):
"""
Hyperparameter optimization is never skipped.
"""
def __call__(self, state: TuningJobState) -> bool:
return False
[docs]
class AlwaysSkipPredicate(SkipOptimizationPredicate):
"""
Hyperparameter optimization is always skipped.
"""
def __call__(self, state: TuningJobState) -> bool:
return True
[docs]
class SkipPeriodicallyPredicate(SkipOptimizationPredicate):
"""
Let ``N`` be the number of labeled points for metric ``metric_name``.
Optimizations are not skipped if ``N < init_length``. Afterwards,
we increase a counter whenever ``N`` is larger than in the previous
call. With respect to this counter, optimizations are done every
``period`` times, in between they are skipped.
:param init_length: See above
:param period: See above
:param metric_name: Name of internal metric. Defaults to
:const:`~syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common.INTERNAL_METRIC_NAME`.
"""
def __init__(
self, init_length: int, period: int, metric_name: str = INTERNAL_METRIC_NAME
):
assert init_length >= 0
assert period > 1
self.init_length = init_length
self.period = period
self.metric_name = metric_name
self.reset()
[docs]
def reset(self):
self._counter = 0
# Need to make sure that if called several times with the same state,
# we return the same value
self._last_size = None
self._last_retval = None
def __call__(self, state: TuningJobState) -> bool:
num_labeled = state.num_observed_cases(self.metric_name)
if num_labeled == self._last_size:
return self._last_retval
if self._last_size is not None:
assert (
num_labeled > self._last_size
), "num_labeled = {} < {} = _last_size".format(num_labeled, self._last_size)
if num_labeled < self.init_length:
ret_value = False
else:
ret_value = self._counter % self.period != 0
self._counter += 1
self._last_size = num_labeled
self._last_size = ret_value
return ret_value
[docs]
class SkipNoMaxResourcePredicate(SkipOptimizationPredicate):
"""
This predicate works for multi-fidelity HPO, see for example
:class:`~syne_tune.optimizer.schedulers.searchers.GPMultiFidelitySearcher`.
We track the number of labeled datapoints at resource level ``max_resource``.
HP optimization is skipped if the total number ``N`` of labeled cases is
``N >= init_length``, and if the number of ``max_resource`` cases has not
increased since the last recent optimization.
This means that as long as the dataset only grows w.r.t. cases at lower
resources than ``max_resource``, this does not trigger HP optimization.
:param init_length: See above
:param max_resource: See above
:param metric_name: Name of internal metric. Defaults to
:const:`~syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common.INTERNAL_METRIC_NAME`.
"""
def __init__(
self,
init_length: int,
max_resource: int,
metric_name: str = INTERNAL_METRIC_NAME,
):
assert init_length >= 0
self.init_length = init_length
self.metric_name = metric_name
self.max_resource = str(max_resource)
self.reset()
[docs]
def reset(self):
self.lastrec_max_resource_cases = None
def _num_max_resource_cases(self, state: TuningJobState):
def is_max_resource(metrics: Dict[str, Any]) -> int:
return int(self.max_resource in metrics[self.metric_name])
return sum(is_max_resource(ev.metrics) for ev in state.trials_evaluations)
def __call__(self, state: TuningJobState) -> bool:
if state.num_observed_cases(self.metric_name) < self.init_length:
return False
num_max_resource_cases = self._num_max_resource_cases(state)
if (
self.lastrec_max_resource_cases is None
or num_max_resource_cases > self.lastrec_max_resource_cases
):
self.lastrec_max_resource_cases = num_max_resource_cases
return False
else:
return True