Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.freeze_thaw

# Copyright 2021, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Any
import autograd.numpy as anp
from autograd.builtins import isinstance
from autograd.tracer import getval

from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.base import (
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.exponential_decay import (
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.constants import (
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.gluon_blocks_helpers import (
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.mean import (

[docs] class FreezeThawKernelFunction(KernelFunction): """ Variant of the kernel function for modeling exponentially decaying learning curves, proposed in: Swersky, K., Snoek, J., & Adams, R. P. (2014). Freeze-Thaw Bayesian Optimization. ArXiv:1406.3896 [Cs, Stat). Retrieved from The argument in that paper actually justifies using a non-zero mean function (see :class:`ExponentialDecayResourcesMeanFunction`) and centralizing the kernel proposed there. This is done here. As in the Freeze-Thaw paper, learning curves for different configs are conditionally independent. This class is configured with a kernel and a mean function over inputs x (dimension d) and represents a kernel (and mean function) over inputs (x, r) (dimension d + 1), where the resource attribute r >= 0 is last. Note: This kernel is mostly for debugging! Its conditional independence assumptions allow for faster inference, as implemented in :class:`GaussProcExpDecayPosteriorState`. """ def __init__( self, kernel_x: KernelFunction, mean_x: MeanFunction, encoding_type=DEFAULT_ENCODING, alpha_init=1.0, mean_lam_init=0.5, gamma_init=0.5, max_metric_value=1.0, **kwargs ): """ :param kernel_x: Kernel k_x(x, x') over configs :param mean_x: Mean function mu_x(x) over configs :param encoding_type: Encoding used for alpha, mean_lam, gamma (positive values) :param alpha_init: Initial value alpha :param mean_lam_init: Initial value mean_lam :param gamma_init: Initial value gamma :param max_metric_value: Maximum value which metric can attend. This is used as upper bound on gamma """ super().__init__(dimension=kernel_x.dimension + 1, **kwargs) self.kernel_x = kernel_x self.mean_x = mean_x alpha_lower, alpha_upper = 1e-6, 250.0 alpha_init = ExponentialDecayResourcesKernelFunction._wrap_initvals( alpha_init, alpha_lower, alpha_upper ) self.encoding_alpha = create_encoding( encoding_type, alpha_init, alpha_lower, alpha_upper, 1, None ) mean_lam_lower, mean_lam_upper = 1e-4, 50.0 mean_lam_init = ExponentialDecayResourcesKernelFunction._wrap_initvals( mean_lam_init, mean_lam_lower, mean_lam_upper ) self.encoding_mean_lam = create_encoding( encoding_type, mean_lam_init, mean_lam_lower, mean_lam_upper, 1, None ) gamma_lower = max_metric_value * 0.0001 gamma_upper = max_metric_value gamma_init = ExponentialDecayResourcesKernelFunction._wrap_initvals( gamma_init, gamma_lower, gamma_upper ) self.encoding_gamma = create_encoding( encoding_type, gamma_init, gamma_lower, gamma_upper, 1, None ) with self.name_scope(): self.alpha_internal = register_parameter( self.params, "alpha", self.encoding_alpha ) self.mean_lam_internal = register_parameter( self.params, "mean_lam", self.encoding_mean_lam ) self.gamma_internal = register_parameter( self.params, "gamma", self.encoding_gamma ) def _compute_terms(self, X, alpha, mean_lam, ret_mean=False): dim = self.kernel_x.dimension cfg = X[:, :dim] res = X[:, dim:] kappa = ExponentialDecayResourcesKernelFunction._compute_kappa( res, alpha, mean_lam ) if ret_mean: mean = self.mean_x(cfg) else: mean = None return cfg, res, kappa, mean def _get_params(self, X, **kwargs): alpha = ExponentialDecayResourcesKernelFunction._unwrap( X, kwargs, "alpha", self.encoding_alpha, self.alpha_internal ) mean_lam = ExponentialDecayResourcesKernelFunction._unwrap( X, kwargs, "mean_lam", self.encoding_mean_lam, self.mean_lam_internal ) gamma = ExponentialDecayResourcesKernelFunction._unwrap( X, kwargs, "gamma", self.encoding_gamma, self.gamma_internal ) return (alpha, mean_lam, gamma) @staticmethod def _to_tuples(cfg): return [ tuple(anp.ravel(x)) for x in anp.split(cfg, getval(cfg.shape[0]), axis=0) ]
[docs] def forward(self, X1, X2, **kwargs): alpha, mean_lam, gamma = self._get_params(X1, **kwargs) gamma = anp.reshape(gamma, (1, 1)) cfg1, res1, kappa1, _ = self._compute_terms(X1, alpha, mean_lam) cfg1_tpls = self._to_tuples(cfg1) if X2 is not X1: cfg2, res2, kappa2, _ = self._compute_terms(X2, alpha, mean_lam) cfg2_tpls = self._to_tuples(cfg2) cfg_set = set(cfg1_tpls + cfg2_tpls) else: cfg2, res2, kappa2, cfg2_tpls = cfg1, res1, kappa1, cfg1_tpls cfg_set = set(cfg1_tpls) cfg_map = dict(zip(cfg_set, range(len(cfg_set)))) cfg1_ind = anp.reshape(anp.array([cfg_map[x] for x in cfg1_tpls]), (-1, 1)) if X2 is not X1: cfg2_ind = anp.reshape(anp.array([cfg_map[x] for x in cfg2_tpls]), (1, -1)) else: cfg2_ind = anp.reshape(cfg1_ind, (1, -1)) res2 = anp.reshape(res2, (1, -1)) kappa2 = anp.reshape(kappa2, (1, -1)) kappa12 = ExponentialDecayResourcesKernelFunction._compute_kappa( anp.add(res1, res2), alpha, mean_lam ) kmat_res = anp.subtract(kappa12, anp.multiply(kappa1, kappa2)) kmat_res = kmat_res * anp.square(gamma) kmat_res = kmat_res * (cfg1_ind == cfg2_ind) kmat_x = self.kernel_x(cfg1, cfg2) return kmat_x + kmat_res
[docs] def diagonal(self, X): alpha, mean_lam, gamma = self._get_params(X) gamma = anp.reshape(gamma, (1, 1)) cfg, res, kappa, _ = self._compute_terms(X, alpha, mean_lam) kappa2 = ExponentialDecayResourcesKernelFunction._compute_kappa( res * 2, alpha, mean_lam ) kdiag_res = anp.subtract(kappa2, anp.square(kappa)) kdiag_res = anp.reshape(kdiag_res * anp.square(gamma), (-1,)) kdiag_x = self.kernel_x.diagonal(cfg) return kdiag_x + kdiag_res
[docs] def diagonal_depends_on_X(self): return True
[docs] def param_encoding_pairs(self): enc_list = [ (self.alpha_internal, self.encoding_alpha), (self.mean_lam_internal, self.encoding_mean_lam), (self.gamma_internal, self.encoding_gamma), ] enc_list.extend(self.kernel_x.param_encoding_pairs()) enc_list.extend(self.mean_x.param_encoding_pairs()) return enc_list
[docs] def mean_function(self, X): alpha, mean_lam, gamma = self._get_params(X) gamma = anp.reshape(gamma, (1, 1)) cfg, res, kappa, mean = self._compute_terms(X, alpha, mean_lam, ret_mean=True) return anp.add(mean, anp.multiply(kappa, gamma))
[docs] def get_params(self) -> Dict[str, Any]: """ Parameter keys are alpha, mean_lam, gamma, delta (only if not fixed to delta_fixed_value), as well as those of self.kernel_x (prefix 'kernelx_') and of self.mean_x (prefix 'meanx_'). """ values = list(self._get_params(None)) keys = ["alpha", "mean_lam", "gamma"] result = {k: anp.reshape(v, (1,))[0] for k, v in zip(keys, values)} for pref, func in [("kernelx_", self.kernel_x), ("meanx_", self.mean_x)]: result.update({(pref + k): v for k, v in func.get_params().items()}) return result
[docs] def set_params(self, param_dict: Dict[str, Any]): for pref, func in [("kernelx_", self.kernel_x), ("meanx_", self.mean_x)]: len_pref = len(pref) stripped_dict = { k[len_pref:]: v for k, v in param_dict.items() if k.startswith(pref) } func.set_params(stripped_dict) self.encoding_alpha.set(self.alpha_internal, param_dict["alpha"]) self.encoding_mean_lam.set(self.mean_lam_internal, param_dict["mean_lam"]) self.encoding_gamma.set(self.gamma_internal, param_dict["gamma"])
[docs] class FreezeThawMeanFunction(MeanFunction): def __init__(self, kernel: FreezeThawKernelFunction, **kwargs): super().__init__(**kwargs) assert isinstance(kernel, FreezeThawKernelFunction) self.kernel = kernel
[docs] def forward(self, X): return self.kernel.mean_function(X)
[docs] def param_encoding_pairs(self): return []
[docs] def get_params(self) -> Dict[str, Any]: return dict()
[docs] def set_params(self, param_dict: Dict[str, Any]): pass