Implementing Components of Bayesian Optimization

At this point, you should have obtained an overview of how Bayesian optimization (BO) is structured in Syne Tune, and understood how a new surrogate model can be implemented. In this section, we turn to other components of BO: the acquisition function, and the covariance kernel of the Gaussian process surrogate model. We also look inside the factory for creating Gaussian process based searchers.

Implementing an Acquisition Function

In Bayesian optimization, the next configuration to sample at is chosen by minimizing an acquisition function:

\[\mathbf{x}_* = \mathrm{argmin}_{\mathbf{x}} \alpha(\mathbf{x})\]

In general, the acquisition function \(\alpha(\mathbf{x})\) is optimized over encoded vectors \(\mathbf{x}\), and the optimal \(\mathbf{x}_*\) is rounded back to a configuration. This allows for gradient-based optimization of \(\alpha(\mathbf{x})\).

In Syne Tune, acquisition functions are subclasses of AcquisitionFunction. It may depend on one or more surrogate models, by being a function of the predictive statistics returned by the predict method of Predictor. For a wide range of acquisition functions used in practice, we have that

\[\alpha(\mathbf{x}) = \alpha(\mu(\mathbf{x}), \sigma(\mathbf{x})).\]

In other words, \(\alpha(\mathbf{x})\) is a function of the predictive mean and standard deviation of a single surrogate model. This case is covered by MeanStdAcquisitionFunction. More general, this class implements acquisition functions depending on one or more surrogate models, each of which returning means and (optionally) standard deviations in predict. Given the generic code in Syne Tune, a new acquisition function of this type is easy to implement. As an example, consider the lower confidence bound (LCB) acquisition function:

\[\alpha_{\mathrm{LCB}}(\mathbf{x}) = \mu(\mathbf{x}) - \kappa \sigma(\mathbf{x}),\quad \kappa > 0.\]

Here is the code:

bayesopt/models/meanstd_acqfunc_impl.py
class LCBAcquisitionFunction(MeanStdAcquisitionFunction):
    r"""
    Lower confidence bound (LCB) acquisition function:

    .. math::

       h(\mu, \sigma) = \mu - \kappa * \sigma
    """

    def __init__(self, predictor: Predictor, kappa: float, active_metric: str = None):
        super().__init__(predictor, active_metric)
        assert isinstance(predictor, Predictor)
        assert kappa > 0, "kappa must be positive"
        self.kappa = kappa

    def _head_needs_current_best(self) -> bool:
        return False

    def _compute_head(
        self,
        output_to_predictions: SamplePredictionsPerOutput,
        current_best: Optional[np.ndarray],
    ) -> np.ndarray:
        means, stds = self._extract_mean_and_std(output_to_predictions)
        return np.mean(means - stds * self.kappa, axis=1)

    def _compute_head_and_gradient(
        self,
        output_to_predictions: SamplePredictionsPerOutput,
        current_best: Optional[np.ndarray],
    ) -> HeadWithGradient:
        mean, std = self._extract_mean_and_std(output_to_predictions)
        nf_mean = mean.size

        dh_dmean = np.ones_like(mean) / nf_mean
        dh_dstd = (-self.kappa) * np.ones_like(std)
        return HeadWithGradient(
            hval=np.mean(mean - std * self.kappa),
            gradient={self.active_metric: dict(mean=dh_dmean, std=dh_dstd)},
        )


  • An object is constructed by passing model (a Predictor) and kappa (the positive constant \(\kappa\)). The surrogate model must return means and standard deviations in its predict method.

  • _compute_head: This method computes \(\alpha(\mathbf{\mu}, \mathbf{\sigma})\), given means and standard deviations. The argument output_to_predictions is a dictionary of dictionaries. If the acquisition function depends on a dictionary of surrogate models, the first level corresponds to that. The second level corresponds to the statistics returned by predict. In the simple case here, the first level is a single entry with key INTERNAL_METRIC_NAME, and the second level uses keys “mean” and “std” for means \(\mathbf{\mu}\) and stddevs \(\mathbf{\sigma}\). Recall that due to fantasizing, the “mean” entry can be a (n, nf) matrix, in which case we compute the average along the columns. The argument current_best is needed only for acquisition functions which depend on the incumbent.

  • _compute_head_and_gradient: This method is needed for the computation of \(\partial\alpha/\partial\mathbf{x}\), for a single input \(\mathbf{x}\). Given the same arguments as _compute_head (but for \(n = 1\) inputs), it returns a HeadWithGradient object, whose hval entry is the same as the return value of _compute_head, whereas the gradient entry contains the head gradients which are passed to the backward_gradient method of the Predictor. This entry is a nested dictionary of the same structure as output_to_predictions. The head gradient for a single surrogate model (as in our example) has \(\partial\alpha/(\partial\mathbf{\mu})\) for “mean” and \(\partial\alpha/(\partial\mathbf{\sigma})\) for “std”. It is particularly simple for the LCB example.

  • _head_needs_current_best returns False, since the LCB acquisition function does not depend on the incumbent (i.e., the current best metric value), which means that the current_best arguments need not be provided.

