Source code for syne_tune.stopping_criterion

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import numpy as np

from dataclasses import dataclass
from typing import Optional, Dict

from syne_tune.tuning_status import TuningStatus

logger = logging.getLogger(__name__)


[docs] @dataclass class StoppingCriterion: """ Stopping criterion that can be used in a Tuner, for instance :code:`Tuner(stop_criterion=StoppingCriterion(max_wallclock_time=3600), ...)`. If several arguments are used, the combined criterion is true whenever one of the atomic criteria is true. In principle, ``stop_criterion`` for ``Tuner`` can be any lambda function, but this class should be used with remote launching in order to ensure proper serialization. :param max_wallclock_time: Stop once this wallclock time is reached :param max_num_evaluations: Stop once more than this number of metric records have been reported :param max_num_trials_started: Stop once more than this number of trials have been started :param max_num_trials_completed: Stop once more than this number of trials have been completed. This does not include trials which were stopped or failed :param max_cost: Stop once total cost of evaluations larger than this value :param max_num_trials_finished: Stop once more than this number of trials have finished (i.e., completed, stopped, failed, or stopping) :param min_metric_value: Dictionary with thresholds for selected metrics. Stop once an evaluation reports a metric value below a threshold :param max_metric_value: Dictionary with thresholds for selected metrics. Stop once an evaluation reports a metric value above a threshold """ max_wallclock_time: float = None max_num_evaluations: int = None max_num_trials_started: int = None max_num_trials_completed: int = None max_cost: float = None max_num_trials_finished: int = None # minimum value for metrics, any value below this threshold will trigger a stop min_metric_value: Optional[Dict[str, float]] = None # maximum value for metrics, any value above this threshold will trigger a stop max_metric_value: Optional[Dict[str, float]] = None # todo we should have unit-test for all those cases. def __call__(self, status: TuningStatus) -> bool: if ( self.max_wallclock_time is not None and status.wallclock_time > self.max_wallclock_time ): logger.info( f"reaching max wallclock time ({self.max_wallclock_time}), stopping there." ) return True if ( self.max_num_trials_started is not None and status.num_trials_started > self.max_num_trials_started ): logger.info( f"reaching max number of trials started ({self.max_num_trials_started + 1}), stopping there." ) return True if ( self.max_num_trials_completed is not None and status.num_trials_completed > self.max_num_trials_completed ): logger.info( f"reaching max number of trials completed ({self.max_num_trials_completed + 1}), stopping there." ) return True if ( self.max_num_trials_finished is not None and status.num_trials_finished > self.max_num_trials_finished ): logger.info( f"reaching max number of trials finished ({self.max_num_trials_finished + 1}), stopping there." ) return True if self.max_cost is not None and status.cost > self.max_cost: logger.info(f"reaching max cost ({self.max_cost}), stopping there.") return True if ( self.max_num_evaluations is not None and status.overall_metric_statistics.count > self.max_num_evaluations ): logger.info( f"reaching {status.overall_metric_statistics.count + 1} evaluations, stopping there. " ) return True if ( self.max_metric_value is not None and status.overall_metric_statistics.count > 0 ): max_metrics_observed = status.overall_metric_statistics.max_metrics for metric, max_metric_accepted in self.max_metric_value.items(): if ( metric in max_metrics_observed and max_metrics_observed[metric] > max_metric_accepted ): logger.info( f"found {metric} with value ({max_metrics_observed[metric]}), " f"above the provided threshold {max_metric_accepted} stopping there." ) return True if ( self.min_metric_value is not None and status.overall_metric_statistics.count > 0 ): min_metrics_observed = status.overall_metric_statistics.min_metrics for metric, min_metric_accepted in self.min_metric_value.items(): if ( metric in min_metrics_observed and min_metrics_observed[metric] < min_metric_accepted ): logger.info( f"found {metric} with value ({min_metrics_observed[metric]}), " f"below the provided threshold {min_metric_accepted} stopping there." ) return True return False
[docs] class PlateauStopper(object): """ Stops the experiment when a metric plateaued for N consecutive trials for more than the given amount of iterations specified in the patience parameter. This code is inspired by Ray Tune. """ def __init__( self, metric: str, std: float = 0.001, num_trials: int = 10, mode: str = "min", patience: int = 0, ): """ :param metric: The metric to be monitored. :param std: The minimal standard deviation after which the tuning process has to stop. :param num_trials: The number of consecutive trials :param mode: The mode to select the top results ("min" or "max") :param patience: Number of iterations to wait for a change in the top models. """ if mode not in ("min", "max"): raise ValueError("The mode parameter can only be either min or max.") if not isinstance(num_trials, int) or num_trials <= 1: raise ValueError( "Top results to consider must be" " a positive integer greater than one." ) if not isinstance(patience, int) or patience < 0: raise ValueError("Patience must be a strictly positive integer.") if not isinstance(std, float) or std <= 0: raise ValueError( "The standard deviation must be a strictly positive float number." ) self._mode = mode self._metric = metric self._patience = patience self._iterations = 0 self._std = std self._num_trials = num_trials if self._mode == "min": self.multiplier = 1 else: self.multiplier = -1 def __call__(self, status: TuningStatus) -> bool: """Return a boolean representing if the tuning has to stop.""" if status.num_trials_finished == 0: return False trials = status.trial_rows trajectory = [] curr_best = None for ti in trials.values(): if self._metric in ti: y = self.multiplier * ti[self._metric] if curr_best is None or y < curr_best: curr_best = y trajectory.append(curr_best) top_values = trajectory[-self._num_trials :] # If the current iteration has to stop has_plateaued = ( len(top_values) == self._num_trials and np.std(top_values) <= self._std ) if has_plateaued: # we increment the total counter of iterations self._iterations += 1 else: # otherwise we reset the counter self._iterations = 0 # and then call the method that re-executes # the checks, including the iterations. return has_plateaued and self._iterations >= self._patience