Allow complex valued input for Cholesky decomposition.

PiperOrigin-RevId: 157572536
This commit is contained in:
A. Unique TensorFlower 2017-05-31 04:23:30 -07:00 committed by TensorFlower Gardener
parent 2d1860859a
commit 473a590c9c
11 changed files with 305 additions and 38 deletions

View File

@ -14,8 +14,7 @@ limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
// TODO(konstantinos): Enable complex inputs. This will require additional tests
// and OP_REQUIRES.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
@ -85,8 +84,10 @@ namespace functor {
typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixBandPart<GPUDevice, T>;
TF_CALL_float(DECLARE_GPU_SPEC);
TF_CALL_double(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor
template <class Scalar>
@ -171,11 +172,15 @@ class CholeskyOpGpu : public AsyncOpKernel {
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float);
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double);
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64);
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128);
#endif // GOOGLE_CUDA
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double);

View File

@ -187,6 +187,8 @@ namespace functor {
extern template struct MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor
@ -199,6 +201,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
#undef REGISTER_MATRIX_DIAG_GPU
// Registration of the deprecated kernel.

View File

@ -31,6 +31,8 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC);
} // end namespace tensorflow

View File

@ -147,6 +147,8 @@ namespace functor {
extern template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor
@ -156,6 +158,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
#undef REGISTER_MATRIX_SET_DIAG_GPU
// Registration of the deprecated kernel.

View File

@ -29,6 +29,8 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC);
} // end namespace tensorflow

View File

@ -97,8 +97,9 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
// an empty set of equation as the empty matrix.
return;
}
const Scalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > Scalar(0),
using RealScalar = typename Base::RealScalar;
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
if (lower_) {
auto triangle = matrix.template triangularView<Eigen::Lower>();
@ -128,6 +129,10 @@ REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<complex64>), complex64);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<complex128>), complex128);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
@ -215,7 +220,8 @@ class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
}
if (adjoint_) {
transpose_matrix = perftools::gputools::blas::Transpose::kTranspose;
transpose_matrix =
perftools::gputools::blas::Transpose::kConjugateTranspose;
} else {
transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
}
@ -249,6 +255,10 @@ REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<double>), double);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<complex64>), complex64);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<complex128>), complex128);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",

View File