Finally, a new acquisition function should be linked into acquisition_function_factory(), so that users can select it via arguments acq_function and acq_function_kwargs in BayesianOptimization. The factory code is:

bayesopt/models/acqfunc_factory.py
from functools import partial

from syne_tune.optimizer.schedulers.searchers.bayesopt.tuning_algorithms.base_classes import (
    AcquisitionFunctionConstructor,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.models.meanstd_acqfunc_impl import (
    EIAcquisitionFunction,
    LCBAcquisitionFunction,
)


SUPPORTED_ACQUISITION_FUNCTIONS = (
    "ei",
    "lcb",
)


def acquisition_function_factory(name: str, **kwargs) -> AcquisitionFunctionConstructor:
    assert (
        name in SUPPORTED_ACQUISITION_FUNCTIONS
    ), f"name = {name} not supported. Choose from:\n{SUPPORTED_ACQUISITION_FUNCTIONS}"
    if name == "ei":
        return EIAcquisitionFunction
    else:  # name == "lcb"
        kappa = kwargs.get("kappa", 1.0)
        return partial(LCBAcquisitionFunction, kappa=kappa)

Here, acq_function_kwargs is passed as kwargs. For our example, acq_function="lcb". The user can pass a value for kappa via acq_function_kwargs={"kappa": 0.5}.

A slightly more involved example is EIAcquisitionFunction, representing the expected improvement (EI) acquisition function, which is the default choice for BayesianOptimization in Syne Tune. This function depends on the incumbent, so current_best needs to be given. Note that if the means passed to _compute_head have shape (n, nf) due to fantasies, then current_best has shape (1, nf), since the incumbent depends on the fantasy sample.

Acquisition functions can depend on more than one surrogate model. In such a case, the model argument to their constructor is a dictionary, and the key names of the corresponding models (or outputs) are also used in the output_to_predictions arguments and head gradients:

  • EIpuAcquisitionFunction is an acquisition function for cost-aware HPO:

    \[\alpha_{\mathrm{EIpu}}(\mathbf{x}) = \frac{\alpha_{\mathrm{EI}}(\mu_y(\mathbf{x}), \sigma_y(\mathbf{x}))}{\mu_c(\mathbf{x})^{\rho}}\]

    Here, \((\mu_y, \sigma_y)\) are predictions from the surrogate model for the target function \(y(\mathbf{x})\), whereas \(\mu_c\) are mean predictions for the cost function \(c(\mathbf{x})\). The latter can be represented by a deterministic surrogate model, whose predict method only returns means as “mean”. In fact, the method _output_to_keys_predict specifies which moments are required from each surrogate model.

  • CEIAcquisitionFunction is an acquisition function for constrained HPO:

    \[\alpha_{\mathrm{CEI}}(\mathbf{x}) = \alpha_{\mathrm{EI}}(\mu_y(\mathbf{x}), \sigma_y(\mathbf{x})) \cdot \mathbb{P}(c(\mathbf{x})\le 0).\]

    Here, \(y(\mathbf{x})\) is the target function, \(c(\mathbf{x})\) is the constraint function. Both functions are represented by probabilistic surrogate models, whose predict method returns means and stddevs. We say that \(\mathbf{x}\) is feasible if \(c(\mathbf{x})\le 0\), and the goal is to minimize \(y(\mathbf{x})\) over feasible points.

    One difficulty with this acquisition function is that the incumbent in the EI term is computed only over observations which are feasible (so \(c_i\le 0\)). This means we cannot rely on the surrogate model for \(y(\mathbf{x})\) to provide the incumbent, but instead need to determine the feasible incumbent ourselves, in the _get_current_bests_internal method.

A final complication in MeanStdAcquisitionFunction arises if some or all surrogate models are MCMC ensembles. In such a case, we average over the sample for each surrogate model involved. Inside this sum over the Cartesian product, the incumbent depends on the sample index for each model. This is dealt with by CurrentBestProvider. In the default case for an acquisition function which needs the incumbent (such as, for example, EI), this value depends only on the model for the active (target) metric, and ActiveMetricCurrentBestProvider is used.

Note

Acquisition function implementations are independent of which auto-differentiation mechanism is used under the hood. Different to surrogate models, there is no acquisition function code in syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd. This is because the implementation only needs to provide head gradients in compute_acq_with_gradient, which are easy to derive and compute for common acquisition functions.

Implementing a Covariance Function for GP Surrogate Models

A Gaussian process, modelling a random function \(y(\mathbf{x})\), is defined by a mean function \(\mu(\mathbf{x})\) and a covariance function (or kernel) \(k(\mathbf{x}, \mathbf{x}')\). While Syne Tune contains a number of different covariance functions for multi-fidelity HPO, where learning curves \(y(\mathbf{x}, r)\) are modelled, \(r = 1, 2, \dots\) the number of epochs trained (details are provided here), it currently provides the Matern 5/2 covariance function only for models of \(y(\mathbf{x})\). A few comments up front:

  • Mean and covariance functions are parts of (Gaussian process) surrogate models. For these models, complex gradients are required for different purposes. First, our Bayesian optimization code supports gradient-based minimization of the acquisition function. Second, a surrogate model is fitted to observed data, which is typically done by gradient-based optimization (e.g., marginal likelihood optimization, empirical Bayes) or by gradient-based Markov Chain Monte Carlo (e.g., Hamiltonian Monte Carlo). This means that covariance function code must be written in a framework supporting automatic differentiation. In Syne Tune, this code resides in syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd. It is based on autograd.

  • Covariance functions contain parameters to be fitted to observed data. Kernels in Syne Tune typically feature an overall output scale, as well as inverse bandwidths for the input. In the (so called) automatic relevance determination parameterization, we use one inverse bandwidth per input vector component. This allows the surrogate model to learn relevance to certain input components: if components are not relevant to explain the observed data, their inverse bandwidths can be driven to very small values. Syne Tune uses code extracted from MXNet Gluon for managing parameters. The base class KernelFunction derives from MeanFunction, which derives from Block. The main service of this class is to maintain a parameter dictionary, collecting all parameters in the current objects and its members (recursively).

In order to understand how a new covariance function can be implemented, we will go through the most important parts of Matern52. This covariance function is defined as:

\[k(\mathbf{x}, \mathbf{x}') = c \left( 1 + d + d^2/3 \right) e^{-d}, \quad d = \sqrt{5} \|\mathbf{S} (\mathbf{x} - \mathbf{x}')\|.\]

Its parameters are the output scale \(c > 0\) and the inverse bandwidths \(s_j > 0\), where \(\mathbf{S}\) is the diagonal matrix with diagonal entries \(s_j\). If ARD == False, there is only a single bandwidth parameter \(s > 0\).

First, we need some includes:

bayesopt/gpautograd/kernel/base.py – includes
import autograd.numpy as anp
from autograd.tracer import getval
from typing import Dict, Any

from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.constants import (
    INITIAL_COVARIANCE_SCALE,
    INITIAL_INVERSE_BANDWIDTHS,
    DEFAULT_ENCODING,
    INVERSE_BANDWIDTHS_LOWER_BOUND,
    INVERSE_BANDWIDTHS_UPPER_BOUND,
    COVARIANCE_SCALE_LOWER_BOUND,
    COVARIANCE_SCALE_UPPER_BOUND,
    NUMERICAL_JITTER,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.distribution import (
    Uniform,
    LogNormal,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.gluon import Block
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.gluon_blocks_helpers import (
    encode_unwrap_parameter,
    register_parameter,
    create_encoding,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.mean import (
    MeanFunction,
)


Since a number of covariance functions are simple expressions of squared distances \(\|\mathbf{S} (\mathbf{x} - \mathbf{x}')\|^2\), Syne Tune contains a block for this one:

bayesopt/gpautograd/kernel/base.py – SquaredDistance
class SquaredDistance(Block):
    r"""
    Block that is responsible for the computation of matrices of squared
    distances. The distances can possibly be weighted (e.g., ARD
    parametrization). For instance:

    .. math::

       m_{i j} = \sum_{k=1}^d ib_k^2 (x_{1: i k} - x_{2: j k})^2

       \mathbf{X}_1 = [x_{1: i j}],\quad \mathbf{X}_2 = [x_{2: i j}]

    Here, :math:`[ib_k]` is the vector :attr:`inverse_bandwidth`.
    if ``ARD == False``, ``inverse_bandwidths`` is equal to a scalar broadcast to the
    d components (with ``d = dimension``, i.e., the number of features in ``X``).

    :param dimension: Dimensionality :math:`d` of input vectors
    :param ARD: Automatic relevance determination (``inverse_bandwidth`` vector
        of size ``d``)? Defaults to ``False``
    :param encoding_type: Encoding for ``inverse_bandwidth``. Defaults to
        :const:`~syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.constants.DEFAULT_ENCODING`
    """

    def __init__(
        self,
        dimension: int,
        ARD: bool = False,
        encoding_type: str = DEFAULT_ENCODING,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.ARD = ARD
        inverse_bandwidths_dimension = 1 if not ARD else dimension
        self.encoding = create_encoding(
            encoding_type,
            INITIAL_INVERSE_BANDWIDTHS,
            INVERSE_BANDWIDTHS_LOWER_BOUND,
            INVERSE_BANDWIDTHS_UPPER_BOUND,
            inverse_bandwidths_dimension,
            Uniform(INVERSE_BANDWIDTHS_LOWER_BOUND, INVERSE_BANDWIDTHS_UPPER_BOUND),
        )

        with self.name_scope():
            self.inverse_bandwidths_internal = register_parameter(
                self.params,
                "inverse_bandwidths",
                self.encoding,
                shape=(inverse_bandwidths_dimension,),
            )

    def _inverse_bandwidths(self):
        return encode_unwrap_parameter(self.inverse_bandwidths_internal, self.encoding)

    def forward(self, X1, X2):
        """Computes matrix of squared distances

        :param X1: input matrix, shape ``(n1, d)``
        :param X2: input matrix, shape ``(n2, d)``
        """
        # In case inverse_bandwidths if of size (1, dimension), dimension>1,
        # ARD is handled by broadcasting
        inverse_bandwidths = anp.reshape(self._inverse_bandwidths(), (1, -1))

        X1_scaled = anp.multiply(X1, inverse_bandwidths)
        X1_squared_norm = anp.sum(anp.square(X1_scaled), axis=1)
        if X2 is X1:
            D = -2.0 * anp.dot(X1_scaled, anp.transpose(X1_scaled))
            X2_squared_norm = X1_squared_norm
        else:
            X2_scaled = anp.multiply(X2, inverse_bandwidths)
            D = -2.0 * anp.matmul(X1_scaled, anp.transpose(X2_scaled))
            X2_squared_norm = anp.sum(anp.square(X2_scaled), axis=1)
        D = D + anp.reshape(X1_squared_norm, (-1, 1))
        D = D + anp.reshape(X2_squared_norm, (1, -1))

        return anp.abs(D)

    def get_params(self) -> Dict[str, Any]:
        """
        Parameter keys are "inv_bw<k> "if ``dimension > 1``, and "inv_bw" if
        ``dimension == 1``.
        """
        inverse_bandwidths = anp.reshape(self._inverse_bandwidths(), (-1,))
        if inverse_bandwidths.size == 1:
            return {"inv_bw": inverse_bandwidths[0]}
        else:
            return {
                "inv_bw{}".format(k): inverse_bandwidths[k]
                for k in range(inverse_bandwidths.size)
            }

    def set_params(self, param_dict: Dict[str, Any]):
        dimension = self.encoding.dimension
        if dimension == 1:
            inverse_bandwidths = [param_dict["inv_bw"]]
        else:
            keys = ["inv_bw{}".format(k) for k in range(dimension)]
            for k in keys:
                assert k in param_dict, "'{}' not in param_dict = {}".format(
                    k, param_dict
                )
            inverse_bandwidths = [param_dict[k] for k in keys]
        self.encoding.set(self.inverse_bandwidths_internal, inverse_bandwidths)


  • In the constructor, we create a parameter vector for the inverse bandwidths \([s_j]\), which can be just a scalar if ARD == False. In Syne Tune, each parameter has an encoding (e.g., identity or logarithmic), which includes a lower and upper bound, an initial value, as well as a prior distribution. The latter is used for regularization during optimization.

  • The most important method is forward. Given two matrices \(\mathbf{X}_1\), \(\mathbf{X}_2\), whose rows are input vectors, we compute the matrix \([\|\mathbf{x}_{1:i} - \mathbf{x}_{2:j}\|^2]_{i, j}\) of squared distances. Most important, we use anp = autograd.numpy here instead of numpy. These autograd wrappers ensure that automatic differentiation can be used in order to compute gradients w.r.t. leaf nodes in the computation graph spanned by the numpy computations. Also, note the use of encode_unwrap_parameter in _inverse_bandwidths to obtain the inverse bandwidth parameters as numpy array. Finally, note that X1 and X2 can be the same object, in which case we can save compute time and create a smaller computation graph.

  • Each block in Syne Tune also provides get_params and set_params methods, which are used for serialization and deserialization.

Given this code, the implementation of Matern52 is simple:

bayesopt/gpautograd/kernel/base.py – Matern52
class Matern52(KernelFunction):
    """
    Block that is responsible for the computation of Matern 5/2 kernel.

    if ``ARD == False``, ``inverse_bandwidths`` is equal to a scalar broadcast to the
    d components (with ``d = dimension``, i.e., the number of features in ``X``).

    Arguments on top of base class :class:`SquaredDistance`:

    :param has_covariance_scale: Kernel has covariance scale parameter? Defaults
        to ``True``
    """

    def __init__(
        self,
        dimension: int,
        ARD: bool = False,
        encoding_type: str = DEFAULT_ENCODING,
        has_covariance_scale: bool = True,
        **kwargs
    ):
        super(Matern52, self).__init__(dimension, **kwargs)
        self.has_covariance_scale = has_covariance_scale
        self.squared_distance = SquaredDistance(
            dimension=dimension, ARD=ARD, encoding_type=encoding_type
        )
        if has_covariance_scale:
            self.encoding = create_encoding(
                encoding_name=encoding_type,
                init_val=INITIAL_COVARIANCE_SCALE,
                constr_lower=COVARIANCE_SCALE_LOWER_BOUND,
                constr_upper=COVARIANCE_SCALE_UPPER_BOUND,
                dimension=1,
                prior=LogNormal(0.0, 1.0),
            )
            with self.name_scope():
                self.covariance_scale_internal = register_parameter(
                    self.params, "covariance_scale", self.encoding
                )

    @property
    def ARD(self) -> bool:
        return self.squared_distance.ARD

    def _covariance_scale(self):
        if self.has_covariance_scale:
            return encode_unwrap_parameter(
                self.covariance_scale_internal, self.encoding
            )
        else:
            return 1.0

    def forward(self, X1, X2):
        """Computes Matern 5/2 kernel matrix

        :param X1: input matrix, shape ``(n1,d)``
        :param X2: input matrix, shape ``(n2,d)``
        """
        covariance_scale = self._covariance_scale()
        X1 = self._check_input_shape(X1)
        if X2 is not X1:
            X2 = self._check_input_shape(X2)
        D = 5.0 * self.squared_distance(X1, X2)
        # Using the plain np.sqrt is numerically unstable for D ~ 0
        # (non-differentiability)
        # that's why we add NUMERICAL_JITTER
        B = anp.sqrt(D + NUMERICAL_JITTER)
        return anp.multiply((1.0 + B + D / 3.0) * anp.exp(-B), covariance_scale)

    def diagonal(self, X):
        X = self._check_input_shape(X)
        covariance_scale = self._covariance_scale()
        covariance_scale_times_ones = anp.multiply(
            anp.ones((getval(X.shape[0]), 1)), covariance_scale
        )

        return anp.reshape(covariance_scale_times_ones, (-1,))

    def diagonal_depends_on_X(self):
        return False

    def param_encoding_pairs(self):
        result = [
            (
                self.squared_distance.inverse_bandwidths_internal,
                self.squared_distance.encoding,
            )
        ]
        if self.has_covariance_scale:
            result.insert(0, (self.covariance_scale_internal, self.encoding))
        return result

    def get_covariance_scale(self):
        if self.has_covariance_scale:
            return self._covariance_scale()[0]
        else:
            return 1.0

    def set_covariance_scale(self, covariance_scale):
        assert self.has_covariance_scale, "covariance_scale is fixed to 1"
        self.encoding.set(self.covariance_scale_internal, covariance_scale)

    def get_params(self) -> Dict[str, Any]:
        result = self.squared_distance.get_params()
        if self.has_covariance_scale:
            result["covariance_scale"] = self.get_covariance_scale()
        return result

    def set_params(self, param_dict: Dict[str, Any]):
        self.squared_distance.set_params(param_dict)
        if self.has_covariance_scale:
            self.set_covariance_scale(param_dict["covariance_scale"])
  • In the constructor, we create an object of type SquaredDistance. A nice feature of MXNet Gluon blocks is that the parameter dictionary of an object is automatically extended by the dictionaries of members, so we don’t need to cater for that. Beware that this only works for members which are of type Block directly. If you use a list or dictionary containing such objects, you need to include their parameter dictionaries explicitly. Next, we also define a covariance scale parameter \(c > 0\), unless has_covariance_scale == False.

  • forward calls forward of the SquaredDistance object, then computes the kernel matrix, using anp = autograd.numpy once more.

  • diagonal returns the diagonal of the kernel matrix based on a matrix X of inputs. For this particular kernel, the diagonal does not depend on the content of X, but only its shape, which is why diagonal_depends_on_X returns False.

  • Besides get_params and set_params, we also need to implement param_encoding_pairs, which is required by the optimization code used for fitting the surrogate model parameters.

At this point, you should not have any major difficulties implementing a new covariance function, such as the Gaussian kernel or the Matern kernel with parameter 3/2.

The Factory for Gaussian Process Searchers

Once a covariance function (or any other component of a surrogate model) has been added, how is it accessed by a user? In general, all details about the surrogate model are specified in search_options passed to FIFOScheduler or BayesianOptimization. Available options are documented in GPFIFOSearcher. Syne Tune offers a range of searchers based on various Gaussian process surrogate models (e.g., single fidelity, multi-fidelity, constrained, cost-aware). The code to generate all required components for these searchers is bundled in gp_searcher_factory. For each type of searcher, there is a factory function and a defaults function. For BayesianOptimization (which is equivalent to FIFOScheduler with searcher="bayesopt"), we have:

The searcher object is created in searcher_factory(). Finally, search_options are merged with default values, and searcher_factory is called in the constructor of FIFOScheduler. This process keeps things simple for the user, who just has to specify the type of searcher by searcher, and additional arguments by search_options. For any argument not provided there, a sensible default value is used.

Factory and default functions in gp_searcher_factory are based on common code in this module, which reflects the complexity of some of the searchers, but is otherwise self-explanatory. As a continuation of the previous section, suppose we had implemented a novel covariance function to be used in GP-based Bayesian optimization. The user-facing argument to select a kernel is gp_base_kernel, its default value is “matern52-ard” (Matern 5/2 with ARD parameters). Here is the code for creating this covariance function in gp_searcher_factory:

gp_searcher_factory.py
def _create_base_gp_kernel(hp_ranges: HyperparameterRanges, **kwargs) -> KernelFunction:
    """
    The default base kernel is :class:`Matern52` with ARD parameters.
    But in the transfer learning case, the base kernel is a product of
    two ``Matern52`` kernels, the first non-ARD over the categorical
    parameter determining the task, the second ARD over the remaining
    parameters.
    """
    input_warping = kwargs.get("input_warping", False)
    if kwargs.get("transfer_learning_task_attr") is not None:
        if input_warping:
            logger.warning(
                "Cannot use input_warping=True together with transfer_learning_task_attr. Will use input_warping=False"
            )
        # Transfer learning: Specific base kernel
        kernel = create_base_gp_kernel_for_warmstarting(hp_ranges, **kwargs)
    else:
        has_covariance_scale = kwargs.get("has_covariance_scale", True)
        kernel = base_kernel_factory(
            name=kwargs["gp_base_kernel"],
            dimension=hp_ranges.ndarray_size,
            has_covariance_scale=has_covariance_scale,
        )
        if input_warping:
            # Use input warping on all coordinates which do not belong to a
            # categorical hyperparameter
            kernel = kernel_with_warping(kernel, hp_ranges)
            if kwargs.get("debug_log", False) and isinstance(kernel, WarpedKernel):
                ranges = [(warp.lower, warp.upper) for warp in kernel.warpings]
                logger.info(
                    f"Creating base GP covariance kernel with input warping: ranges = {ranges}"
                )
    return kernel


  • Ignoring transfer_learning_task_attr, we first call base_kernel_factory to create the base kernel, passing kwargs["gp_base_kernel"] as its name.

  • Syne Tune also supports warping of the inputs to a kernel, which adds two more parameters for each component (except those coming from categorical hyperparameters, these are not warped).

bayesopt/models/kernel_factory.py
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel import (
    KernelFunction,
    Matern52,
    ExponentialDecayResourcesKernelFunction,
    ExponentialDecayResourcesMeanFunction,
    FreezeThawKernelFunction,
    FreezeThawMeanFunction,
    CrossValidationMeanFunction,
    CrossValidationKernelFunction,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.warping import (
    WarpedKernel,
    Warping,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.mean import (
    MeanFunction,
)


SUPPORTED_BASE_MODELS = (
    "matern52-ard",
    "matern52-noard",
)


def base_kernel_factory(name: str, dimension: int, **kwargs) -> KernelFunction:
    assert (
        name in SUPPORTED_BASE_MODELS
    ), f"name = {name} not supported. Choose from:\n{SUPPORTED_BASE_MODELS}"
    return Matern52(
        dimension=dimension,
        ARD=name == "matern52-ard",
        has_covariance_scale=kwargs.get("has_covariance_scale", True),
    )


  • base_kernel_factory creates the base kernel, based on its name (must be in SUPPORTED_BASE_MODELS, the dimension of input vectors, as well as further parameters (has_covariance_scale in our example). Currently, Syne Tune only supports the Matern 5/2 kernel, with and without ARD.

  • Had we implemented a novel covariance function, we would have to select a new name, insert it into SUPPORTED_BASE_MODELS, and insert code into base_kernel_factory. Once this is done, the new base kernel can as well be selected as component in multi-fidelity or constrained Bayesian optimization.