Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.custom_op

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import autograd.numpy as anp
import autograd.scipy.linalg as aspl
from autograd.extend import primitive, defvjp
import numpy as np
import scipy.linalg as spl
import logging
import math

logger = logging.getLogger(__name__)

__all__ = ["AddJitterOp", "flatten_and_concat", "cholesky_factorization"]


INITIAL_JITTER_FACTOR = 1e-9
JITTER_GROWTH = 10.0
JITTER_UPPERBOUND_FACTOR = 1e3


[docs] def flatten_and_concat(x: anp.ndarray, sigsq_init: anp.ndarray): return anp.append(anp.reshape(x, (-1,)), sigsq_init)
@primitive def AddJitterOp( inputs: np.ndarray, initial_jitter_factor=INITIAL_JITTER_FACTOR, jitter_growth=JITTER_GROWTH, debug_log="false", ): """ Finds smaller jitter to add to diagonal of square matrix to render the matrix positive definite (in that linalg.potrf works). Given input x (positive semi-definite matrix) and ``sigsq_init`` (nonneg scalar), find ``sigsq_final`` (nonneg scalar), so that: | ``sigsq_final = sigsq_init + jitter``, ``jitter >= 0``, | ``x + sigsq_final * Id`` positive definite (so that ``potrf`` call works) We return the matrix ``x + sigsq_final * Id``, for which ``potrf`` has not failed. For the gradient, the dependence of jitter on the inputs is ignored. The values tried for sigsq_final are: | ``sigsq_init, sigsq_init + initial_jitter * (jitter_growth ** k)``, ``k = 0, 1, 2, ...``, | ``initial_jitter = initial_jitter_factor * max(mean(diag(x)), 1)`` Note: The scaling of initial_jitter with ``mean(diag(x))`` is taken from ``GPy``. The rationale is that the largest eigenvalue of x is ``>= mean(diag(x))``, and likely of this magnitude. There is no guarantee that the Cholesky factor returned is well-conditioned enough for subsequent computations to be reliable. A better solution would be to estimate the condition number of the Cholesky factor, and to add jitter until this is bounded below a threshold we tolerate. See | Higham, N. | A Survey of Condition Number Estimation for Triangular Matrices | MIMS EPrint: 2007.10 Algorithm 4.1 could work for us. """ assert initial_jitter_factor > 0.0 and jitter_growth > 1.0 n_square = inputs.shape[0] - 1 n = int(math.sqrt(n_square)) assert ( n_square % n == 0 and n_square // n == n ), "x must be square matrix, shape (n, n)" x, sigsq_init = np.reshape(inputs[:-1], (n, -1)), inputs[-1] def _get_constant_identity(x, constant): n, _ = x.shape return np.diag(np.ones((n,)) * constant) def _get_jitter_upperbound(x): # To define a safeguard in the while-loop of the forward, # we define an upperbound on the jitter we can reasonably add # the bound is quite generous, and is dependent on the scale of the input x # (the scale is captured via the trace of x) # the primary goal is avoid any infinite while-loop. return JITTER_UPPERBOUND_FACTOR * max(1.0, np.mean(np.diag(x))) jitter = 0.0 jitter_upperbound = _get_jitter_upperbound(x) must_increase_jitter = True x_plus_constant = None while must_increase_jitter and jitter <= jitter_upperbound: try: x_plus_constant = x + _get_constant_identity(x, sigsq_init + jitter) # Note: Do not use np.linalg.cholesky here, this can cause # locking issues L = spl.cholesky(x_plus_constant, lower=True) must_increase_jitter = False except spl.LinAlgError: if debug_log == "true": logger.info("sigsq = {} does not work".format(sigsq_init + jitter)) if jitter == 0.0: jitter = initial_jitter_factor * max(1.0, np.mean(np.diag(x))) else: jitter = jitter * jitter_growth assert ( not must_increase_jitter ), "The jitter ({}) has reached its upperbound ({}) while the Cholesky of the input matrix still cannot be computed.".format( jitter, jitter_upperbound ) if debug_log == "true": logger.info("sigsq_final = {}".format(sigsq_init + jitter)) return x_plus_constant def AddJitterOp_vjp( ans: np.ndarray, inputs: np.ndarray, initial_jitter_factor=INITIAL_JITTER_FACTOR, jitter_growth=JITTER_GROWTH, debug_log="false", ): return lambda g: anp.append(anp.reshape(g, (-1,)), anp.sum(anp.diag(g))) defvjp(AddJitterOp, AddJitterOp_vjp) @primitive def cholesky_factorization(a): """ Replacement for :func:`autograd.numpy.linalg.cholesky`. Our backward (vjp) is faster and simpler, while somewhat less general (only works if ``a.ndim == 2``). See https://arxiv.org/abs/1710.08717 for derivation of backward (vjp) expression. :param a: Symmmetric positive definite matrix A :return: Lower-triangular Cholesky factor L of A """ # Note: Do not use np.linalg.cholesky here, this can cause locking issues return spl.cholesky(a, lower=True) def copyltu(x): return anp.tril(x) + anp.transpose(anp.tril(x, -1)) def cholesky_factorization_backward(l, lbar): abar = copyltu(anp.matmul(anp.transpose(l), lbar)) abar = anp.transpose(aspl.solve_triangular(l, abar, lower=True, trans="T")) abar = aspl.solve_triangular(l, abar, lower=True, trans="T") return 0.5 * abar def cholesky_factorization_vjp(l, a): return lambda lbar: cholesky_factorization_backward(l, lbar) defvjp(cholesky_factorization, cholesky_factorization_vjp)