mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add tf.linalg.eigh_tridiagonal, which computes the eigenvalues of a Hermitian tridiagonal matrix.
PiperOrigin-RevId: 372601772 Change-Id: Idc6471cd851a88b17ab657a567dd018768a1cd8d
This commit is contained in:
parent
5d637e92bf
commit
65219fa1b1
|
|
@ -52,10 +52,12 @@
|
|||
lower overall memory usage, and a cleaner API. It does not require
|
||||
specifying a `delete_key` and `empty_key` that cannot be inserted into
|
||||
the table.
|
||||
* Added support for specifying number of subdivisions in all reduce host
|
||||
* Added support for specifying number of subdivisions in all reduce host
|
||||
collective. This parallelizes work on CPU and speeds up the collective
|
||||
performance. Default behavior is unchanged.
|
||||
* SavedModel
|
||||
* Added `tf.linalg.eigh_tridiagonal` that computes the eigenvalues of a
|
||||
Hermitian tridiagonal matrix.
|
||||
* SavedModel
|
||||
* Added `tf.saved_model.experimental.TrackableResource`, which allows
|
||||
the creation of custom wrapper objects for resource tensors.
|
||||
* Added a SavedModel load option to allow restoring partial
|
||||
|
|
|
|||
|
|
@ -2188,6 +2188,7 @@ cuda_py_test(
|
|||
name = "linalg_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["linalg_ops_test.py"],
|
||||
shard_count = 4,
|
||||
tags = ["no_windows_gpu"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ import itertools
|
|||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
|
@ -35,13 +37,6 @@ from tensorflow.python.ops.linalg import linalg
|
|||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _AddTest(test_class, op_name, testcase_name, fn):
|
||||
test_name = "_".join(["test", op_name, testcase_name])
|
||||
if hasattr(test_class, test_name):
|
||||
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||
setattr(test_class, test_name, fn)
|
||||
|
||||
|
||||
def _RandomPDMatrix(n, rng, dtype=np.float64):
|
||||
"""Random positive definite matrix."""
|
||||
temp = rng.randn(n, n).astype(dtype)
|
||||
|
|
@ -562,5 +557,98 @@ class LUSolveDynamic(test.TestCase, _LUSolve):
|
|||
use_static_shape = False
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class EighTridiagonalTest(test.TestCase):
|
||||
|
||||
# This op is rather slow in pure eager mode, so we force it to be
|
||||
# a function here.
|
||||
@def_function.function
|
||||
def run_eigh_tridiagonal(self, alpha, beta, **kwargs):
|
||||
return linalg.eigh_tridiagonal(alpha, beta, **kwargs)
|
||||
|
||||
def run_test(self, alpha, beta):
|
||||
n = alpha.shape[0]
|
||||
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
|
||||
# this we call the slower numpy.linalg.eigh.
|
||||
if np.issubdtype(alpha.dtype, np.complexfloating):
|
||||
tridiagonal = np.diag(alpha) + np.diag(beta, 1) + np.diag(
|
||||
np.conj(beta), -1)
|
||||
eigvals_expected, _ = np.linalg.eigh(tridiagonal)
|
||||
else:
|
||||
eigvals_expected = scipy.linalg.eigh_tridiagonal(
|
||||
alpha, beta, eigvals_only=True)
|
||||
eigvals = self.run_eigh_tridiagonal(alpha, beta)
|
||||
eps = np.finfo(alpha.dtype).eps
|
||||
atol = 2 * n * eps * np.amax(np.abs(eigvals_expected))
|
||||
self.assertAllClose(eigvals_expected, eigvals, atol=atol)
|
||||
|
||||
def run_toeplitz_tests(self, dtype):
|
||||
for n in [1, 2, 3, 7, 8, 100]:
|
||||
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
|
||||
alpha = a * np.ones([n], dtype=dtype)
|
||||
beta = b * np.ones([n - 1], dtype=dtype)
|
||||
if np.issubdtype(alpha.dtype, np.complexfloating):
|
||||
beta += 1j * beta
|
||||
self.run_test(alpha, beta)
|
||||
|
||||
def run_random_uniform_tests(self, dtype):
|
||||
for n in [2, 3, 7, 8, 100]:
|
||||
alpha = np.random.uniform(size=(n,)).astype(dtype)
|
||||
beta = np.random.uniform(size=(n - 1,)).astype(dtype)
|
||||
if np.issubdtype(alpha.dtype, np.complexfloating):
|
||||
beta += 1j * beta
|
||||
self.run_test(alpha, beta)
|
||||
|
||||
def run_select_tests(self, dtype):
|
||||
n = 5
|
||||
alpha = np.random.uniform(size=(n,)).astype(dtype)
|
||||
beta = np.random.uniform(size=(n - 1,)).astype(dtype)
|
||||
eigvals_all = self.run_eigh_tridiagonal(alpha, beta, select="a")
|
||||
eps = np.finfo(alpha.dtype).eps
|
||||
atol = 2 * n * eps
|
||||
for first in range(n - 1):
|
||||
for last in range(first + 1, n - 1):
|
||||
# Check that we get the expected eigenvalues by selecting by
|
||||
# index range.
|
||||
eigvals_index = self.run_eigh_tridiagonal(
|
||||
alpha, beta, select="i", select_range=(first, last))
|
||||
self.assertAllClose(
|
||||
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
|
||||
|
||||
# Check that we get the expected eigenvalues by selecting by
|
||||
# value range.
|
||||
eigvals_value = self.run_eigh_tridiagonal(
|
||||
alpha,
|
||||
beta,
|
||||
select="v",
|
||||
select_range=(eigvals_all[first], eigvals_all[last]))
|
||||
self.assertAllClose(
|
||||
eigvals_all[first:(last + 1)], eigvals_value, atol=atol)
|
||||
|
||||
def test_float32(self):
|
||||
dtype = np.float32
|
||||
self.run_toeplitz_tests(dtype)
|
||||
self.run_random_uniform_tests(dtype)
|
||||
self.run_select_tests(dtype)
|
||||
|
||||
def test_float64(self):
|
||||
dtype = np.float64
|
||||
self.run_toeplitz_tests(dtype)
|
||||
self.run_random_uniform_tests(dtype)
|
||||
self.run_select_tests(dtype)
|
||||
|
||||
def test_complex64(self):
|
||||
dtype = np.complex64
|
||||
self.run_toeplitz_tests(dtype)
|
||||
self.run_random_uniform_tests(dtype)
|
||||
self.run_select_tests(dtype)
|
||||
|
||||
def test_complex128(self):
|
||||
dtype = np.complex128
|
||||
self.run_toeplitz_tests(dtype)
|
||||
self.run_random_uniform_tests(dtype)
|
||||
self.run_select_tests(dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -1205,3 +1205,228 @@ def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
|
|||
message=message))
|
||||
|
||||
return assertions
|
||||
|
||||
|
||||
@tf_export('linalg.eigh_tridiagonal')
|
||||
@dispatch.add_dispatch_support
|
||||
def eigh_tridiagonal(alpha,
|
||||
beta,
|
||||
eigvals_only=True,
|
||||
select='a',
|
||||
select_range=None,
|
||||
tol=None,
|
||||
name=None):
|
||||
"""Computes the eigenvalues of a Hermitian tridiagonal matrix.
|
||||
|
||||
Args:
|
||||
alpha: A real or complex tensor of shape (n), the diagonal elements of the
|
||||
matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed
|
||||
zero) to satisfy the requirement that the matrix be Hermitian.
|
||||
beta: A real or complex tensor of shape (n-1), containing the elements of
|
||||
the first super-diagonal of the matrix. If beta is complex, the first
|
||||
sub-diagonal of the matrix is assumed to be the conjugate of beta to
|
||||
satisfy the requirement that the matrix be Hermitian
|
||||
eigvals_only: If False, both eigenvalues and corresponding eigenvectors are
|
||||
computed. If True, only eigenvalues are computed. Default is True.
|
||||
select: Optional string with values in {‘a’, ‘v’, ‘i’} (default is 'a') that
|
||||
determines which eigenvalues to calculate:
|
||||
'a': all eigenvalues.
|
||||
‘v’: eigenvalues in the interval (min, max] given by select range.
|
||||
'i’: eigenvalues with indices min <= i <= max.
|
||||
select_range: Size 2 tuple or list or tensor specifying the range of
|
||||
eigenvalues to compute together with select. If select is 'a',
|
||||
select_range is ignored.
|
||||
tol: Optional scalar. The absolute tolerance to which each eigenvalue is
|
||||
required. An eigenvalue (or cluster) is considered to have converged if it
|
||||
lies in an interval of this width. If tol is None (default), the value
|
||||
eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the
|
||||
2-norm of the matrix T.
|
||||
name: Optional name of the op.
|
||||
|
||||
Returns:
|
||||
eig_vals: The eigenvalues of the matrix in non-decreasing order.
|
||||
eig_vectors: If `eigvals_only` is False the eigenvectors are returned in
|
||||
the second output argument.
|
||||
|
||||
Raises:
|
||||
ValueError: If input values are invalid.
|
||||
NotImplemented: Computing eigenvectors for `eigvals_only` = False is
|
||||
not implemented yet.
|
||||
|
||||
This op implements a subset of the functionality of
|
||||
scipy.linalg.eigh_tridiagonal.
|
||||
|
||||
TODO(b/187527398):
|
||||
a) Complete scipy.linalg.compatibility:
|
||||
1. Add support for computing eigenvectors.
|
||||
b) Add support for outer batch dimensions.
|
||||
|
||||
#### Examples
|
||||
|
||||
```python
|
||||
import numpy
|
||||
eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0])
|
||||
eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)]
|
||||
tf.assert_near(eigvals_expected, eigvals)
|
||||
# ==> True
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.name_scope(name or 'eigh_tridiagonal'):
|
||||
|
||||
def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
|
||||
"""Implements the Sturm sequence recurrence."""
|
||||
n = alpha.shape[0]
|
||||
zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32)
|
||||
ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32)
|
||||
|
||||
# The first step in the Sturm sequence recurrence
|
||||
# requires special care if x is equal to alpha[0].
|
||||
def sturm_step0():
|
||||
q = alpha[0] - x
|
||||
count = array_ops.where(q < 0, ones, zeros)
|
||||
q = array_ops.where(math_ops.equal(alpha[0], x), alpha0_perturbation, q)
|
||||
return q, count
|
||||
|
||||
# Subsequent steps all take this form:
|
||||
def sturm_step(i, q, count):
|
||||
q = alpha[i] - beta_sq[i - 1] / q - x
|
||||
count = array_ops.where(q <= pivmin, count + 1, count)
|
||||
q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q)
|
||||
return q, count
|
||||
|
||||
# The first step initializes q and count.
|
||||
q, count = sturm_step0()
|
||||
|
||||
# Peel off ((n-1) % blocksize) steps from the main loop, so we can run
|
||||
# the bulk of the iterations unrolled by a factor of blocksize.
|
||||
blocksize = 16
|
||||
i = 1
|
||||
peel = (n - 1) % blocksize
|
||||
unroll_cnt = peel
|
||||
|
||||
def unrolled_steps(start, q, count):
|
||||
for j in range(unroll_cnt):
|
||||
q, count = sturm_step(start + j, q, count)
|
||||
return start + unroll_cnt, q, count
|
||||
|
||||
i, q, count = unrolled_steps(i, q, count)
|
||||
|
||||
# Run the remaining steps of the Sturm sequence using a partially
|
||||
# unrolled while loop.
|
||||
unroll_cnt = blocksize
|
||||
cond = lambda i, q, count: math_ops.less(i, n)
|
||||
_, _, count = control_flow_ops.while_loop(
|
||||
cond, unrolled_steps, [i, q, count], back_prop=False)
|
||||
return count
|
||||
|
||||
if not eigvals_only:
|
||||
raise NotImplementedError(
|
||||
'`eigvals_only` = False is not implemented yet.')
|
||||
alpha = ops.convert_to_tensor(alpha, name='alpha')
|
||||
n = alpha.shape[0]
|
||||
if n <= 1:
|
||||
return math_ops.real(alpha)
|
||||
beta = ops.convert_to_tensor(beta, name='beta')
|
||||
|
||||
if alpha.dtype != beta.dtype:
|
||||
raise ValueError("'alpha' and 'beta' must have the same type.")
|
||||
|
||||
if alpha.dtype.is_complex:
|
||||
alpha = math_ops.real(alpha)
|
||||
beta_sq = math_ops.real(math_ops.conj(beta) * beta)
|
||||
beta_abs = math_ops.sqrt(beta_sq)
|
||||
else:
|
||||
beta_sq = math_ops.square(beta)
|
||||
beta_abs = math_ops.abs(beta)
|
||||
|
||||
# Estimate the largest and smallest eigenvalues of T using the Gershgorin
|
||||
# circle theorem.
|
||||
off_diag_abs_row_sum = array_ops.concat(
|
||||
[beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
|
||||
lambda_est_max = math_ops.reduce_max(alpha + off_diag_abs_row_sum)
|
||||
lambda_est_min = math_ops.reduce_min(alpha - off_diag_abs_row_sum)
|
||||
# Upper bound on 2-norm of T.
|
||||
t_norm = math_ops.maximum(
|
||||
math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max))
|
||||
|
||||
# Compute the smallest allowed pivot in the Sturm sequence to avoid
|
||||
# overflow.
|
||||
finfo = np.finfo(alpha.dtype.as_numpy_dtype)
|
||||
one = np.ones([], dtype=alpha.dtype.as_numpy_dtype)
|
||||
safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
|
||||
pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq))
|
||||
alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0])
|
||||
abs_tol = finfo.eps * t_norm
|
||||
if tol:
|
||||
abs_tol = math_ops.maximum(tol, abs_tol)
|
||||
# In the worst case, when the absolute tolerance is eps*lambda_est_max and
|
||||
# lambda_est_max = -lambda_est_min, we have to take as many bisection steps
|
||||
# as there are bits in the mantissa plus 1.
|
||||
max_it = finfo.nmant + 1
|
||||
|
||||
# Determine the indices of the desired eigenvalues, based on select and
|
||||
# select_range.
|
||||
asserts = None
|
||||
if select == 'a':
|
||||
target_counts = math_ops.range(n)
|
||||
elif select == 'i':
|
||||
asserts = check_ops.assert_less_equal(
|
||||
select_range[0],
|
||||
select_range[1],
|
||||
message='Got empty index range in select_range.')
|
||||
target_counts = math_ops.range(select_range[0], select_range[1] + 1)
|
||||
elif select == 'v':
|
||||
asserts = check_ops.assert_less(
|
||||
select_range[0],
|
||||
select_range[1],
|
||||
message='Got empty interval in select_range.')
|
||||
else:
|
||||
raise ValueError("'select must have a value in {'a', 'i', 'v'}.")
|
||||
|
||||
if asserts:
|
||||
with ops.control_dependencies([asserts]):
|
||||
alpha = array_ops.identity(alpha)
|
||||
|
||||
# Run binary search for all desired eigenvalues in parallel, starting from
|
||||
# the interval lightly wider than the estimated
|
||||
# [lambda_est_min, lambda_est_max].
|
||||
fudge = 2.1 # We widen starting interval the Gershgorin interval a bit.
|
||||
norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm
|
||||
if select in {'a', 'i'}:
|
||||
lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
|
||||
upper = lambda_est_max + norm_slack + fudge * pivmin
|
||||
else:
|
||||
# Count the number of eigenvalues in the given range.
|
||||
lower = select_range[0] - norm_slack - 2 * fudge * pivmin
|
||||
upper = select_range[1] + norm_slack + fudge * pivmin
|
||||
first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower)
|
||||
last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper)
|
||||
target_counts = math_ops.range(first, last)
|
||||
|
||||
# Pre-broadcast the scalars used in the Sturm sequence for improved
|
||||
# performance.
|
||||
target_shape = array_ops.shape(target_counts)
|
||||
lower = array_ops.broadcast_to(lower, shape=target_shape)
|
||||
upper = array_ops.broadcast_to(upper, shape=target_shape)
|
||||
mid = 0.5 * (upper + lower)
|
||||
pivmin = array_ops.broadcast_to(pivmin, target_shape)
|
||||
alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation,
|
||||
target_shape)
|
||||
|
||||
# Start parallel binary searches.
|
||||
def cond(i, lower, _, upper):
|
||||
return math_ops.logical_and(
|
||||
math_ops.less(i, max_it),
|
||||
math_ops.less(abs_tol, math_ops.reduce_max(upper - lower)))
|
||||
|
||||
def body(i, lower, mid, upper):
|
||||
counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
|
||||
lower = array_ops.where(counts <= target_counts, mid, lower)
|
||||
upper = array_ops.where(counts > target_counts, mid, upper)
|
||||
mid = 0.5 * (lower + upper)
|
||||
return i + 1, lower, mid, upper
|
||||
|
||||
_, _, mid, _ = control_flow_ops.while_loop(cond, body,
|
||||
[0, lower, mid, upper])
|
||||
return mid
|
||||
|
|
|
|||
|
|
@ -124,6 +124,10 @@ tf_module {
|
|||
name: "eigh"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "eigh_tridiagonal"
|
||||
argspec: "args=[\'alpha\', \'beta\', \'eigvals_only\', \'select\', \'select_range\', \'tol\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'a\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "eigvalsh"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
|||
|
|
@ -132,6 +132,10 @@ tf_module {
|
|||
name: "eigh"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "eigh_tridiagonal"
|
||||
argspec: "args=[\'alpha\', \'beta\', \'eigvals_only\', \'select\', \'select_range\', \'tol\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'a\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "eigvals"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user