Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_transformer

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Optional, Callable, Union
import logging
import copy

from numpy.random import RandomState

from syne_tune.optimizer.schedulers.searchers.bayesopt.models.estimator import (
    Estimator,
    OutputEstimator,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes import (
    OutputPredictor,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state import (
    TuningJobState,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_skipopt import (
    SkipOptimizationPredicate,
    NeverSkipPredicate,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.common import (
    PendingEvaluation,
    TrialEvaluations,
    dictionarize_objective,
    INTERNAL_METRIC_NAME,
)
from syne_tune.optimizer.schedulers.searchers.utils.common import Configuration

logger = logging.getLogger(__name__)


def _assert_same_keys(dict1, dict2):
    assert set(dict1.keys()) == set(
        dict2.keys()
    ), f"{list(dict1.keys())} and {list(dict2.keys())} need to be the same keys. "


# Convenience type allowing for multi-output HPO. These are used for methods
# that work both in the standard case of a single output model and in the
# multi-output case

SkipOptimizationOutputPredicate = Union[
    SkipOptimizationPredicate, Dict[str, SkipOptimizationPredicate]
]


[docs] class StateForModelConverter: """ Interface for state converters (optionally) used in :class:`~syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_transformer.ModelStateTransformer`. These are applied to a state before being passed to the model for fitting and predictions. The main use case is to filter down data if fitting the model scales super-linearly. """ def __call__(self, state: TuningJobState) -> TuningJobState: raise NotImplementedError
[docs] def set_random_state(self, random_state: RandomState): """ Some state converters use random sampling. For these, the random state has to be set before first usage. :param random_state: Random state to be used internally """ pass
[docs] class ModelStateTransformer: """ This class maintains the :class:`~syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state.TuningJobState` object alongside an HPO experiment, and manages the reaction to changes of this state. In particular, it provides a fitted surrogate model on demand, which encapsulates the GP posterior. The state transformer is generic, it uses :class:`Estimator` for anything specific to the model type. ``skip_optimization`` is a predicate depending on the state, determining what is done at the next recent call of ``model``. If ``False``, the model parameters are refit, otherwise the current ones are not changed (which is usually faster, but risks stale-ness). We also track the observed data ``state.trials_evaluations``. If this did not change since the last recent :meth:`model` call, we do not refit the model parameters. This is based on the assumption that model parameter fitting only depends on ``state.trials_evaluations`` (observed data), not on other fields (e.g., pending evaluations). If given, ``state_converter`` maps the state to another one which is then passed to the model for fitting and predictions. One important use case is filtering down data when model fitting is superlinear. Another is to convert multi-fidelity setups to be used with single-fidelity models inside. Note that ``estimator`` and ``skip_optimization`` can also be a dictionary mapping output names to models. In that case, the state is shared but the models for each output metric are updated independently. :param estimator: Surrogate model(s) :param init_state: Initial tuning job state :param skip_optimization: Skip optimization predicate (see above). Defaults to ``None`` (fitting is never skipped) :param state_converter: See above, optional """ def __init__( self, estimator: OutputEstimator, init_state: TuningJobState, skip_optimization: Optional[SkipOptimizationOutputPredicate] = None, state_converter: Optional[StateForModelConverter] = None, ): self._use_single_model = False if isinstance(estimator, Estimator): self._use_single_model = True if not self._use_single_model: assert isinstance(estimator, dict), ( f"{estimator} is not an instance of Estimator. " f"It is assumed that we are in the multi-output case and that it " f"must be a Dict. No other types are supported. " ) # Default: Always refit model parameters for each output model if skip_optimization is None: skip_optimization = { output_name: NeverSkipPredicate() for output_name in estimator.keys() } else: assert isinstance(skip_optimization, Dict), ( f"{skip_optimization} must be a Dict, consistently " f"with {estimator}." ) _assert_same_keys(estimator, skip_optimization) skip_optimization = { output_name: skip_optimization[output_name] if skip_optimization.get(output_name) is not None else NeverSkipPredicate() for output_name in estimator.keys() } # debug_log is shared by all output models self._debug_log = next(iter(estimator.values())).debug_log else: if skip_optimization is None: # Default: Always refit model parameters skip_optimization = NeverSkipPredicate() assert isinstance(skip_optimization, SkipOptimizationPredicate) self._debug_log = estimator.debug_log # Make ``estimator`` and ``skip_optimization`` single-key dictionaries # for convenience, so that we can treat the single model and multi-model case in the same way estimator = dictionarize_objective(estimator) skip_optimization = dictionarize_objective(skip_optimization) self._estimator = estimator self._skip_optimization = skip_optimization self._state_converter = state_converter self._state = copy.copy(init_state) # OutputPredictor computed on demand self._predictor: Optional[OutputPredictor] = None # Observed data for which model parameters were re-fit most # recently, separately for each model self._num_evaluations = {output_name: 0 for output_name in estimator.keys()} @property def state(self) -> TuningJobState: return self._state def _unwrap_from_dict(self, x): if self._use_single_model: return next(iter(x.values())) else: return x @property def use_single_model(self) -> bool: return self._use_single_model @property def estimator(self) -> OutputEstimator: return self._unwrap_from_dict(self._estimator) @property def skip_optimization(self) -> SkipOptimizationOutputPredicate: return self._unwrap_from_dict(self._skip_optimization)
[docs] def fit(self, **kwargs) -> OutputPredictor: """ If ``skip_optimization`` is given, it overrides the ``self._skip_optimization`` predicate. :return: Fitted surrogate model for current state in the standard single model case; in the multi-model case, it returns a dictionary mapping output names to surrogate model instances for current state (shared across models). """ if self._predictor is None: skip_optimization = kwargs.get("skip_optimization") self._compute_predictor(skip_optimization=skip_optimization) return self._unwrap_from_dict(self._predictor)
[docs] def get_params(self): params = { output_name: output_estimator.get_params() for output_name, output_estimator in self._estimator.items() } return self._unwrap_from_dict(params)
[docs] def set_params(self, param_dict): if self._use_single_model: param_dict = dictionarize_objective(param_dict) _assert_same_keys(self._estimator, param_dict) for output_name in self._estimator: self._estimator[output_name].set_params(param_dict[output_name])
[docs] def append_trial( self, trial_id: str, config: Optional[Configuration] = None, resource: Optional[int] = None, ): """ Appends new pending evaluation to the state. :param trial_id: ID of trial :param config: Must be given if this trial does not yet feature in the state :param resource: Must be given in the multi-fidelity case, to specify at which resource level the evaluation is pending """ self._predictor = None # Invalidate self._state.append_pending(trial_id, config=config, resource=resource)
[docs] def drop_pending_evaluation( self, trial_id: str, resource: Optional[int] = None ) -> bool: """ Drop pending evaluation from state. If it is not listed as pending, nothing is done :param trial_id: ID of trial :param resource: Must be given in the multi-fidelity case, to specify at which resource level the evaluation is pending """ return self._state.remove_pending(trial_id, resource)
[docs] def remove_observed_case( self, trial_id: str, metric_name: str = INTERNAL_METRIC_NAME, key: Optional[str] = None, ): """ Removes specific observation from the state. :param trial_id: ID of trial :param metric_name: Name of internal metric :param key: Must be given in the multi-fidelity case """ pos = self._state._find_labeled(trial_id) assert pos != -1, f"Trial trial_id = {trial_id} has no observations" metrics = self._state.trials_evaluations[pos].metrics assert metric_name in metrics, ( f"state.trials_evaluations entry for trial_id = {trial_id} " + f"does not contain metric {metric_name}" ) if key is None: del metrics[metric_name] else: metric_vals = metrics[metric_name] assert isinstance(metric_vals, dict) and key in metric_vals, ( f"state.trials_evaluations entry for trial_id = {trial_id} " + f"and metric {metric_name} does not contain case for " + f"key {key}" ) del metric_vals[key]
[docs] def label_trial( self, data: TrialEvaluations, config: Optional[Configuration] = None ): """ Adds observed data for a trial. If it has observations in the state already, ``data.metrics`` are appended. Otherwise, a new entry is appended. If new observations replace pending evaluations, these are removed. ``config`` must be passed if the trial has not yet been registered in the state (this happens normally with the ``append_trial`` call). If already registered, ``config`` is ignored. """ # Drop pending candidate if it exists trial_id = data.trial_id metric_vals = data.metrics.get(INTERNAL_METRIC_NAME) if metric_vals is not None: resource_attr_name = self._state.hp_ranges.name_last_pos if resource_attr_name is not None: assert isinstance( metric_vals, dict ), f"Metric {INTERNAL_METRIC_NAME} must be dict-valued" for resource in metric_vals.keys(): self.drop_pending_evaluation(trial_id, resource=int(resource)) else: self.drop_pending_evaluation(trial_id) # Assign / append new labels metrics = self._state.metrics_for_trial(trial_id, config=config) for name, new_labels in data.metrics.items(): if name not in metrics or not isinstance(new_labels, dict): metrics[name] = new_labels else: metrics[name].update(new_labels) self._predictor = None # Invalidate
[docs] def filter_pending_evaluations( self, filter_pred: Callable[[PendingEvaluation], bool] ): """ Filters ``state.pending_evaluations`` with ``filter_pred``. :param filter_pred: Filtering predicate """ new_pending_evaluations = list( filter(filter_pred, self._state.pending_evaluations) ) if len(new_pending_evaluations) != len(self._state.pending_evaluations): self._predictor = None # Invalidate del self._state.pending_evaluations[:] self._state.pending_evaluations.extend(new_pending_evaluations)
[docs] def mark_trial_failed(self, trial_id: str): failed_trials = self._state.failed_trials if trial_id not in failed_trials: failed_trials.append(trial_id)
def _compute_predictor(self, skip_optimization=None): if skip_optimization is None: skip_optimization = dict() for ( output_name, output_skip_optimization, ) in self._skip_optimization.items(): skip_optimization[output_name] = output_skip_optimization(self._state) elif self._use_single_model: skip_optimization = dictionarize_objective(skip_optimization) if self._debug_log is not None: for output_name, skip_opt in skip_optimization.items(): if skip_opt: logger.info( f"Skipping the refitting of model parameters for {output_name}" ) _assert_same_keys(skip_optimization, self._estimator) state_for_model = ( self._state if self._state_converter is None else self._state_converter(self._state) ) output_predictors = dict() for output_name, output_skip_optimization in skip_optimization.items(): update_params = not output_skip_optimization if update_params: # Did the labeled data really change since the last recent refit? # If not, skip the refitting num_evaluations = self._state.num_observed_cases(output_name) if num_evaluations == self._num_evaluations[output_name]: update_params = False if self._debug_log is not None: logger.info( f"Skipping the refitting of model parameters for {output_name}, " f"since the labeled data did not change since the last recent fit" ) else: # Model will be refitted: Update self._num_evaluations[output_name] = num_evaluations output_predictors[output_name] = self._estimator[ output_name ].fit_from_state(state=state_for_model, update_params=update_params) self._predictor = output_predictors