Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Dict, Optional

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import (
    TrialEvaluations,
    PendingEvaluation,
    MetricValues,
    INTERNAL_METRIC_NAME,
)
from syne_tune.optimizer.schedulers.searchers.utils.common import (
    Configuration,
    ConfigurationFilter,
)
from syne_tune.optimizer.schedulers.searchers.utils.hp_ranges import (
    HyperparameterRanges,
)


[docs] class TuningJobState: """ Collects all data determining the state of a tuning experiment. Trials are indexed by ``trial_id``. The configurations associated with trials are listed in ``config_for_trial``. ``trials_evaluations`` contains observations, ``failed_trials`` lists trials for which evaluations have failed, ``pending_evaluations`` lists trials for which observations are pending. ``trials_evaluations`` may store values for different metrics in each record, and each such value may be a dict (see:class:`TrialEvaluations`). For example, for multi-fidelity schedulers, :code:`trials_evaluations[i].metrics[k][str(r)]` is the value for metric k and trial :code:`trials_evaluations[i].trial_id` observed at resource level r. """ def __init__( self, hp_ranges: HyperparameterRanges, config_for_trial: Dict[str, Configuration], trials_evaluations: List[TrialEvaluations], failed_trials: List[str] = None, pending_evaluations: List[PendingEvaluation] = None, ): if failed_trials is None: failed_trials = [] if pending_evaluations is None: pending_evaluations = [] self._check_trial_ids( config_for_trial, trials_evaluations, failed_trials, pending_evaluations ) self.hp_ranges = hp_ranges self.config_for_trial = config_for_trial self.trials_evaluations = trials_evaluations self.failed_trials = failed_trials self.pending_evaluations = pending_evaluations @staticmethod def _check_all_string(trial_ids: List[str], name: str): assert all( isinstance(x, str) for x in trial_ids ), f"trial_ids in {name} contain non-string values:\n{trial_ids}" @staticmethod def _check_trial_ids( config_for_trial, trials_evaluations, failed_trials, pending_evaluations ): observed_trials = [x.trial_id for x in trials_evaluations] pending_trials = [x.trial_id for x in pending_evaluations] TuningJobState._check_all_string(observed_trials, "trials_evaluations") TuningJobState._check_all_string(failed_trials, "failed_trials") TuningJobState._check_all_string(pending_trials, "pending_evaluations") trial_ids = set(observed_trials + failed_trials + pending_trials) for trial_id in trial_ids: assert ( trial_id in config_for_trial ), f"trial_id {trial_id} not contained in configs_for_trials"
[docs] @staticmethod def empty_state(hp_ranges: HyperparameterRanges) -> "TuningJobState": return TuningJobState( hp_ranges=hp_ranges, config_for_trial=dict(), trials_evaluations=[], failed_trials=[], pending_evaluations=[], )
def _find_labeled(self, trial_id: str) -> int: try: return next( i for i, x in enumerate(self.trials_evaluations) if x.trial_id == trial_id ) except StopIteration: return -1 def _find_pending(self, trial_id: str, resource: Optional[int] = None) -> int: try: return next( i for i, x in enumerate(self.pending_evaluations) if x.trial_id == trial_id and x.resource == resource ) except StopIteration: return -1 def _register_config_for_trial( self, trial_id: str, config: Optional[Configuration] = None ): if config is None: assert trial_id in self.config_for_trial, ( f"trial_id = {trial_id} not yet registered in " + "config_for_trial, so config must be given" ) elif trial_id not in self.config_for_trial: self.config_for_trial[trial_id] = config.copy()
[docs] def metrics_for_trial( self, trial_id: str, config: Optional[Configuration] = None ) -> MetricValues: """ Helper for inserting new entry into ``trials_evaluations``. If ``trial_id`` is already contained there, the corresponding ``eval.metrics`` is returned. Otherwise, a new entry ``new_eval`` is appended to ``trials_evaluations`` and its ``new_eval.metrics`` is returned (empty ``dict``). In the latter case, ``config`` needs to be passed, because it may not yet feature in ``config_for_trial``. """ # NOTE: If ``trial_id`` exists in ``config_for_trial`` and ``config`` is # given, we do not check that ``config`` is correct. In fact, we ignore # ``config`` in this case. self._register_config_for_trial(trial_id, config) pos = self._find_labeled(trial_id) if pos != -1: metrics = self.trials_evaluations[pos].metrics else: # New entry metrics = dict() new_eval = TrialEvaluations(trial_id=trial_id, metrics=metrics) self.trials_evaluations.append(new_eval) return metrics
[docs] def num_observed_cases( self, metric_name: str = INTERNAL_METRIC_NAME, resource: Optional[int] = None ) -> int: """ Counts the number of observations for metric ``metric_name``. :param metric_name: Defaults to :const:`INTERNAL_METRIC_NAME` :param resource: In the multi-fidelity case, we only count observations at this resource level :return: Number of observations """ return sum( ev.num_cases(metric_name, resource) for ev in self.trials_evaluations )
[docs] def observed_data_for_metric( self, metric_name: str = INTERNAL_METRIC_NAME, resource_attr_name: str = None ) -> (List[Configuration], List[float]): """ Extracts datapoints from ``trials_evaluations`` for particular metric ``metric_name``, in the form of a list of configs and a list of metric values. If ``metric_name`` is a dict-valued metric, the dict keys must be resource values, and the returned configs are extended. Here, the name of the resource attribute can be passed in ``resource_attr_name`` (if not given, it can be obtained from ``hp_ranges`` if this is extended). Note: Implements the default behaviour, namely to return extended configs for dict-valued metrics, which also require ``hp_ranges`` to be extended. This is not correct for some specific multi-fidelity surrogate models, which should access the data directly. :param metric_name: :param resource_attr_name: :return: configs, metric_values """ if resource_attr_name is None: resource_attr_name = self.hp_ranges.name_last_pos configs = [] metric_values = [] for ev in self.trials_evaluations: config = self.config_for_trial[ev.trial_id] metric_entry = ev.metrics.get(metric_name) if metric_entry is not None: if isinstance(metric_entry, dict): assert resource_attr_name is not None, ( "Need resource_attr_name for dict-valued metric " + metric_name ) for resource, metric_val in metric_entry.items(): config_ext = dict(config, **{resource_attr_name: int(resource)}) configs.append(config_ext) metric_values.append(metric_val) else: configs.append(config) metric_values.append(metric_entry) return configs, metric_values
[docs] def is_pending(self, trial_id: str, resource: Optional[int] = None) -> bool: return self._find_pending(trial_id, resource) != -1
[docs] def is_labeled( self, trial_id: str, metric_name: str = INTERNAL_METRIC_NAME, resource: Optional[int] = None, ) -> bool: """ Checks whether ``trial_id`` has observed data under ``metric_name``. If ``resource`` is given, the observation must be at that resource level. """ pos = self._find_labeled(trial_id) result = False if pos != -1: metric_entry = self.trials_evaluations[pos].metrics.get(metric_name) if metric_entry is not None: if resource is None: result = True elif isinstance(metric_entry, dict): result = str(resource) in metric_entry return result
[docs] def append_pending( self, trial_id: str, config: Optional[Configuration] = None, resource: Optional[int] = None, ): """ Appends new pending evaluation. If the trial has not been registered here, ``config`` must be given. Otherwise, it is ignored. """ self._register_config_for_trial(trial_id, config) assert not self.is_pending(trial_id, resource) self.pending_evaluations.append( PendingEvaluation(trial_id=trial_id, resource=resource) )
[docs] def remove_pending(self, trial_id: str, resource: Optional[int] = None) -> bool: pos = self._find_pending(trial_id, resource) if pos != -1: self.pending_evaluations.pop(pos) return True else: return False
[docs] def pending_configurations( self, resource_attr_name: str = None ) -> List[Configuration]: """ Returns list of configurations corresponding to pending evaluations. If the latter have resource values, the configs are extended. """ if resource_attr_name is None: resource_attr_name = self.hp_ranges.name_last_pos configs = [] for pend_eval in self.pending_evaluations: config = self.config_for_trial[pend_eval.trial_id] resource = pend_eval.resource if resource is not None: assert ( resource_attr_name is not None ), f"Need resource_attr_name, or hp_ranges to be extended" config = dict(config, **{resource_attr_name: int(resource)}) configs.append(config) return configs
def _map_configs_for_matching( self, config_for_trial: Dict[str, Configuration] ) -> Dict[str, str]: return { trial_id: self.hp_ranges.config_to_match_string(config) for trial_id, config in config_for_trial.items() } def __eq__(self, other) -> bool: if not isinstance(other, TuningJobState): return False if ( self.failed_trials != other.failed_trials or self.pending_evaluations != other.pending_evaluations ): return False if self.hp_ranges != other.hp_ranges: return False if self.trials_evaluations != other.trials_evaluations: return False return self._map_configs_for_matching( self.config_for_trial ) == self._map_configs_for_matching(other.config_for_trial)
[docs] def all_configurations( self, filter_observed_data: Optional[ConfigurationFilter] = None ) -> List[Configuration]: """ Returns list of configurations for all trials represented here, whether observed, pending, or failed. If ``filter_observed_data`` is given, the configurations for observed trials are filtered with this predicate. :param filter_observed_data: See above, optional :return: List of all configurations """ _elist = [x.trial_id for x in self.pending_evaluations] + self.failed_trials observed_trial_ids = [x.trial_id for x in self.trials_evaluations] if filter_observed_data is not None: observed_trial_ids = [ trial_id for trial_id in observed_trial_ids if filter_observed_data(self.config_for_trial[trial_id]) ] _elist = set(_elist + observed_trial_ids) return [self.config_for_trial[trial_id] for trial_id in _elist]