Source code for syne_tune.optimizer.schedulers.scheduler_searcher

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Optional, Dict, Any
import logging

from syne_tune.optimizer.schedulers.searchers import BaseSearcher
from syne_tune.optimizer.scheduler import TrialSuggestion, TrialScheduler
from syne_tune.optimizer.schedulers.random_seeds import (
    RandomSeedGenerator,
    generate_random_seed,
)
from syne_tune.backend.trial_status import Trial

logger = logging.getLogger(__name__)


[docs] class TrialSchedulerWithSearcher(TrialScheduler): """ Base class for trial schedulers which have a :class:`~syne_tune.optimizer.schedulers.searchers.BaseSearcher` member ``searcher``. This searcher has a method :meth:`~syne_tune.optimizer.schedulers.searchers.BaseSearcher.configure_scheduler` which has to be called before the searcher is first used. We also collect common code here: * Determine ``max_resource_level`` if not explicitly given * Master seed, :attr:`random_seed_generator` """ def __init__(self, config_space: Dict[str, Any], **kwargs): super().__init__(config_space) self._searcher_initialized = False # Generator for random seeds random_seed = kwargs.get("random_seed") if random_seed is None: random_seed = generate_random_seed() logger.info(f"Master random_seed = {random_seed}") self.random_seed_generator = RandomSeedGenerator(random_seed) @property def searcher(self) -> Optional[BaseSearcher]: raise NotImplementedError def _initialize_searcher(self): """Callback to initialize searcher based on scheduler, if not done already""" if not self._searcher_initialized: if self.searcher is not None: self.searcher.configure_scheduler(self) self._searcher_initialized = True
[docs] def suggest(self, trial_id: int) -> Optional[TrialSuggestion]: self._initialize_searcher() return super().suggest(trial_id)
[docs] def on_trial_error(self, trial: Trial): self._initialize_searcher() if self.searcher is not None: trial_id = str(trial.trial_id) self.searcher.evaluation_failed(trial_id) if self.searcher.debug_log is not None: logger.info(f"trial_id {trial_id}: Evaluation failed!")
[docs] def on_trial_complete(self, trial: Trial, result: Dict[str, Any]): self._initialize_searcher() if self.searcher is not None: config = self._preprocess_config(trial.config) self.searcher.on_trial_result( str(trial.trial_id), config, result=result, update=True )
def _infer_max_resource_level_getval(self, name): if name in self.config_space and name not in self._hyperparameter_keys: return self.config_space[name] else: return None def _infer_max_resource_level( self, max_resource_level: Optional[int], max_resource_attr: Optional[str] ): """Infer ``max_resource_level`` if not explicitly given. :param max_resource_level: Value explicitly provided, or None :param max_resource_attr: Name of max resource attribute in ``self.config_space`` (optional) :return: Inferred value for ``max_resource_level`` """ inferred_max_t = None names = ("epochs", "max_t", "max_epochs") if max_resource_attr is not None: names = (max_resource_attr,) + names for name in names: inferred_max_t = self._infer_max_resource_level_getval(name) if inferred_max_t is not None: break if max_resource_level is not None: if inferred_max_t is not None and max_resource_level != inferred_max_t: logger.warning( f"max_resource_level = {max_resource_level} is different " f"from the value {inferred_max_t} inferred from " "config_space" ) else: # It is OK if max_resource_level cannot be inferred if inferred_max_t is not None: logger.info( f"max_resource_level = {inferred_max_t}, as inferred " "from config_space" ) max_resource_level = inferred_max_t return max_resource_level