# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Any
import autograd.numpy as anp
from autograd.builtins import isinstance
from autograd.tracer import getval
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.base import (
KernelFunction,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.exponential_decay import (
ExponentialDecayResourcesKernelFunction,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.constants import (
DEFAULT_ENCODING,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.gluon_blocks_helpers import (
register_parameter,
create_encoding,
)
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.mean import (
MeanFunction,
)
[docs]
class FreezeThawKernelFunction(KernelFunction):
"""
Variant of the kernel function for modeling exponentially decaying
learning curves, proposed in:
Swersky, K., Snoek, J., & Adams, R. P. (2014).
Freeze-Thaw Bayesian Optimization.
ArXiv:1406.3896 [Cs, Stat).
Retrieved from http://arxiv.org/abs/1406.3896
The argument in that paper actually justifies using a non-zero mean
function (see :class:`ExponentialDecayResourcesMeanFunction`) and
centralizing the kernel proposed there. This is done here.
As in the Freeze-Thaw paper, learning curves for different configs are
conditionally independent.
This class is configured with a kernel and a mean function over
inputs x (dimension d) and represents a kernel (and mean function) over
inputs (x, r) (dimension d + 1), where the resource attribute r >= 0 is
last.
Note: This kernel is mostly for debugging! Its conditional independence
assumptions allow for faster inference, as implemented in
:class:`GaussProcExpDecayPosteriorState`.
"""
def __init__(
self,
kernel_x: KernelFunction,
mean_x: MeanFunction,
encoding_type=DEFAULT_ENCODING,
alpha_init=1.0,
mean_lam_init=0.5,
gamma_init=0.5,
max_metric_value=1.0,
**kwargs
):
"""
:param kernel_x: Kernel k_x(x, x') over configs
:param mean_x: Mean function mu_x(x) over configs
:param encoding_type: Encoding used for alpha, mean_lam, gamma (positive
values)
:param alpha_init: Initial value alpha
:param mean_lam_init: Initial value mean_lam
:param gamma_init: Initial value gamma
:param max_metric_value: Maximum value which metric can attend. This is
used as upper bound on gamma
"""
super().__init__(dimension=kernel_x.dimension + 1, **kwargs)
self.kernel_x = kernel_x
self.mean_x = mean_x
alpha_lower, alpha_upper = 1e-6, 250.0
alpha_init = ExponentialDecayResourcesKernelFunction._wrap_initvals(
alpha_init, alpha_lower, alpha_upper
)
self.encoding_alpha = create_encoding(
encoding_type, alpha_init, alpha_lower, alpha_upper, 1, None
)
mean_lam_lower, mean_lam_upper = 1e-4, 50.0
mean_lam_init = ExponentialDecayResourcesKernelFunction._wrap_initvals(
mean_lam_init, mean_lam_lower, mean_lam_upper
)
self.encoding_mean_lam = create_encoding(
encoding_type, mean_lam_init, mean_lam_lower, mean_lam_upper, 1, None
)
gamma_lower = max_metric_value * 0.0001
gamma_upper = max_metric_value
gamma_init = ExponentialDecayResourcesKernelFunction._wrap_initvals(
gamma_init, gamma_lower, gamma_upper
)
self.encoding_gamma = create_encoding(
encoding_type, gamma_init, gamma_lower, gamma_upper, 1, None
)
with self.name_scope():
self.alpha_internal = register_parameter(
self.params, "alpha", self.encoding_alpha
)
self.mean_lam_internal = register_parameter(
self.params, "mean_lam", self.encoding_mean_lam
)
self.gamma_internal = register_parameter(
self.params, "gamma", self.encoding_gamma
)
def _compute_terms(self, X, alpha, mean_lam, ret_mean=False):
dim = self.kernel_x.dimension
cfg = X[:, :dim]
res = X[:, dim:]
kappa = ExponentialDecayResourcesKernelFunction._compute_kappa(
res, alpha, mean_lam
)
if ret_mean:
mean = self.mean_x(cfg)
else:
mean = None
return cfg, res, kappa, mean
def _get_params(self, X, **kwargs):
alpha = ExponentialDecayResourcesKernelFunction._unwrap(
X, kwargs, "alpha", self.encoding_alpha, self.alpha_internal
)
mean_lam = ExponentialDecayResourcesKernelFunction._unwrap(
X, kwargs, "mean_lam", self.encoding_mean_lam, self.mean_lam_internal
)
gamma = ExponentialDecayResourcesKernelFunction._unwrap(
X, kwargs, "gamma", self.encoding_gamma, self.gamma_internal
)
return (alpha, mean_lam, gamma)
@staticmethod
def _to_tuples(cfg):
return [
tuple(anp.ravel(x)) for x in anp.split(cfg, getval(cfg.shape[0]), axis=0)
]
[docs]
def forward(self, X1, X2, **kwargs):
alpha, mean_lam, gamma = self._get_params(X1, **kwargs)
gamma = anp.reshape(gamma, (1, 1))
cfg1, res1, kappa1, _ = self._compute_terms(X1, alpha, mean_lam)
cfg1_tpls = self._to_tuples(cfg1)
if X2 is not X1:
cfg2, res2, kappa2, _ = self._compute_terms(X2, alpha, mean_lam)
cfg2_tpls = self._to_tuples(cfg2)
cfg_set = set(cfg1_tpls + cfg2_tpls)
else:
cfg2, res2, kappa2, cfg2_tpls = cfg1, res1, kappa1, cfg1_tpls
cfg_set = set(cfg1_tpls)
cfg_map = dict(zip(cfg_set, range(len(cfg_set))))
cfg1_ind = anp.reshape(anp.array([cfg_map[x] for x in cfg1_tpls]), (-1, 1))
if X2 is not X1:
cfg2_ind = anp.reshape(anp.array([cfg_map[x] for x in cfg2_tpls]), (1, -1))
else:
cfg2_ind = anp.reshape(cfg1_ind, (1, -1))
res2 = anp.reshape(res2, (1, -1))
kappa2 = anp.reshape(kappa2, (1, -1))
kappa12 = ExponentialDecayResourcesKernelFunction._compute_kappa(
anp.add(res1, res2), alpha, mean_lam
)
kmat_res = anp.subtract(kappa12, anp.multiply(kappa1, kappa2))
kmat_res = kmat_res * anp.square(gamma)
kmat_res = kmat_res * (cfg1_ind == cfg2_ind)
kmat_x = self.kernel_x(cfg1, cfg2)
return kmat_x + kmat_res
[docs]
def diagonal(self, X):
alpha, mean_lam, gamma = self._get_params(X)
gamma = anp.reshape(gamma, (1, 1))
cfg, res, kappa, _ = self._compute_terms(X, alpha, mean_lam)
kappa2 = ExponentialDecayResourcesKernelFunction._compute_kappa(
res * 2, alpha, mean_lam
)
kdiag_res = anp.subtract(kappa2, anp.square(kappa))
kdiag_res = anp.reshape(kdiag_res * anp.square(gamma), (-1,))
kdiag_x = self.kernel_x.diagonal(cfg)
return kdiag_x + kdiag_res
[docs]
def diagonal_depends_on_X(self):
return True
[docs]
def param_encoding_pairs(self):
enc_list = [
(self.alpha_internal, self.encoding_alpha),
(self.mean_lam_internal, self.encoding_mean_lam),
(self.gamma_internal, self.encoding_gamma),
]
enc_list.extend(self.kernel_x.param_encoding_pairs())
enc_list.extend(self.mean_x.param_encoding_pairs())
return enc_list
[docs]
def mean_function(self, X):
alpha, mean_lam, gamma = self._get_params(X)
gamma = anp.reshape(gamma, (1, 1))
cfg, res, kappa, mean = self._compute_terms(X, alpha, mean_lam, ret_mean=True)
return anp.add(mean, anp.multiply(kappa, gamma))
[docs]
def get_params(self) -> Dict[str, Any]:
"""
Parameter keys are alpha, mean_lam, gamma, delta (only if not fixed
to delta_fixed_value), as well as those of self.kernel_x (prefix
'kernelx_') and of self.mean_x (prefix 'meanx_').
"""
values = list(self._get_params(None))
keys = ["alpha", "mean_lam", "gamma"]
result = {k: anp.reshape(v, (1,))[0] for k, v in zip(keys, values)}
for pref, func in [("kernelx_", self.kernel_x), ("meanx_", self.mean_x)]:
result.update({(pref + k): v for k, v in func.get_params().items()})
return result
[docs]
def set_params(self, param_dict: Dict[str, Any]):
for pref, func in [("kernelx_", self.kernel_x), ("meanx_", self.mean_x)]:
len_pref = len(pref)
stripped_dict = {
k[len_pref:]: v for k, v in param_dict.items() if k.startswith(pref)
}
func.set_params(stripped_dict)
self.encoding_alpha.set(self.alpha_internal, param_dict["alpha"])
self.encoding_mean_lam.set(self.mean_lam_internal, param_dict["mean_lam"])
self.encoding_gamma.set(self.gamma_internal, param_dict["gamma"])
[docs]
class FreezeThawMeanFunction(MeanFunction):
def __init__(self, kernel: FreezeThawKernelFunction, **kwargs):
super().__init__(**kwargs)
assert isinstance(kernel, FreezeThawKernelFunction)
self.kernel = kernel
[docs]
def forward(self, X):
return self.kernel.mean_function(X)
[docs]
def param_encoding_pairs(self):
return []
[docs]
def get_params(self) -> Dict[str, Any]:
return dict()
[docs]
def set_params(self, param_dict: Dict[str, Any]):
pass