Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_base

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Optional
import numpy as np
import logging

from syne_tune.optimizer.schedulers.searchers.bayesopt.datatypes.tuning_job_state import (
    TuningJobState,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes import (
    Predictor,
)
from syne_tune.optimizer.schedulers.searchers.utils.common import ConfigurationFilter

logger = logging.getLogger(__name__)


[docs] class BasePredictor(Predictor): """ Base class for (most) :class:`~syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes.Predictor` implementations, provides common code. """ def __init__( self, state: TuningJobState, active_metric: Optional[str] = None, filter_observed_data: Optional[ConfigurationFilter] = None, ): super().__init__(state, active_metric) self._current_best = None self._filter_observed_data = filter_observed_data @property def filter_observed_data(self) -> Optional[ConfigurationFilter]: return self._filter_observed_data
[docs] def set_filter_observed_data( self, filter_observed_data: Optional[ConfigurationFilter] ): self._filter_observed_data = filter_observed_data self._current_best = None
[docs] def predict_mean_current_candidates(self) -> List[np.ndarray]: """ Returns the predictive mean (signal with key 'mean') at all current candidates in the state (observed, pending). If the hyperparameters of the surrogate model are being optimized (e.g., by empirical Bayes), the returned list has length 1. If its hyperparameters are averaged over by MCMC, the returned list has one entry per MCMC sample. :return: List of predictive means """ candidates, _ = self.state.observed_data_for_metric(self.active_metric) candidates += self.state.pending_configurations() candidates = self._current_best_filter_candidates(candidates) assert ( len(candidates) > 0 ), "Cannot predict means at current candidates with no candidates at all" inputs = self.hp_ranges_for_prediction().to_ndarray_matrix(candidates) all_means = [] # Loop over MCMC samples (if any) for prediction in self.predict(inputs): means = prediction["mean"] if means.ndim == 1: # In case of no fantasizing means = means.reshape((-1, 1)) all_means.append(means) return all_means
[docs] def current_best(self) -> List[np.ndarray]: if self._current_best is None: all_means = self.predict_mean_current_candidates() result = [np.min(means, axis=0) for means in all_means] self._current_best = result return self._current_best
def _current_best_filter_candidates(self, candidates): """ In some subclasses, 'current_best' is not computed over all (observed and pending) candidates: they need to implement this filter. """ if self._filter_observed_data is None: return candidates # Default: No filtering else: filtered_candidates = [ config for config in candidates if self._filter_observed_data(config) ] return filtered_candidates