Source code for syne_tune.optimizer.schedulers.ask_tell_scheduler

from typing import Dict
import datetime

from syne_tune.backend.trial_status import Trial, Status, TrialResult
from syne_tune.optimizer.scheduler import TrialScheduler


[docs] class AskTellScheduler: base_scheduler: TrialScheduler trial_counter: int completed_experiments: Dict[int, TrialResult] def __init__(self, base_scheduler: TrialScheduler): """ Simple interface to use SyneTune schedulers in a custom loop, for example: .. code-block:: python scheduler = AskTellScheduler(base_scheduler=RandomSearch(config_space, metric=metric, mode=mode)) for iter in range(max_iterations): trial_suggestion = scheduler.ask() test_result = target_function(**trial_suggestion.config) scheduler.tell(trial_suggestion, {metric: test_result}) :param base_scheduler: Scheduler to be wrapped """ self.base_scheduler = base_scheduler self.trial_counter = 0 self.completed_experiments = {}
[docs] def ask(self) -> Trial: """ Ask the scheduler for new trial to run :return: Trial to run """ trial_suggestion = self.base_scheduler.suggest() trial = Trial( trial_id=self.trial_counter, config=trial_suggestion.config, creation_time=datetime.datetime.now(), ) self.trial_counter += 1 return trial
[docs] def tell(self, trial: Trial, experiment_result: Dict[str, float]): """ Feed experiment results back to the Scheduler :param trial: Trial that was run :param experiment_result: {metric: value} dictionary with experiment results """ trial_result = trial.add_results( metrics=experiment_result, status=Status.completed, training_end_time=datetime.datetime.now(), ) self.base_scheduler.on_trial_complete(trial=trial, result=experiment_result) self.completed_experiments[trial_result.trial_id] = trial_result
[docs] def best_trial(self, metric: str) -> TrialResult: """ Return the best trial according to the provided metric. :param metric: Metric to use for comparison """ if self.base_scheduler.mode == "max": sign = 1.0 else: sign = -1.0 return max( [value for key, value in self.completed_experiments.items()], key=lambda trial: sign * trial.metrics[metric], )