# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import autograd.numpy as anp
import autograd.scipy.linalg as aspl
from autograd.extend import primitive, defvjp
import numpy as np
import scipy.linalg as spl
import logging
import math
logger = logging.getLogger(__name__)
__all__ = ["AddJitterOp", "flatten_and_concat", "cholesky_factorization"]
INITIAL_JITTER_FACTOR = 1e-9
JITTER_GROWTH = 10.0
JITTER_UPPERBOUND_FACTOR = 1e3
[docs]
def flatten_and_concat(x: anp.ndarray, sigsq_init: anp.ndarray):
return anp.append(anp.reshape(x, (-1,)), sigsq_init)
@primitive
def AddJitterOp(
inputs: np.ndarray,
initial_jitter_factor=INITIAL_JITTER_FACTOR,
jitter_growth=JITTER_GROWTH,
debug_log="false",
):
"""
Finds smaller jitter to add to diagonal of square matrix to render the
matrix positive definite (in that linalg.potrf works).
Given input x (positive semi-definite matrix) and ``sigsq_init`` (nonneg
scalar), find ``sigsq_final`` (nonneg scalar), so that:
| ``sigsq_final = sigsq_init + jitter``, ``jitter >= 0``,
| ``x + sigsq_final * Id`` positive definite (so that ``potrf`` call works)
We return the matrix ``x + sigsq_final * Id``, for which ``potrf`` has not failed.
For the gradient, the dependence of jitter on the inputs is ignored.
The values tried for sigsq_final are:
| ``sigsq_init, sigsq_init + initial_jitter * (jitter_growth ** k)``,
``k = 0, 1, 2, ...``,
| ``initial_jitter = initial_jitter_factor * max(mean(diag(x)), 1)``
Note: The scaling of initial_jitter with ``mean(diag(x))`` is taken from ``GPy``.
The rationale is that the largest eigenvalue of x is ``>= mean(diag(x))``, and
likely of this magnitude.
There is no guarantee that the Cholesky factor returned is well-conditioned
enough for subsequent computations to be reliable. A better solution
would be to estimate the condition number of the Cholesky factor, and to add
jitter until this is bounded below a threshold we tolerate. See
| Higham, N.
| A Survey of Condition Number Estimation for Triangular Matrices
| MIMS EPrint: 2007.10
Algorithm 4.1 could work for us.
"""
assert initial_jitter_factor > 0.0 and jitter_growth > 1.0
n_square = inputs.shape[0] - 1
n = int(math.sqrt(n_square))
assert (
n_square % n == 0 and n_square // n == n
), "x must be square matrix, shape (n, n)"
x, sigsq_init = np.reshape(inputs[:-1], (n, -1)), inputs[-1]
def _get_constant_identity(x, constant):
n, _ = x.shape
return np.diag(np.ones((n,)) * constant)
def _get_jitter_upperbound(x):
# To define a safeguard in the while-loop of the forward,
# we define an upperbound on the jitter we can reasonably add
# the bound is quite generous, and is dependent on the scale of the input x
# (the scale is captured via the trace of x)
# the primary goal is avoid any infinite while-loop.
return JITTER_UPPERBOUND_FACTOR * max(1.0, np.mean(np.diag(x)))
jitter = 0.0
jitter_upperbound = _get_jitter_upperbound(x)
must_increase_jitter = True
x_plus_constant = None
while must_increase_jitter and jitter <= jitter_upperbound:
try:
x_plus_constant = x + _get_constant_identity(x, sigsq_init + jitter)
# Note: Do not use np.linalg.cholesky here, this can cause
# locking issues
L = spl.cholesky(x_plus_constant, lower=True)
must_increase_jitter = False
except spl.LinAlgError:
if debug_log == "true":
logger.info("sigsq = {} does not work".format(sigsq_init + jitter))
if jitter == 0.0:
jitter = initial_jitter_factor * max(1.0, np.mean(np.diag(x)))
else:
jitter = jitter * jitter_growth
assert (
not must_increase_jitter
), "The jitter ({}) has reached its upperbound ({}) while the Cholesky of the input matrix still cannot be computed.".format(
jitter, jitter_upperbound
)
if debug_log == "true":
logger.info("sigsq_final = {}".format(sigsq_init + jitter))
return x_plus_constant
def AddJitterOp_vjp(
ans: np.ndarray,
inputs: np.ndarray,
initial_jitter_factor=INITIAL_JITTER_FACTOR,
jitter_growth=JITTER_GROWTH,
debug_log="false",
):
return lambda g: anp.append(anp.reshape(g, (-1,)), anp.sum(anp.diag(g)))
defvjp(AddJitterOp, AddJitterOp_vjp)
@primitive
def cholesky_factorization(a):
"""
Replacement for :func:`autograd.numpy.linalg.cholesky`. Our backward (vjp)
is faster and simpler, while somewhat less general (only works if
``a.ndim == 2``).
See https://arxiv.org/abs/1710.08717 for derivation of backward (vjp)
expression.
:param a: Symmmetric positive definite matrix A
:return: Lower-triangular Cholesky factor L of A
"""
# Note: Do not use np.linalg.cholesky here, this can cause locking issues
return spl.cholesky(a, lower=True)
def copyltu(x):
return anp.tril(x) + anp.transpose(anp.tril(x, -1))
def cholesky_factorization_backward(l, lbar):
abar = copyltu(anp.matmul(anp.transpose(l), lbar))
abar = anp.transpose(aspl.solve_triangular(l, abar, lower=True, trans="T"))
abar = aspl.solve_triangular(l, abar, lower=True, trans="T")
return 0.5 * abar
def cholesky_factorization_vjp(l, a):
return lambda lbar: cholesky_factorization_backward(l, lbar)
defvjp(cholesky_factorization, cholesky_factorization_vjp)