@ -245,16 +245,25 @@ Equivalent to np.linalg.inv
REGISTER_OP("Cholesky")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(BatchUnchangedSquareShapeFn)
.Doc(R"doc(
Computes the Cholesky decomposition of one or more square matrices.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices, with the same constraints as the single matrix Cholesky
decomposition above. The output is a tensor of the same shape as the input
form square matrices.
The input has to be symmetric and positive definite. Only the lower-triangular
part of the input will be used for this operation. The upper-triangular part
will not be read.
The output is a tensor of the same shape as the input
containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
**Note**: The gradient computation on GPU is faster for large matrices but
not for large batch dimensions when the submatrices are small. In this
case it might be faster to use the CPU.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M, M]`.
)doc");
@ -373,7 +382,7 @@ REGISTER_OP("MatrixTriangularSolve")
.Output("output: T")
.Attr("lower: bool = True")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
return MatrixSolveShapeFn(c, true /* square (*/);
})

View File

@ -120,7 +120,7 @@ tf_py_test(
],
)
tf_py_test(
cuda_py_test(
name = "cholesky_op_test",
size = "small",
srcs = ["cholesky_op_test.py"],

View File

@ -21,16 +21,70 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
# Different gradient implementations for benchmark purposes
def SpecializedGrad(l, grad):
return gen_linalg_ops.cholesky_grad(l, grad)
def _GradWithInverseL(l, l_inverse, grad):
middle = math_ops.matmul(l, grad, adjoint_a=True)
middle = array_ops.matrix_set_diag(middle,
0.5 * array_ops.matrix_diag_part(middle))
middle = array_ops.matrix_band_part(middle, -1, 0)
grad_a = math_ops.matmul(
math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
return grad_a * 0.5
def TriAngSolveCompositeGrad(l, grad):
# Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
# Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle
middle = math_ops.matmul(l, grad, adjoint_a=True)
middle = array_ops.matrix_set_diag(middle,
0.5 * array_ops.matrix_diag_part(middle))
middle = array_ops.matrix_band_part(middle, -1, 0)
# Compute l^{-H} @ middle = z
l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True)
# We need to compute z @ l^{-1}. With matrix_triangular_solve we
# actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H}
# we can ommit the conjugate transpose here.
z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle))
grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True)
grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
return grad_a * 0.5
def MatrixInverseCompositeGrad(l, grad):
l_inverse = linalg_ops.matrix_inverse(l)
return _GradWithInverseL(l, l_inverse, grad)
def TriAngInvCompositeGrad(l, grad):
num_rows = array_ops.shape(l)[-1]
batch_shape = array_ops.shape(l)[:-2]
l_inverse = linalg_ops.matrix_triangular_solve(
l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
return _GradWithInverseL(l, l_inverse, grad)
class CholeskyOpTest(test.TestCase):
def _verifyCholeskyBase(self, sess, x, chol, verification):
@ -54,9 +108,14 @@ class CholeskyOpTest(test.TestCase):
self._verifyCholeskyBase(sess, x, chol, verification)
def testBasic(self):
data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
for dtype in (np.float32, np.float64):
self._verifyCholesky(
np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]).astype(dtype))
self._verifyCholesky(data.astype(dtype))
for dtype in (np.complex64, np.complex128):
complex_data = np.tril(1j * data, -1).astype(dtype)
complex_data += np.triu(-1j * data, 1).astype(dtype)
complex_data += data
self._verifyCholesky(complex_data)
def testBatch(self):
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
@ -65,12 +124,18 @@ class CholeskyOpTest(test.TestCase):
odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
# Generate random positive-definite matrices.
# Generate random positive-definite matrices.
matrices = np.random.rand(10, 5, 5)
for i in xrange(10):
matrices[i] = np.dot(matrices[i].T, matrices[i])
self._verifyCholesky(matrices)
# Generate random complex valued positive-definite matrices.
matrices = np.random.rand(10, 5, 5) + 1j * np.random.rand(10, 5, 5)
for i in xrange(10):
matrices[i] = np.dot(matrices[i].T.conj(), matrices[i])
self._verifyCholesky(matrices)
def testNonSquareMatrix(self):
with self.assertRaises(ValueError):
linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
@ -110,7 +175,14 @@ class CholeskyGradTest(test.TestCase):
def testSmallMatrices(self):
np.random.seed(0)
shapes = self.getShapes([1, 2, 10])
self.runFiniteDifferences(shapes)
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
def testSmallMatricesComplex(self):
np.random.seed(0)
shapes = self.getShapes([1, 2, 10])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
def testOneBlockMatrices(self):
np.random.seed(0)
@ -132,25 +204,61 @@ class CholeskyGradTest(test.TestCase):
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float64,), scalarTest=True)
def testTwoBlockMatrixComplexFloat(self):
np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True)
def testTwoBlockMatrixComplexDouble(self):
np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True)
def testAgainstSpecialized(self):
np.random.seed(0)
data = np.random.randn(33, 33).astype(np.float32)
data = np.matmul(data, data.T)
grad_data = np.random.randn(*data.shape).astype(np.float32)
with ops.Graph().as_default(), self.test_session(use_gpu=False) as s:
x = constant_op.constant(data, dtypes_lib.float32)
chol = linalg_ops.cholesky(x)
composite_grad = gradients_impl.gradients(chol, x, grad_data)[0]
specialized_grad = SpecializedGrad(chol, grad_data)
reference, actual = s.run([specialized_grad, composite_grad])
self.assertAllClose(reference, actual)
def runFiniteDifferences(self,
shapes,
dtypes=(dtypes_lib.float32, dtypes_lib.float64),
dtypes=(dtypes_lib.float32, dtypes_lib.float64,
dtypes_lib.complex64, dtypes_lib.complex128),
scalarTest=False):
with self.test_session(use_gpu=False):
with self.test_session(use_gpu=True):
for shape in shapes:
for batch in False, True:
for dtype in dtypes:
if not scalarTest:
x = constant_op.constant(
np.random.randn(shape[0], shape[1]), dtype)
tensor = math_ops.matmul(x, array_ops.transpose(x)) / shape[0]
data = np.random.randn(shape[0], shape[1])
if dtype.is_complex:
data = data.astype(np.complex64)
data += 1j * np.random.randn(shape[0], shape[1])
x = constant_op.constant(data, dtype)
tensor = math_ops.matmul(
x, math_ops.conj(array_ops.transpose(x))) / shape[0]
else:
# This is designed to be a faster test for larger matrices.
x = constant_op.constant(np.random.randn(), dtype)
data = np.random.randn()
if dtype.is_complex:
data = np.complex64(data)
data += 1j * np.random.randn()
x = constant_op.constant(data, dtype)
R = constant_op.constant(
np.random.randn(shape[0], shape[1]), dtype)
e = math_ops.multiply(R, x)
tensor = math_ops.matmul(e, array_ops.transpose(e)) / shape[0]
tensor = math_ops.matmul(
e, math_ops.conj(array_ops.transpose(e))) / shape[0]
# Inner-most matrices in tensor are positive definite.
if batch:
@ -159,15 +267,87 @@ class CholeskyGradTest(test.TestCase):
y = linalg_ops.cholesky(tensor)
if scalarTest:
y = math_ops.reduce_mean(y)
error = gradient_checker.compute_gradient_error(x,
x._shape_as_list(),
y,
y._shape_as_list())
error = gradient_checker.compute_gradient_error(
x, x._shape_as_list(), y, y._shape_as_list())
tf_logging.info("error = %f", error)
if dtype == dtypes_lib.float64:
self.assertLess(error, 1e-5)
elif dtype == dtypes_lib.complex128:
self.assertLess(error, 5e-5)
else:
self.assertLess(error, 3e-3)
self.assertLess(error, 5e-3)
class CholeskyBenchmark(test.Benchmark):
sizes = [
(4, 4), (16, 16), (256, 256), (1024, 1024), (2048, 2048),
(513, 2, 2), (513, 8, 8), (4, 513, 2, 2)
]
def _GenerateData(self, size):
batch_shape = size[:-2]
size = size[-2:]
assert size[0] == size[1]
n = size[0]
data = np.ones(size).astype(np.float32) / (2.0 * n) + np.diag(
np.ones(n).astype(np.float32))
return np.tile(data, batch_shape + (1, 1))
def benchmarkCholeskyOp(self):
for size in self.sizes:
data = self._GenerateData(size)
with ops.Graph().as_default(), \
session.Session() as sess, \
ops.device("/cpu:0"):
l = linalg_ops.cholesky(data)
self.run_op_benchmark(
sess, l,
min_iters=25,
name="cholesky_cpu_{size}".format(size=size))
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
session.Session() as sess, \
ops.device("/gpu:0"):
l = linalg_ops.cholesky(data)
self.run_op_benchmark(
sess, l,
min_iters=25,
name="cholesky_gpu_{size}".format(size=size))
def benchmarkGradVariants(self):
def _BenchmarkGrad(grad_fn, name, device):
for size in self.sizes:
data = self._GenerateData(size)
l = np.linalg.cholesky(data)
grad_data = np.random.randn(*data.shape).astype(np.float32)
with ops.Graph().as_default(), \
session.Session() as sess, \
ops.device(device):
grad = grad_fn(l, grad_data)
self.run_op_benchmark(
sess, grad,
min_iters=25,
name="{name}_{dev}_{size}".format(
name=name, dev=grad.device, size=size))
if test.is_gpu_available(True):
_BenchmarkGrad(
MatrixInverseCompositeGrad, "composite_matrix_inverse", "/gpu:0")
_BenchmarkGrad(
TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/gpu:0")
_BenchmarkGrad(
TriAngSolveCompositeGrad, "composite_triangular_solve", "/gpu:0")
_BenchmarkGrad(
MatrixInverseCompositeGrad, "composite_matrix_inverse", "/cpu:0")
_BenchmarkGrad(
TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/cpu:0")
_BenchmarkGrad(
TriAngSolveCompositeGrad, "composite_triangular_solve", "/cpu:0")
_BenchmarkGrad(SpecializedGrad, "specialized", "/cpu:0")
if __name__ == "__main__":

View File

@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class MatrixTriangularSolveOpTest(test.TestCase):
def _verifySolveAllWays(self, x, y, batch_dims=None):
def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None):
for lower in True, False:
for adjoint in True, False:
for use_placeholder in True, False:
@ -38,7 +38,14 @@ class MatrixTriangularSolveOpTest(test.TestCase):
lower=lower,
adjoint=adjoint,
batch_dims=batch_dims,
use_placeholder=use_placeholder)
use_placeholder=use_placeholder,
dtypes=dtypes)
def _verifySolveAllWaysReal(self, x, y, batch_dims=None):
self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
def _verifySolveAllWaysComplex(self, x, y, batch_dims=None):
self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
def _verifySolve(self,
x,
@ -46,8 +53,9 @@ class MatrixTriangularSolveOpTest(test.TestCase):
lower=True,
adjoint=False,
batch_dims=None,
use_placeholder=False):
for np_type in [np.float32, np.float64]:
use_placeholder=False,
dtypes=(np.float32, np.float64)):
for np_type in dtypes:
a = x.astype(np_type)
b = y.astype(np_type)
# For numpy.solve we have to explicitly zero out the strictly
@ -89,22 +97,48 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# 1x1 matrix, single rhs.
matrix = np.array([[0.1]])
rhs0 = np.array([[1.]])
self._verifySolveAllWays(matrix, rhs0)
self._verifySolveAllWaysReal(matrix, rhs0)
# 2x2 matrices, single right-hand side.
matrix = np.array([[1., 2.], [3., 4.]])
rhs0 = np.array([[1.], [1.]])
self._verifySolveAllWays(matrix, rhs0)
self._verifySolveAllWaysReal(matrix, rhs0)
# 2x2 matrices, 3 right-hand sides.
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
self._verifySolveAllWays(matrix, rhs1)
self._verifySolveAllWaysReal(matrix, rhs1)
def testSolveComplex(self):
# 1x1 matrix, single rhs.
matrix = np.array([[0.1 + 1j * 0.1]])
rhs0 = np.array([[1. + 1j]])
self._verifySolveAllWaysComplex(matrix, rhs0)
# 2x2 matrices, single right-hand side.
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
matrix += 1j * matrix
rhs0 = np.array([[1.], [1.]]).astype(np.complex64)
rhs0 += 1j * rhs0
self._verifySolveAllWaysComplex(matrix, rhs0)
# 2x2 matrices, 3 right-hand sides.
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
rhs1 += 1j * rhs1
self._verifySolveAllWaysComplex(matrix, rhs1)
def testSolveBatch(self):
matrix = np.array([[1., 2.], [3., 4.]])
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
self._verifySolveAllWays(matrix, rhs, batch_dims=[2, 3])
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
self._verifySolveAllWays(matrix, rhs, batch_dims=[3, 2])
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
def testSolveBatchComplex(self):
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
matrix += 1j * matrix
rhs = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
rhs += 1j * rhs
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[2, 3])
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[3, 2])
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.

View File

@ -57,7 +57,24 @@ def _MatrixDeterminantGrad(op, grad):
@ops.RegisterGradient("Cholesky")
def _CholeskyGrad(op, grad):
"""Gradient for Cholesky."""
return linalg_ops.cholesky_grad(op.outputs[0], grad)
# Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
l = op.outputs[0]
num_rows = array_ops.shape(l)[-1]
batch_shape = array_ops.shape(l)[:-2]
l_inverse = linalg_ops.matrix_triangular_solve(
l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
middle = math_ops.matmul(l, grad, adjoint_a=True)
middle = array_ops.matrix_set_diag(middle,
0.5 * array_ops.matrix_diag_part(middle))
middle = array_ops.matrix_band_part(middle, -1, 0)
grad_a = math_ops.matmul(
math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
return grad_a * 0.5
@ops.RegisterGradient("MatrixSolve")