Source code for syne_tune.optimizer.schedulers.searchers.searcher_factory

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging

from syne_tune.try_import import (
    try_import_gpsearchers_message,
    try_import_kde_message,
    try_import_bore_message,
    try_import_botorch_message,
)
from syne_tune.optimizer.schedulers.searchers import (
    BaseSearcher,
    RandomSearcher,
    GridSearcher,
)

__all__ = ["searcher_factory"]

logger = logging.getLogger(__name__)


SUPPORTED_SEARCHERS_FIFO = {
    "random",
    "grid",
    "kde",
    "bore",
    "cqr",
    "botorch",
    "bayesopt",
    "bayesopt_constrained",
    "bayesopt_cost",
}


SUPPORTED_SEARCHERS_HYPERBAND = {
    "random",
    "grid",
    "kde",
    "bore",
    "cqr",
    "bayesopt",
    "bayesopt_cost",
    "hypertune",
    "dyhpo",
}


_OUR_MULTIFIDELITY_SCHEDULERS = {
    "hyperband_stopping",
    "hyperband_promotion",
    "hyperband_cost_promotion",
    "hyperband_pasha",
    "hyperband_synchronous",
}


[docs] def searcher_factory(searcher_name: str, **kwargs) -> BaseSearcher: """Factory for searcher objects This function creates searcher objects from string argument name and additional kwargs. It is typically called in the constructor of a scheduler (see :class:`~syne_tune.optimizer.schedulers.FIFOScheduler`), which provides most of the required ``kwargs``. :param searcher_name: Value of ``searcher`` argument to scheduler (see :class:`~syne_tune.optimizer.schedulers.FIFOScheduler`) :param kwargs: Argument to :class:`~syne_tune.optimizer.schedulers.searchers.BaseSearcher` constructor :return: New searcher object """ supported_schedulers = None scheduler = kwargs.get("scheduler") model = kwargs.get("model", "gp_multitask") if searcher_name == "random": searcher_cls = RandomSearcher elif searcher_name == "grid": searcher_cls = GridSearcher elif searcher_name == "kde": try: from syne_tune.optimizer.schedulers.searchers.kde import ( KernelDensityEstimator, MultiFidelityKernelDensityEstimator, ) except ImportError: logger.info(try_import_kde_message()) raise if scheduler == "fifo": searcher_cls = KernelDensityEstimator else: supported_schedulers = _OUR_MULTIFIDELITY_SCHEDULERS searcher_cls = MultiFidelityKernelDensityEstimator elif searcher_name == "bore": try: from syne_tune.optimizer.schedulers.searchers.bore import ( Bore, MultiFidelityBore, ) except ImportError: logger.info(try_import_bore_message()) raise if scheduler == "fifo": searcher_cls = Bore else: supported_schedulers = _OUR_MULTIFIDELITY_SCHEDULERS searcher_cls = MultiFidelityBore elif searcher_name == "cqr": try: from syne_tune.optimizer.schedulers.searchers.conformal.surrogate_searcher import ( SurrogateSearcher, ) except ImportError: logger.info(try_import_bore_message()) raise searcher_cls = SurrogateSearcher elif searcher_name == "botorch": try: from syne_tune.optimizer.schedulers.searchers.botorch import ( BoTorchSearcher, ) except ImportError: logger.info(try_import_botorch_message()) raise searcher_cls = BoTorchSearcher supported_schedulers = {"fifo"} else: gp_searchers = { "bayesopt", "bayesopt_constrained", "bayesopt_cost", "hypertune", "dyhpo", } assert ( searcher_name in gp_searchers ), f"searcher '{searcher_name}' is not supported" try: from syne_tune.optimizer.schedulers.searchers import ( GPFIFOSearcher, GPMultiFidelitySearcher, ) from syne_tune.optimizer.schedulers.searchers.constrained import ( ConstrainedGPFIFOSearcher, ) from syne_tune.optimizer.schedulers.searchers.cost_aware import ( CostAwareGPFIFOSearcher, CostAwareGPMultiFidelitySearcher, ) from syne_tune.optimizer.schedulers.searchers.hypertune import ( HyperTuneSearcher, ) from syne_tune.optimizer.schedulers.searchers.dyhpo import ( DynamicHPOSearcher, ) except ImportError: logger.info(try_import_gpsearchers_message()) raise if searcher_name == "bayesopt": if scheduler == "fifo": searcher_cls = GPFIFOSearcher else: supported_schedulers = _OUR_MULTIFIDELITY_SCHEDULERS if ( model == "gp_multitask" and kwargs.get("gp_resource_kernel") == "freeze-thaw" ): logger.warning( "You are combining model = gp_multitask with " "gp_resource_kernel = freeze-thaw. This is mainly " "for debug purposes. The same surrogate model is " "obtained with model = gp_expdecay, but computations " "are faster then." ) searcher_cls = GPMultiFidelitySearcher elif searcher_name == "hypertune": supported_schedulers = _OUR_MULTIFIDELITY_SCHEDULERS searcher_cls = HyperTuneSearcher elif searcher_name == "bayesopt_constrained": supported_schedulers = {"fifo"} searcher_cls = ConstrainedGPFIFOSearcher elif searcher_name == "dyhpo": supported_schedulers = {"hyperband_dyhpo"} searcher_cls = DynamicHPOSearcher else: # bayesopt_cost if scheduler == "fifo": searcher_cls = CostAwareGPFIFOSearcher else: supported_schedulers = _OUR_MULTIFIDELITY_SCHEDULERS searcher_cls = CostAwareGPMultiFidelitySearcher if supported_schedulers is not None: assert scheduler is not None, "Scheduler must set search_options['scheduler']" assert scheduler in supported_schedulers, ( f"Searcher '{searcher_name}' only works with schedulers " + f"{supported_schedulers} (not with '{scheduler}')" ) searcher = searcher_cls(**kwargs) return searcher