Source code for syne_tune.optimizer.schedulers.searchers.bayesopt.gpautograd.custom_op

import autograd.numpy as anp
import autograd.scipy.linalg as aspl
from autograd.extend import primitive, defvjp
import numpy as np
import scipy.linalg as spl
import logging
import math

logger = logging.getLogger(__name__)

__all__ = ["AddJitterOp", "flatten_and_concat", "cholesky_factorization"]


INITIAL_JITTER_FACTOR = 1e-9
JITTER_GROWTH = 10.0
JITTER_UPPERBOUND_FACTOR = 1e3


[docs] def flatten_and_concat(x: anp.ndarray, sigsq_init: anp.ndarray): return anp.append(anp.reshape(x, (-1,)), sigsq_init)
@primitive def AddJitterOp( inputs: np.ndarray, initial_jitter_factor=INITIAL_JITTER_FACTOR, jitter_growth=JITTER_GROWTH, debug_log="false", ): """ Finds smaller jitter to add to diagonal of square matrix to render the matrix positive definite (in that linalg.potrf works). Given input x (positive semi-definite matrix) and ``sigsq_init`` (nonneg scalar), find ``sigsq_final`` (nonneg scalar), so that: | ``sigsq_final = sigsq_init + jitter``, ``jitter >= 0``, | ``x + sigsq_final * Id`` positive definite (so that ``potrf`` call works) We return the matrix ``x + sigsq_final * Id``, for which ``potrf`` has not failed. For the gradient, the dependence of jitter on the inputs is ignored. The values tried for sigsq_final are: | ``sigsq_init, sigsq_init + initial_jitter * (jitter_growth ** k)``, ``k = 0, 1, 2, ...``, | ``initial_jitter = initial_jitter_factor * max(mean(diag(x)), 1)`` Note: The scaling of initial_jitter with ``mean(diag(x))`` is taken from ``GPy``. The rationale is that the largest eigenvalue of x is ``>= mean(diag(x))``, and likely of this magnitude. There is no guarantee that the Cholesky factor returned is well-conditioned enough for subsequent computations to be reliable. A better solution would be to estimate the condition number of the Cholesky factor, and to add jitter until this is bounded below a threshold we tolerate. See | Higham, N. | A Survey of Condition Number Estimation for Triangular Matrices | MIMS EPrint: 2007.10 Algorithm 4.1 could work for us. """ assert initial_jitter_factor > 0.0 and jitter_growth > 1.0 n_square = inputs.shape[0] - 1 n = int(math.sqrt(n_square)) assert ( n_square % n == 0 and n_square // n == n ), "x must be square matrix, shape (n, n)" x, sigsq_init = np.reshape(inputs[:-1], (n, -1)), inputs[-1] def _get_constant_identity(x, constant): n, _ = x.shape return np.diag(np.ones((n,)) * constant) def _get_jitter_upperbound(x): # To define a safeguard in the while-loop of the forward, # we define an upperbound on the jitter we can reasonably add # the bound is quite generous, and is dependent on the scale of the input x # (the scale is captured via the trace of x) # the primary goal is avoid any infinite while-loop. return JITTER_UPPERBOUND_FACTOR * max(1.0, np.mean(np.diag(x))) jitter = 0.0 jitter_upperbound = _get_jitter_upperbound(x) must_increase_jitter = True x_plus_constant = None while must_increase_jitter and jitter <= jitter_upperbound: try: x_plus_constant = x + _get_constant_identity(x, sigsq_init + jitter) # Note: Do not use np.linalg.cholesky here, this can cause # locking issues L = spl.cholesky(x_plus_constant, lower=True) must_increase_jitter = False except spl.LinAlgError: if debug_log == "true": logger.info("sigsq = {} does not work".format(sigsq_init + jitter)) if jitter == 0.0: jitter = initial_jitter_factor * max(1.0, np.mean(np.diag(x))) else: jitter = jitter * jitter_growth assert ( not must_increase_jitter ), "The jitter ({}) has reached its upperbound ({}) while the Cholesky of the input matrix still cannot be computed.".format( jitter, jitter_upperbound ) if debug_log == "true": logger.info("sigsq_final = {}".format(sigsq_init + jitter)) return x_plus_constant def AddJitterOp_vjp( ans: np.ndarray, inputs: np.ndarray, initial_jitter_factor=INITIAL_JITTER_FACTOR, jitter_growth=JITTER_GROWTH, debug_log="false", ): return lambda g: anp.append(anp.reshape(g, (-1,)), anp.sum(anp.diag(g))) defvjp(AddJitterOp, AddJitterOp_vjp) @primitive def cholesky_factorization(a): """ Replacement for :func:`autograd.numpy.linalg.cholesky`. Our backward (vjp) is faster and simpler, while somewhat less general (only works if ``a.ndim == 2``). See https://arxiv.org/abs/1710.08717 for derivation of backward (vjp) expression. :param a: Symmmetric positive definite matrix A :return: Lower-triangular Cholesky factor L of A """ # Note: Do not use np.linalg.cholesky here, this can cause locking issues return spl.cholesky(a, lower=True) def copyltu(x): return anp.tril(x) + anp.transpose(anp.tril(x, -1)) def cholesky_factorization_backward(l, lbar): abar = copyltu(anp.matmul(anp.transpose(l), lbar)) abar = anp.transpose(aspl.solve_triangular(l, abar, lower=True, trans="T")) abar = aspl.solve_triangular(l, abar, lower=True, trans="T") return 0.5 * abar def cholesky_factorization_vjp(l, a): return lambda lbar: cholesky_factorization_backward(l, lbar) defvjp(cholesky_factorization, cholesky_factorization_vjp)