Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.range_kernel

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Any
from syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.kernel.base import (
    KernelFunction,
)


[docs] class RangeKernelFunction(KernelFunction): r""" Given kernel function ``K`` and range ``R``, this class represents .. math:: (x, y) \mapsto K(x_R, y_R) """ def __init__(self, dimension: int, kernel: KernelFunction, start: int, **kwargs): """ :param dimension: Input dimension :param kernel: Kernel function K :param start: Range is ``range(start, start + kernel.dimension)`` """ super().__init__(dimension, **kwargs) assert start >= 0 and start + kernel.dimension <= dimension, ( start, dimension, kernel.dimension, ) self.kernel = kernel self.start = start
[docs] def forward(self, X1, X2): a = self.start b = a + self.kernel.dimension X1_part = X1[:, a:b] if X2 is X1: X2_part = X1_part else: X2_part = X2[:, a:b] return self.kernel(X1_part, X2_part)
[docs] def diagonal(self, X): a = self.start b = a + self.kernel.dimension return self.kernel.diagonal(X[:, a:b])
[docs] def diagonal_depends_on_X(self): return self.kernel.diagonal_depends_on_X()
[docs] def param_encoding_pairs(self): """ Note: We assume that K1 and K2 have disjoint parameters, otherwise there will be a redundancy here. """ return self.kernel.param_encoding_pairs()
[docs] def get_params(self) -> Dict[str, Any]: return self.kernel.get_params()
[docs] def set_params(self, param_dict: Dict[str, Any]): self.kernel.set_params(param_dict)