mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Allow complex valued input for Cholesky decomposition.
PiperOrigin-RevId: 157572536
This commit is contained in:
parent
2d1860859a
commit
473a590c9c
|
|
@ -14,8 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/linalg_ops.cc.
|
||||
// TODO(konstantinos): Enable complex inputs. This will require additional tests
|
||||
// and OP_REQUIRES.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
@ -85,8 +84,10 @@ namespace functor {
|
|||
typename TTypes<T, 3>::Tensor output); \
|
||||
extern template struct MatrixBandPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_float(DECLARE_GPU_SPEC);
|
||||
TF_CALL_double(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <class Scalar>
|
||||
|
|
@ -171,11 +172,15 @@ class CholeskyOpGpu : public AsyncOpKernel {
|
|||
|
||||
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float);
|
||||
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double);
|
||||
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64);
|
||||
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
|
||||
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
|
||||
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64);
|
||||
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128);
|
||||
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);
|
||||
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double);
|
||||
|
||||
|
|
|
|||
|
|
@ -187,6 +187,8 @@ namespace functor {
|
|||
extern template struct MatrixDiagPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
|
|
@ -199,6 +201,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
|||
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
MatrixDiagPartOp<GPUDevice, type>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
|
||||
#undef REGISTER_MATRIX_DIAG_GPU
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
template struct functor::MatrixDiagPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
|
|
|||
|
|
@ -147,6 +147,8 @@ namespace functor {
|
|||
extern template struct MatrixSetDiag<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
|
|
@ -156,6 +158,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
|||
Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
MatrixSetDiagOp<GPUDevice, type>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
#undef REGISTER_MATRIX_SET_DIAG_GPU
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
template struct functor::MatrixSetDiag<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
|
|
|||
|
|
@ -97,8 +97,9 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
|
|||
// an empty set of equation as the empty matrix.
|
||||
return;
|
||||
}
|
||||
const Scalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
|
||||
OP_REQUIRES(context, min_abs_pivot > Scalar(0),
|
||||
using RealScalar = typename Base::RealScalar;
|
||||
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
|
||||
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
|
||||
errors::InvalidArgument("Input matrix is not invertible."));
|
||||
if (lower_) {
|
||||
auto triangle = matrix.template triangularView<Eigen::Lower>();
|
||||
|
|
@ -128,6 +129,10 @@ REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
|||
(MatrixTriangularSolveOp<float>), float);
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<double>), double);
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<complex64>), complex64);
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<complex128>), complex128);
|
||||
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<float>), float);
|
||||
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
|
||||
|
|
@ -215,7 +220,8 @@ class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
|
|||
upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
|
||||
}
|
||||
if (adjoint_) {
|
||||
transpose_matrix = perftools::gputools::blas::Transpose::kTranspose;
|
||||
transpose_matrix =
|
||||
perftools::gputools::blas::Transpose::kConjugateTranspose;
|
||||
} else {
|
||||
transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
|
||||
}
|
||||
|
|
@ -249,6 +255,10 @@ REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
|||
(MatrixTriangularSolveOpGPU<float>), float);
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<double>), double);
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<complex64>), complex64);
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<complex128>), complex128);
|
||||
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<float>), float);
|
||||
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
|
||||
|
|
|
|||
|
|
@ -245,16 +245,25 @@ Equivalent to np.linalg.inv
|
|||
REGISTER_OP("Cholesky")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {double, float}")
|
||||
.Attr("T: {double, float, complex64, complex128}")
|
||||
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
||||
.Doc(R"doc(
|
||||
Computes the Cholesky decomposition of one or more square matrices.
|
||||
|
||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||
form square matrices, with the same constraints as the single matrix Cholesky
|
||||
decomposition above. The output is a tensor of the same shape as the input
|
||||
form square matrices.
|
||||
|
||||
The input has to be symmetric and positive definite. Only the lower-triangular
|
||||
part of the input will be used for this operation. The upper-triangular part
|
||||
will not be read.
|
||||
|
||||
The output is a tensor of the same shape as the input
|
||||
containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
|
||||
|
||||
**Note**: The gradient computation on GPU is faster for large matrices but
|
||||
not for large batch dimensions when the submatrices are small. In this
|
||||
case it might be faster to use the CPU.
|
||||
|
||||
input: Shape is `[..., M, M]`.
|
||||
output: Shape is `[..., M, M]`.
|
||||
)doc");
|
||||
|
|
@ -373,7 +382,7 @@ REGISTER_OP("MatrixTriangularSolve")
|
|||
.Output("output: T")
|
||||
.Attr("lower: bool = True")
|
||||
.Attr("adjoint: bool = False")
|
||||
.Attr("T: {double, float}")
|
||||
.Attr("T: {double, float, complex64, complex128}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return MatrixSolveShapeFn(c, true /* square (*/);
|
||||
})
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ tf_py_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "cholesky_op_test",
|
||||
size = "small",
|
||||
srcs = ["cholesky_op_test.py"],
|
||||
|
|
|
|||
|
|
@ -21,16 +21,70 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
# Different gradient implementations for benchmark purposes
|
||||
def SpecializedGrad(l, grad):
|
||||
return gen_linalg_ops.cholesky_grad(l, grad)
|
||||
|
||||
|
||||
def _GradWithInverseL(l, l_inverse, grad):
|
||||
middle = math_ops.matmul(l, grad, adjoint_a=True)
|
||||
middle = array_ops.matrix_set_diag(middle,
|
||||
0.5 * array_ops.matrix_diag_part(middle))
|
||||
middle = array_ops.matrix_band_part(middle, -1, 0)
|
||||
grad_a = math_ops.matmul(
|
||||
math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
|
||||
grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
|
||||
return grad_a * 0.5
|
||||
|
||||
|
||||
def TriAngSolveCompositeGrad(l, grad):
|
||||
# Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
|
||||
|
||||
# Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle
|
||||
middle = math_ops.matmul(l, grad, adjoint_a=True)
|
||||
middle = array_ops.matrix_set_diag(middle,
|
||||
0.5 * array_ops.matrix_diag_part(middle))
|
||||
middle = array_ops.matrix_band_part(middle, -1, 0)
|
||||
|
||||
# Compute l^{-H} @ middle = z
|
||||
l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True)
|
||||
|
||||
# We need to compute z @ l^{-1}. With matrix_triangular_solve we
|
||||
# actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H}
|
||||
# we can ommit the conjugate transpose here.
|
||||
z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle))
|
||||
grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True)
|
||||
grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
|
||||
return grad_a * 0.5
|
||||
|
||||
|
||||
def MatrixInverseCompositeGrad(l, grad):
|
||||
l_inverse = linalg_ops.matrix_inverse(l)
|
||||
return _GradWithInverseL(l, l_inverse, grad)
|
||||
|
||||
|
||||
def TriAngInvCompositeGrad(l, grad):
|
||||
num_rows = array_ops.shape(l)[-1]
|
||||
batch_shape = array_ops.shape(l)[:-2]
|
||||
l_inverse = linalg_ops.matrix_triangular_solve(
|
||||
l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
|
||||
return _GradWithInverseL(l, l_inverse, grad)
|
||||
|
||||
|
||||
class CholeskyOpTest(test.TestCase):
|
||||
|
||||
def _verifyCholeskyBase(self, sess, x, chol, verification):
|
||||
|
|
@ -54,9 +108,14 @@ class CholeskyOpTest(test.TestCase):
|
|||
self._verifyCholeskyBase(sess, x, chol, verification)
|
||||
|
||||
def testBasic(self):
|
||||
data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
|
||||
for dtype in (np.float32, np.float64):
|
||||
self._verifyCholesky(
|
||||
np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]).astype(dtype))
|
||||
self._verifyCholesky(data.astype(dtype))
|
||||
for dtype in (np.complex64, np.complex128):
|
||||
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||
complex_data += np.triu(-1j * data, 1).astype(dtype)
|
||||
complex_data += data
|
||||
self._verifyCholesky(complex_data)
|
||||
|
||||
def testBatch(self):
|
||||
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
|
||||
|
|
@ -65,12 +124,18 @@ class CholeskyOpTest(test.TestCase):
|
|||
odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
|
||||
self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
|
||||
|
||||
# Generate random positive-definite matrices.
|
||||
# Generate random positive-definite matrices.
|
||||
matrices = np.random.rand(10, 5, 5)
|
||||
for i in xrange(10):
|
||||
matrices[i] = np.dot(matrices[i].T, matrices[i])
|
||||
self._verifyCholesky(matrices)
|
||||
|
||||
# Generate random complex valued positive-definite matrices.
|
||||
matrices = np.random.rand(10, 5, 5) + 1j * np.random.rand(10, 5, 5)
|
||||
for i in xrange(10):
|
||||
matrices[i] = np.dot(matrices[i].T.conj(), matrices[i])
|
||||
self._verifyCholesky(matrices)
|
||||
|
||||
def testNonSquareMatrix(self):
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
|
||||
|
|
@ -110,7 +175,14 @@ class CholeskyGradTest(test.TestCase):
|
|||
def testSmallMatrices(self):
|
||||
np.random.seed(0)
|
||||
shapes = self.getShapes([1, 2, 10])
|
||||
self.runFiniteDifferences(shapes)
|
||||
self.runFiniteDifferences(
|
||||
shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
|
||||
|
||||
def testSmallMatricesComplex(self):
|
||||
np.random.seed(0)
|
||||
shapes = self.getShapes([1, 2, 10])
|
||||
self.runFiniteDifferences(
|
||||
shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
|
||||
|
||||
def testOneBlockMatrices(self):
|
||||
np.random.seed(0)
|
||||
|
|
@ -132,25 +204,61 @@ class CholeskyGradTest(test.TestCase):
|
|||
self.runFiniteDifferences(
|
||||
shapes, dtypes=(dtypes_lib.float64,), scalarTest=True)
|
||||
|
||||
def testTwoBlockMatrixComplexFloat(self):
|
||||
np.random.seed(0)
|
||||
shapes = self.getShapes([2 * self._backprop_block_size + 1])
|
||||
self.runFiniteDifferences(
|
||||
shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True)
|
||||
|
||||
def testTwoBlockMatrixComplexDouble(self):
|
||||
np.random.seed(0)
|
||||
shapes = self.getShapes([2 * self._backprop_block_size + 1])
|
||||
self.runFiniteDifferences(
|
||||
shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True)
|
||||
|
||||
def testAgainstSpecialized(self):
|
||||
np.random.seed(0)
|
||||
data = np.random.randn(33, 33).astype(np.float32)
|
||||
data = np.matmul(data, data.T)
|
||||
grad_data = np.random.randn(*data.shape).astype(np.float32)
|
||||
|
||||
with ops.Graph().as_default(), self.test_session(use_gpu=False) as s:
|
||||
x = constant_op.constant(data, dtypes_lib.float32)
|
||||
chol = linalg_ops.cholesky(x)
|
||||
composite_grad = gradients_impl.gradients(chol, x, grad_data)[0]
|
||||
specialized_grad = SpecializedGrad(chol, grad_data)
|
||||
reference, actual = s.run([specialized_grad, composite_grad])
|
||||
self.assertAllClose(reference, actual)
|
||||
|
||||
def runFiniteDifferences(self,
|
||||
shapes,
|
||||
dtypes=(dtypes_lib.float32, dtypes_lib.float64),
|
||||
dtypes=(dtypes_lib.float32, dtypes_lib.float64,
|
||||
dtypes_lib.complex64, dtypes_lib.complex128),
|
||||
scalarTest=False):
|
||||
with self.test_session(use_gpu=False):
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape in shapes:
|
||||
for batch in False, True:
|
||||
for dtype in dtypes:
|
||||
if not scalarTest:
|
||||
x = constant_op.constant(
|
||||
np.random.randn(shape[0], shape[1]), dtype)
|
||||
tensor = math_ops.matmul(x, array_ops.transpose(x)) / shape[0]
|
||||
data = np.random.randn(shape[0], shape[1])
|
||||
if dtype.is_complex:
|
||||
data = data.astype(np.complex64)
|
||||
data += 1j * np.random.randn(shape[0], shape[1])
|
||||
x = constant_op.constant(data, dtype)
|
||||
tensor = math_ops.matmul(
|
||||
x, math_ops.conj(array_ops.transpose(x))) / shape[0]
|
||||
else:
|
||||
# This is designed to be a faster test for larger matrices.
|
||||
x = constant_op.constant(np.random.randn(), dtype)
|
||||
data = np.random.randn()
|
||||
if dtype.is_complex:
|
||||
data = np.complex64(data)
|
||||
data += 1j * np.random.randn()
|
||||
x = constant_op.constant(data, dtype)
|
||||
R = constant_op.constant(
|
||||
np.random.randn(shape[0], shape[1]), dtype)
|
||||
e = math_ops.multiply(R, x)
|
||||
tensor = math_ops.matmul(e, array_ops.transpose(e)) / shape[0]
|
||||
tensor = math_ops.matmul(
|
||||
e, math_ops.conj(array_ops.transpose(e))) / shape[0]
|
||||
|
||||
# Inner-most matrices in tensor are positive definite.
|
||||
if batch:
|
||||
|
|
@ -159,15 +267,87 @@ class CholeskyGradTest(test.TestCase):
|
|||
y = linalg_ops.cholesky(tensor)
|
||||
if scalarTest:
|
||||
y = math_ops.reduce_mean(y)
|
||||
error = gradient_checker.compute_gradient_error(x,
|
||||
x._shape_as_list(),
|
||||
y,
|
||||
y._shape_as_list())
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
x, x._shape_as_list(), y, y._shape_as_list())
|
||||
tf_logging.info("error = %f", error)
|
||||
if dtype == dtypes_lib.float64:
|
||||
self.assertLess(error, 1e-5)
|
||||
elif dtype == dtypes_lib.complex128:
|
||||
self.assertLess(error, 5e-5)
|
||||
else:
|
||||
self.assertLess(error, 3e-3)
|
||||
self.assertLess(error, 5e-3)
|
||||
|
||||
|
||||
class CholeskyBenchmark(test.Benchmark):
|
||||
|
||||
sizes = [
|
||||
(4, 4), (16, 16), (256, 256), (1024, 1024), (2048, 2048),
|
||||
(513, 2, 2), (513, 8, 8), (4, 513, 2, 2)
|
||||
]
|
||||
|
||||
def _GenerateData(self, size):
|
||||
batch_shape = size[:-2]
|
||||
size = size[-2:]
|
||||
assert size[0] == size[1]
|
||||
n = size[0]
|
||||
data = np.ones(size).astype(np.float32) / (2.0 * n) + np.diag(
|
||||
np.ones(n).astype(np.float32))
|
||||
return np.tile(data, batch_shape + (1, 1))
|
||||
|
||||
def benchmarkCholeskyOp(self):
|
||||
for size in self.sizes:
|
||||
data = self._GenerateData(size)
|
||||
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session() as sess, \
|
||||
ops.device("/cpu:0"):
|
||||
l = linalg_ops.cholesky(data)
|
||||
self.run_op_benchmark(
|
||||
sess, l,
|
||||
min_iters=25,
|
||||
name="cholesky_cpu_{size}".format(size=size))
|
||||
|
||||
if test.is_gpu_available(True):
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session() as sess, \
|
||||
ops.device("/gpu:0"):
|
||||
l = linalg_ops.cholesky(data)
|
||||
self.run_op_benchmark(
|
||||
sess, l,
|
||||
min_iters=25,
|
||||
name="cholesky_gpu_{size}".format(size=size))
|
||||
|
||||
def benchmarkGradVariants(self):
|
||||
def _BenchmarkGrad(grad_fn, name, device):
|
||||
for size in self.sizes:
|
||||
data = self._GenerateData(size)
|
||||
l = np.linalg.cholesky(data)
|
||||
grad_data = np.random.randn(*data.shape).astype(np.float32)
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session() as sess, \
|
||||
ops.device(device):
|
||||
grad = grad_fn(l, grad_data)
|
||||
self.run_op_benchmark(
|
||||
sess, grad,
|
||||
min_iters=25,
|
||||
name="{name}_{dev}_{size}".format(
|
||||
name=name, dev=grad.device, size=size))
|
||||
|
||||
if test.is_gpu_available(True):
|
||||
_BenchmarkGrad(
|
||||
MatrixInverseCompositeGrad, "composite_matrix_inverse", "/gpu:0")
|
||||
_BenchmarkGrad(
|
||||
TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/gpu:0")
|
||||
_BenchmarkGrad(
|
||||
TriAngSolveCompositeGrad, "composite_triangular_solve", "/gpu:0")
|
||||
|
||||
_BenchmarkGrad(
|
||||
MatrixInverseCompositeGrad, "composite_matrix_inverse", "/cpu:0")
|
||||
_BenchmarkGrad(
|
||||
TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/cpu:0")
|
||||
_BenchmarkGrad(
|
||||
TriAngSolveCompositeGrad, "composite_triangular_solve", "/cpu:0")
|
||||
_BenchmarkGrad(SpecializedGrad, "specialized", "/cpu:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
|||
|
||||
class MatrixTriangularSolveOpTest(test.TestCase):
|
||||
|
||||
def _verifySolveAllWays(self, x, y, batch_dims=None):
|
||||
def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None):
|
||||
for lower in True, False:
|
||||
for adjoint in True, False:
|
||||
for use_placeholder in True, False:
|
||||
|
|
@ -38,7 +38,14 @@ class MatrixTriangularSolveOpTest(test.TestCase):
|
|||
lower=lower,
|
||||
adjoint=adjoint,
|
||||
batch_dims=batch_dims,
|
||||
use_placeholder=use_placeholder)
|
||||
use_placeholder=use_placeholder,
|
||||
dtypes=dtypes)
|
||||
|
||||
def _verifySolveAllWaysReal(self, x, y, batch_dims=None):
|
||||
self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
|
||||
|
||||
def _verifySolveAllWaysComplex(self, x, y, batch_dims=None):
|
||||
self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
|
||||
|
||||
def _verifySolve(self,
|
||||
x,
|
||||
|
|
@ -46,8 +53,9 @@ class MatrixTriangularSolveOpTest(test.TestCase):
|
|||
lower=True,
|
||||
adjoint=False,
|
||||
batch_dims=None,
|
||||
use_placeholder=False):
|
||||
for np_type in [np.float32, np.float64]:
|
||||
use_placeholder=False,
|
||||
dtypes=(np.float32, np.float64)):
|
||||
for np_type in dtypes:
|
||||
a = x.astype(np_type)
|
||||
b = y.astype(np_type)
|
||||
# For numpy.solve we have to explicitly zero out the strictly
|
||||
|
|
@ -89,22 +97,48 @@ class MatrixTriangularSolveOpTest(test.TestCase):
|
|||
# 1x1 matrix, single rhs.
|
||||
matrix = np.array([[0.1]])
|
||||
rhs0 = np.array([[1.]])
|
||||
self._verifySolveAllWays(matrix, rhs0)
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
# 2x2 matrices, single right-hand side.
|
||||
matrix = np.array([[1., 2.], [3., 4.]])
|
||||
rhs0 = np.array([[1.], [1.]])
|
||||
self._verifySolveAllWays(matrix, rhs0)
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
# 2x2 matrices, 3 right-hand sides.
|
||||
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
|
||||
self._verifySolveAllWays(matrix, rhs1)
|
||||
self._verifySolveAllWaysReal(matrix, rhs1)
|
||||
|
||||
def testSolveComplex(self):
|
||||
# 1x1 matrix, single rhs.
|
||||
matrix = np.array([[0.1 + 1j * 0.1]])
|
||||
rhs0 = np.array([[1. + 1j]])
|
||||
self._verifySolveAllWaysComplex(matrix, rhs0)
|
||||
# 2x2 matrices, single right-hand side.
|
||||
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
|
||||
matrix += 1j * matrix
|
||||
rhs0 = np.array([[1.], [1.]]).astype(np.complex64)
|
||||
rhs0 += 1j * rhs0
|
||||
self._verifySolveAllWaysComplex(matrix, rhs0)
|
||||
# 2x2 matrices, 3 right-hand sides.
|
||||
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
|
||||
rhs1 += 1j * rhs1
|
||||
self._verifySolveAllWaysComplex(matrix, rhs1)
|
||||
|
||||
def testSolveBatch(self):
|
||||
matrix = np.array([[1., 2.], [3., 4.]])
|
||||
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
|
||||
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
|
||||
self._verifySolveAllWays(matrix, rhs, batch_dims=[2, 3])
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
|
||||
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
|
||||
self._verifySolveAllWays(matrix, rhs, batch_dims=[3, 2])
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
|
||||
|
||||
def testSolveBatchComplex(self):
|
||||
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
|
||||
matrix += 1j * matrix
|
||||
rhs = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
|
||||
rhs += 1j * rhs
|
||||
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[2, 3])
|
||||
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[3, 2])
|
||||
|
||||
def testNonSquareMatrix(self):
|
||||
# A non-square matrix should cause an error.
|
||||
|
|
|
|||
|
|
@ -57,7 +57,24 @@ def _MatrixDeterminantGrad(op, grad):
|
|||
@ops.RegisterGradient("Cholesky")
|
||||
def _CholeskyGrad(op, grad):
|
||||
"""Gradient for Cholesky."""
|
||||
return linalg_ops.cholesky_grad(op.outputs[0], grad)
|
||||
|
||||
# Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
|
||||
l = op.outputs[0]
|
||||
num_rows = array_ops.shape(l)[-1]
|
||||
batch_shape = array_ops.shape(l)[:-2]
|
||||
l_inverse = linalg_ops.matrix_triangular_solve(
|
||||
l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
|
||||
|
||||
middle = math_ops.matmul(l, grad, adjoint_a=True)
|
||||
middle = array_ops.matrix_set_diag(middle,
|
||||
0.5 * array_ops.matrix_diag_part(middle))
|
||||
middle = array_ops.matrix_band_part(middle, -1, 0)
|
||||
|
||||
grad_a = math_ops.matmul(
|
||||
math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
|
||||
|
||||
grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
|
||||
return grad_a * 0.5
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixSolve")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user