import autograd.numpy as anp
import autograd.scipy.linalg as aspl
from autograd.extend import primitive, defvjp
import numpy as np
import scipy.linalg as spl
import logging
import math
logger = logging.getLogger(__name__)
__all__ = ["AddJitterOp", "flatten_and_concat", "cholesky_factorization"]
INITIAL_JITTER_FACTOR = 1e-9
JITTER_GROWTH = 10.0
JITTER_UPPERBOUND_FACTOR = 1e3
[docs]
def flatten_and_concat(x: anp.ndarray, sigsq_init: anp.ndarray):
return anp.append(anp.reshape(x, (-1,)), sigsq_init)
@primitive
def AddJitterOp(
inputs: np.ndarray,
initial_jitter_factor=INITIAL_JITTER_FACTOR,
jitter_growth=JITTER_GROWTH,
debug_log="false",
):
"""
Finds smaller jitter to add to diagonal of square matrix to render the
matrix positive definite (in that linalg.potrf works).
Given input x (positive semi-definite matrix) and ``sigsq_init`` (nonneg
scalar), find ``sigsq_final`` (nonneg scalar), so that:
| ``sigsq_final = sigsq_init + jitter``, ``jitter >= 0``,
| ``x + sigsq_final * Id`` positive definite (so that ``potrf`` call works)
We return the matrix ``x + sigsq_final * Id``, for which ``potrf`` has not failed.
For the gradient, the dependence of jitter on the inputs is ignored.
The values tried for sigsq_final are:
| ``sigsq_init, sigsq_init + initial_jitter * (jitter_growth ** k)``,
``k = 0, 1, 2, ...``,
| ``initial_jitter = initial_jitter_factor * max(mean(diag(x)), 1)``
Note: The scaling of initial_jitter with ``mean(diag(x))`` is taken from ``GPy``.
The rationale is that the largest eigenvalue of x is ``>= mean(diag(x))``, and
likely of this magnitude.
There is no guarantee that the Cholesky factor returned is well-conditioned
enough for subsequent computations to be reliable. A better solution
would be to estimate the condition number of the Cholesky factor, and to add
jitter until this is bounded below a threshold we tolerate. See
| Higham, N.
| A Survey of Condition Number Estimation for Triangular Matrices
| MIMS EPrint: 2007.10
Algorithm 4.1 could work for us.
"""
assert initial_jitter_factor > 0.0 and jitter_growth > 1.0
n_square = inputs.shape[0] - 1
n = int(math.sqrt(n_square))
assert (
n_square % n == 0 and n_square // n == n
), "x must be square matrix, shape (n, n)"
x, sigsq_init = np.reshape(inputs[:-1], (n, -1)), inputs[-1]
def _get_constant_identity(x, constant):
n, _ = x.shape
return np.diag(np.ones((n,)) * constant)
def _get_jitter_upperbound(x):
# To define a safeguard in the while-loop of the forward,
# we define an upperbound on the jitter we can reasonably add
# the bound is quite generous, and is dependent on the scale of the input x
# (the scale is captured via the trace of x)
# the primary goal is avoid any infinite while-loop.
return JITTER_UPPERBOUND_FACTOR * max(1.0, np.mean(np.diag(x)))
jitter = 0.0
jitter_upperbound = _get_jitter_upperbound(x)
must_increase_jitter = True
x_plus_constant = None
while must_increase_jitter and jitter <= jitter_upperbound:
try:
x_plus_constant = x + _get_constant_identity(x, sigsq_init + jitter)
# Note: Do not use np.linalg.cholesky here, this can cause
# locking issues
L = spl.cholesky(x_plus_constant, lower=True)
must_increase_jitter = False
except spl.LinAlgError:
if debug_log == "true":
logger.info("sigsq = {} does not work".format(sigsq_init + jitter))
if jitter == 0.0:
jitter = initial_jitter_factor * max(1.0, np.mean(np.diag(x)))
else:
jitter = jitter * jitter_growth
assert (
not must_increase_jitter
), "The jitter ({}) has reached its upperbound ({}) while the Cholesky of the input matrix still cannot be computed.".format(
jitter, jitter_upperbound
)
if debug_log == "true":
logger.info("sigsq_final = {}".format(sigsq_init + jitter))
return x_plus_constant
def AddJitterOp_vjp(
ans: np.ndarray,
inputs: np.ndarray,
initial_jitter_factor=INITIAL_JITTER_FACTOR,
jitter_growth=JITTER_GROWTH,
debug_log="false",
):
return lambda g: anp.append(anp.reshape(g, (-1,)), anp.sum(anp.diag(g)))
defvjp(AddJitterOp, AddJitterOp_vjp)
@primitive
def cholesky_factorization(a):
"""
Replacement for :func:`autograd.numpy.linalg.cholesky`. Our backward (vjp)
is faster and simpler, while somewhat less general (only works if
``a.ndim == 2``).
See https://arxiv.org/abs/1710.08717 for derivation of backward (vjp)
expression.
:param a: Symmmetric positive definite matrix A
:return: Lower-triangular Cholesky factor L of A
"""
# Note: Do not use np.linalg.cholesky here, this can cause locking issues
return spl.cholesky(a, lower=True)
def copyltu(x):
return anp.tril(x) + anp.transpose(anp.tril(x, -1))
def cholesky_factorization_backward(l, lbar):
abar = copyltu(anp.matmul(anp.transpose(l), lbar))
abar = anp.transpose(aspl.solve_triangular(l, abar, lower=True, trans="T"))
abar = aspl.solve_triangular(l, abar, lower=True, trans="T")
return 0.5 * abar
def cholesky_factorization_vjp(l, a):
return lambda lbar: cholesky_factorization_backward(l, lbar)
defvjp(cholesky_factorization, cholesky_factorization_vjp)