Source code for syne_tune.optimizer.schedulers.ray_scheduler

# Copyright 2021, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Optional, List
import logging

from syne_tune.optimizer.scheduler import TrialScheduler, TrialSuggestion
from syne_tune.backend.trial_status import Trial
import syne_tune.config_space as sp

__all__ = ["RayTuneScheduler"]

logger = logging.getLogger(__name__)

[docs] class RayTuneScheduler(TrialScheduler): """ Allow using Ray scheduler and searcher. Any searcher/scheduler should work, except such which need access to ``TrialRunner`` (e.g., PBT), this feature is not implemented in Syne Tune. If ``ray_searcher`` is not given (defaults to random searcher), initial configurations to evaluate can be passed in ``points_to_evaluate``. If ``ray_searcher`` is given, this argument is ignored (needs to be passed to ``ray_searcher`` at construction). Note: Use :func:`~syne_tune.optimizer.schedulers.searchers.impute_points_to_evaluate` in order to preprocess ``points_to_evaluate`` specified by the user or the benchmark. :param config_space: Configuration space :param ray_scheduler: Ray scheduler, defaults to FIFO scheduler :param ray_searcher: Ray searcher, defaults to random search :param points_to_evaluate: See above """ from ray.tune.schedulers import FIFOScheduler as RT_FIFOScheduler from import Searcher as RT_Searcher
[docs] class RandomSearch(RT_Searcher): def __init__( self, config_space: Dict, points_to_evaluate: List[Dict], mode: str ): super().__init__(mode=mode) self.config_space = config_space self._points_to_evaluate = points_to_evaluate def _next_initial_config(self) -> Optional[Dict]: if self._points_to_evaluate: return self._points_to_evaluate.pop(0) else: return None # No more initial configs
[docs] def suggest(self, trial_id: str) -> Optional[Dict]: config = self._next_initial_config() if config is None: config = { k: v.sample() if hasattr(v, "sample") else v for k, v in self.config_space.items() } return config
[docs] def on_trial_complete( self, trial_id: str, result: Optional[Dict] = None, error: bool = False ): pass
def __init__( self, config_space: Dict, ray_scheduler=None, ray_searcher: Optional[RT_Searcher] = None, points_to_evaluate: Optional[List[Dict]] = None, ): super().__init__(config_space) if ray_scheduler is None: ray_scheduler = self.RT_FIFOScheduler() self.scheduler = ray_scheduler if ray_searcher is not None: self.mode = ray_searcher.mode else: if hasattr(ray_scheduler, "_mode"): self.mode = ray_scheduler._mode else: self.mode = "min" if ray_searcher is None: ray_searcher = self.RandomSearch( config_space=self.convert_config_space(config_space), points_to_evaluate=points_to_evaluate, mode=self.mode, ) elif points_to_evaluate is not None: logger.warning( "points_to_evaluate specified here will not be used. Pass this" " argument when creating ray_searcher" ) self.searcher = ray_searcher # todo this one is not implemented yet, PBT would require it self.trial_runner_wrapper = None if self.searcher.metric is not None and self.scheduler.metric is not None: assert ( self.scheduler.metric == self.searcher.metric ), "searcher and scheduler must have the same metric."
[docs] def on_trial_add(self, trial: Trial): self.scheduler.on_trial_add( trial_runner=self.trial_runner_wrapper, trial=trial, )
[docs] def on_trial_error(self, trial: Trial): self.scheduler.on_trial_error( trial_runner=self.trial_runner_wrapper, trial=trial, )
[docs] def on_trial_result(self, trial: Trial, result: Dict) -> str: self._check_valid_result(result=result) self.searcher.on_trial_result(trial_id=str(trial.trial_id), result=result) return self.scheduler.on_trial_result( trial_runner=self.trial_runner_wrapper, trial=trial, result=result )
[docs] def on_trial_complete(self, trial: Trial, result: Dict): self._check_valid_result(result=result) self.searcher.on_trial_complete(trial_id=str(trial.trial_id), result=result) self.scheduler.on_trial_complete( trial_runner=self.trial_runner_wrapper, trial=trial, result=result )
def _check_valid_result(self, result: Dict): for m in self.metric_names(): assert m in result, ( f"metric {m} is not present in reported results {result}," f" the metrics present when calling ``report(...)`` in your training functions should" f" be identical to the ones passed as metrics/time_attr to the scheduler and searcher" )
[docs] def on_trial_remove(self, trial: Trial): return self.scheduler.on_trial_remove( trial_runner=self.trial_runner_wrapper, trial=trial )
def _suggest(self, trial_id: int) -> Optional[TrialSuggestion]: config = self.searcher.suggest(trial_id=str(trial_id)) return TrialSuggestion.start_suggestion(config)
[docs] def metric_names(self) -> List[str]: return [self.scheduler.metric]
[docs] def metric_mode(self) -> str: return self.mode
[docs] @staticmethod def convert_config_space(config_space): """ Converts config_space from our type to the one of Ray Tune. Note: ``randint(lower, upper)`` in Ray Tune has exclusive ``upper``, while this is inclusive for us. On the other hand, ``lograndint(lower, upper)`` has inclusive ``upper`` in Ray Tune as well. :param config_space: Configuration space :return: ``config_space`` converted into Ray Tune type """ import as ray_sp ray_config_space = dict() for name, hp_range in config_space.items(): assert not isinstance( hp_range, sp.FiniteRange ), f"'{name}' has type FiniteRange, not supported by Ray Tune" if isinstance(hp_range, sp.Domain): cls_mapping = { sp.Integer: ray_sp.Integer, sp.Float: ray_sp.Float, sp.LogUniform: ray_sp.LogUniform, sp.Categorical: ray_sp.Categorical, sp.Normal: ray_sp.Normal, } sampler_mapping = { sp.Integer._Uniform: ray_sp.Integer._Uniform, sp.Integer._LogUniform: ray_sp.Integer._LogUniform, sp.Float._Uniform: ray_sp.Float._Uniform, sp.Float._LogUniform: ray_sp.Float._LogUniform, sp.Categorical._Uniform: ray_sp.Categorical._Uniform, sp.Float._Normal: ray_sp.Float._Normal, } ray_cls = cls_mapping[type(hp_range)] domain_kwargs = { k: v for k, v in hp_range.__dict__.items() if k != "sampler" } # Note: ``tune.randint`` has exclusive upper while we have inclusive if isinstance(hp_range, sp.Integer): domain_kwargs["upper"] = domain_kwargs["upper"] + 1 ray_domain = ray_cls(**domain_kwargs) ray_sampler = sampler_mapping[type(hp_range.get_sampler())]( **hp_range.get_sampler().__dict__ ) ray_domain.set_sampler(ray_sampler) ray_config_space[name] = ray_domain else: ray_config_space[name] = hp_range return ray_config_space