Add tf.linalg.eigh_tridiagonal, which computes the eigenvalues of a Hermitian tridiagonal matrix.

PiperOrigin-RevId: 372601772
Change-Id: Idc6471cd851a88b17ab657a567dd018768a1cd8d
This commit is contained in:
A. Unique TensorFlower 2021-05-07 12:04:38 -07:00 committed by TensorFlower Gardener
parent 5d637e92bf
commit 65219fa1b1
6 changed files with 333 additions and 9 deletions

View File

@ -52,10 +52,12 @@
lower overall memory usage, and a cleaner API. It does not require
specifying a `delete_key` and `empty_key` that cannot be inserted into
the table.
* Added support for specifying number of subdivisions in all reduce host
* Added support for specifying number of subdivisions in all reduce host
collective. This parallelizes work on CPU and speeds up the collective
performance. Default behavior is unchanged.
* SavedModel
* Added `tf.linalg.eigh_tridiagonal` that computes the eigenvalues of a
Hermitian tridiagonal matrix.
* SavedModel
* Added `tf.saved_model.experimental.TrackableResource`, which allows
the creation of custom wrapper objects for resource tensors.
* Added a SavedModel load option to allow restoring partial

View File

@ -2188,6 +2188,7 @@ cuda_py_test(
name = "linalg_ops_test",
size = "medium",
srcs = ["linalg_ops_test.py"],
shard_count = 4,
tags = ["no_windows_gpu"],
deps = [
"//tensorflow/python:array_ops",

View File

@ -22,7 +22,9 @@ import itertools
from absl.testing import parameterized
import numpy as np
import scipy.linalg
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -35,13 +37,6 @@ from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import test
def _AddTest(test_class, op_name, testcase_name, fn):
test_name = "_".join(["test", op_name, testcase_name])
if hasattr(test_class, test_name):
raise RuntimeError("Test %s defined more than once" % test_name)
setattr(test_class, test_name, fn)
def _RandomPDMatrix(n, rng, dtype=np.float64):
"""Random positive definite matrix."""
temp = rng.randn(n, n).astype(dtype)
@ -562,5 +557,98 @@ class LUSolveDynamic(test.TestCase, _LUSolve):
use_static_shape = False
@test_util.run_all_in_graph_and_eager_modes
class EighTridiagonalTest(test.TestCase):
# This op is rather slow in pure eager mode, so we force it to be
# a function here.
@def_function.function
def run_eigh_tridiagonal(self, alpha, beta, **kwargs):
return linalg.eigh_tridiagonal(alpha, beta, **kwargs)
def run_test(self, alpha, beta):
n = alpha.shape[0]
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
# this we call the slower numpy.linalg.eigh.
if np.issubdtype(alpha.dtype, np.complexfloating):
tridiagonal = np.diag(alpha) + np.diag(beta, 1) + np.diag(
np.conj(beta), -1)
eigvals_expected, _ = np.linalg.eigh(tridiagonal)
else:
eigvals_expected = scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
eigvals = self.run_eigh_tridiagonal(alpha, beta)
eps = np.finfo(alpha.dtype).eps
atol = 2 * n * eps * np.amax(np.abs(eigvals_expected))
self.assertAllClose(eigvals_expected, eigvals, atol=atol)
def run_toeplitz_tests(self, dtype):
for n in [1, 2, 3, 7, 8, 100]:
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
alpha = a * np.ones([n], dtype=dtype)
beta = b * np.ones([n - 1], dtype=dtype)
if np.issubdtype(alpha.dtype, np.complexfloating):
beta += 1j * beta
self.run_test(alpha, beta)
def run_random_uniform_tests(self, dtype):
for n in [2, 3, 7, 8, 100]:
alpha = np.random.uniform(size=(n,)).astype(dtype)
beta = np.random.uniform(size=(n - 1,)).astype(dtype)
if np.issubdtype(alpha.dtype, np.complexfloating):
beta += 1j * beta
self.run_test(alpha, beta)
def run_select_tests(self, dtype):
n = 5
alpha = np.random.uniform(size=(n,)).astype(dtype)
beta = np.random.uniform(size=(n - 1,)).astype(dtype)
eigvals_all = self.run_eigh_tridiagonal(alpha, beta, select="a")
eps = np.finfo(alpha.dtype).eps
atol = 2 * n * eps
for first in range(n - 1):
for last in range(first + 1, n - 1):
# Check that we get the expected eigenvalues by selecting by
# index range.
eigvals_index = self.run_eigh_tridiagonal(
alpha, beta, select="i", select_range=(first, last))
self.assertAllClose(
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
# Check that we get the expected eigenvalues by selecting by
# value range.
eigvals_value = self.run_eigh_tridiagonal(
alpha,
beta,
select="v",
select_range=(eigvals_all[first], eigvals_all[last]))
self.assertAllClose(
eigvals_all[first:(last + 1)], eigvals_value, atol=atol)
def test_float32(self):
dtype = np.float32
self.run_toeplitz_tests(dtype)
self.run_random_uniform_tests(dtype)
self.run_select_tests(dtype)
def test_float64(self):
dtype = np.float64
self.run_toeplitz_tests(dtype)
self.run_random_uniform_tests(dtype)
self.run_select_tests(dtype)
def test_complex64(self):
dtype = np.complex64
self.run_toeplitz_tests(dtype)
self.run_random_uniform_tests(dtype)
self.run_select_tests(dtype)
def test_complex128(self):
dtype = np.complex128
self.run_toeplitz_tests(dtype)
self.run_random_uniform_tests(dtype)
self.run_select_tests(dtype)
if __name__ == "__main__":
test.main()

View File

@ -1205,3 +1205,228 @@ def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
message=message))
return assertions
@tf_export('linalg.eigh_tridiagonal')
@dispatch.add_dispatch_support
def eigh_tridiagonal(alpha,
beta,
eigvals_only=True,
select='a',
select_range=None,
tol=None,
name=None):
"""Computes the eigenvalues of a Hermitian tridiagonal matrix.
Args:
alpha: A real or complex tensor of shape (n), the diagonal elements of the
matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed
zero) to satisfy the requirement that the matrix be Hermitian.
beta: A real or complex tensor of shape (n-1), containing the elements of
the first super-diagonal of the matrix. If beta is complex, the first
sub-diagonal of the matrix is assumed to be the conjugate of beta to
satisfy the requirement that the matrix be Hermitian
eigvals_only: If False, both eigenvalues and corresponding eigenvectors are
computed. If True, only eigenvalues are computed. Default is True.
select: Optional string with values in {a, v, i} (default is 'a') that
determines which eigenvalues to calculate:
'a': all eigenvalues.
v: eigenvalues in the interval (min, max] given by select range.
'i: eigenvalues with indices min <= i <= max.
select_range: Size 2 tuple or list or tensor specifying the range of
eigenvalues to compute together with select. If select is 'a',
select_range is ignored.
tol: Optional scalar. The absolute tolerance to which each eigenvalue is
required. An eigenvalue (or cluster) is considered to have converged if it
lies in an interval of this width. If tol is None (default), the value
eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the
2-norm of the matrix T.
name: Optional name of the op.
Returns:
eig_vals: The eigenvalues of the matrix in non-decreasing order.
eig_vectors: If `eigvals_only` is False the eigenvectors are returned in
the second output argument.
Raises:
ValueError: If input values are invalid.
NotImplemented: Computing eigenvectors for `eigvals_only` = False is
not implemented yet.
This op implements a subset of the functionality of
scipy.linalg.eigh_tridiagonal.
TODO(b/187527398):
a) Complete scipy.linalg.compatibility:
1. Add support for computing eigenvectors.
b) Add support for outer batch dimensions.
#### Examples
```python
import numpy
eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0])
eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)]
tf.assert_near(eigvals_expected, eigvals)
# ==> True
```
"""
with ops.name_scope(name or 'eigh_tridiagonal'):
def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
"""Implements the Sturm sequence recurrence."""
n = alpha.shape[0]
zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32)
ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32)
# The first step in the Sturm sequence recurrence
# requires special care if x is equal to alpha[0].
def sturm_step0():
q = alpha[0] - x
count = array_ops.where(q < 0, ones, zeros)
q = array_ops.where(math_ops.equal(alpha[0], x), alpha0_perturbation, q)
return q, count
# Subsequent steps all take this form:
def sturm_step(i, q, count):
q = alpha[i] - beta_sq[i - 1] / q - x
count = array_ops.where(q <= pivmin, count + 1, count)
q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q)
return q, count
# The first step initializes q and count.
q, count = sturm_step0()
# Peel off ((n-1) % blocksize) steps from the main loop, so we can run
# the bulk of the iterations unrolled by a factor of blocksize.
blocksize = 16
i = 1
peel = (n - 1) % blocksize
unroll_cnt = peel
def unrolled_steps(start, q, count):
for j in range(unroll_cnt):
q, count = sturm_step(start + j, q, count)
return start + unroll_cnt, q, count
i, q, count = unrolled_steps(i, q, count)
# Run the remaining steps of the Sturm sequence using a partially
# unrolled while loop.
unroll_cnt = blocksize
cond = lambda i, q, count: math_ops.less(i, n)
_, _, count = control_flow_ops.while_loop(
cond, unrolled_steps, [i, q, count], back_prop=False)
return count
if not eigvals_only:
raise NotImplementedError(
'`eigvals_only` = False is not implemented yet.')
alpha = ops.convert_to_tensor(alpha, name='alpha')
n = alpha.shape[0]
if n <= 1:
return math_ops.real(alpha)
beta = ops.convert_to_tensor(beta, name='beta')
if alpha.dtype != beta.dtype:
raise ValueError("'alpha' and 'beta' must have the same type.")
if alpha.dtype.is_complex:
alpha = math_ops.real(alpha)
beta_sq = math_ops.real(math_ops.conj(beta) * beta)
beta_abs = math_ops.sqrt(beta_sq)
else:
beta_sq = math_ops.square(beta)
beta_abs = math_ops.abs(beta)
# Estimate the largest and smallest eigenvalues of T using the Gershgorin
# circle theorem.
off_diag_abs_row_sum = array_ops.concat(
[beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
lambda_est_max = math_ops.reduce_max(alpha + off_diag_abs_row_sum)
lambda_est_min = math_ops.reduce_min(alpha - off_diag_abs_row_sum)
# Upper bound on 2-norm of T.
t_norm = math_ops.maximum(
math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max))
# Compute the smallest allowed pivot in the Sturm sequence to avoid
# overflow.
finfo = np.finfo(alpha.dtype.as_numpy_dtype)
one = np.ones([], dtype=alpha.dtype.as_numpy_dtype)
safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq))
alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0])
abs_tol = finfo.eps * t_norm
if tol:
abs_tol = math_ops.maximum(tol, abs_tol)
# In the worst case, when the absolute tolerance is eps*lambda_est_max and
# lambda_est_max = -lambda_est_min, we have to take as many bisection steps
# as there are bits in the mantissa plus 1.
max_it = finfo.nmant + 1
# Determine the indices of the desired eigenvalues, based on select and
# select_range.
asserts = None
if select == 'a':
target_counts = math_ops.range(n)
elif select == 'i':
asserts = check_ops.assert_less_equal(
select_range[0],
select_range[1],
message='Got empty index range in select_range.')
target_counts = math_ops.range(select_range[0], select_range[1] + 1)
elif select == 'v':
asserts = check_ops.assert_less(
select_range[0],
select_range[1],
message='Got empty interval in select_range.')
else:
raise ValueError("'select must have a value in {'a', 'i', 'v'}.")
if asserts:
with ops.control_dependencies([asserts]):
alpha = array_ops.identity(alpha)
# Run binary search for all desired eigenvalues in parallel, starting from
# the interval lightly wider than the estimated
# [lambda_est_min, lambda_est_max].
fudge = 2.1 # We widen starting interval the Gershgorin interval a bit.
norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm
if select in {'a', 'i'}:
lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
upper = lambda_est_max + norm_slack + fudge * pivmin
else:
# Count the number of eigenvalues in the given range.
lower = select_range[0] - norm_slack - 2 * fudge * pivmin
upper = select_range[1] + norm_slack + fudge * pivmin
first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower)
last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper)
target_counts = math_ops.range(first, last)
# Pre-broadcast the scalars used in the Sturm sequence for improved
# performance.
target_shape = array_ops.shape(target_counts)
lower = array_ops.broadcast_to(lower, shape=target_shape)
upper = array_ops.broadcast_to(upper, shape=target_shape)
mid = 0.5 * (upper + lower)
pivmin = array_ops.broadcast_to(pivmin, target_shape)
alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation,
target_shape)
# Start parallel binary searches.
def cond(i, lower, _, upper):
return math_ops.logical_and(
math_ops.less(i, max_it),
math_ops.less(abs_tol, math_ops.reduce_max(upper - lower)))
def body(i, lower, mid, upper):
counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
lower = array_ops.where(counts <= target_counts, mid, lower)
upper = array_ops.where(counts > target_counts, mid, upper)
mid = 0.5 * (lower + upper)
return i + 1, lower, mid, upper
_, _, mid, _ = control_flow_ops.while_loop(cond, body,
[0, lower, mid, upper])
return mid

View File

@ -124,6 +124,10 @@ tf_module {
name: "eigh"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "eigh_tridiagonal"
argspec: "args=[\'alpha\', \'beta\', \'eigvals_only\', \'select\', \'select_range\', \'tol\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'a\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "eigvalsh"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -132,6 +132,10 @@ tf_module {
name: "eigh"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "eigh_tridiagonal"
argspec: "args=[\'alpha\', \'beta\', \'eigvals_only\', \'select\', \'select_range\', \'tol\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'a\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "eigvals"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "