Deprecate symeig (#57732)

Summary:
This one had a tricky usage of `torch.symeig` that had to be replaced. I tested the replacement locally though.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57732

Reviewed By: bdhirsh

Differential Revision: D28328189

Pulled By: mruberry

fbshipit-source-id: 7f000fcbf2b029beabc76e5a89ff158b47977474
This commit is contained in:
lezcano 2021-05-12 02:20:31 -07:00 committed by Facebook GitHub Bot
parent e18f5f1d13
commit db13119fc4
8 changed files with 56 additions and 17 deletions

View File

@ -2383,11 +2383,37 @@ std::tuple<Tensor, Tensor> _symeig_helper_cpu(const Tensor& self, bool eigenvect
}
std::tuple<Tensor, Tensor> symeig(const Tensor& self, bool eigenvectors, bool upper) {
TORCH_WARN_ONCE(
"torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future ",
"PyTorch release.\n",
"The default behavior has changed from using the upper triangular portion of the matrix by default ",
"to using the lower triangular portion.\n",
"L, _ = torch.symeig(A, upper=upper)\n",
"should be replaced with\n",
"L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n",
"and\n",
"L, V = torch.symeig(A, eigenvectors=True)\n"
"should be replaced with\n",
"L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
);
squareCheckInputs(self);
return at::_symeig_helper(self, eigenvectors, upper);
}
std::tuple<Tensor&, Tensor&> symeig_out(const Tensor& self, bool eigenvectors, bool upper, Tensor& vals, Tensor& vecs) {
TORCH_WARN_ONCE(
"torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future ",
"PyTorch release.\n",
"The default behavior has changed from using the upper triangular portion of the matrix by default ",
"to using the lower triangular portion.\n",
"L, _ = torch.symeig(A, upper=upper)\n",
"should be replaced with\n",
"L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n",
"and\n",
"L, V = torch.symeig(A, eigenvectors=True)\n"
"should be replaced with\n",
"L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
);
checkSameDevice("symeig", vals, self, "eigenvalues");
checkSameDevice("symeig", vecs, self, "eigenvectors");
checkLinalgCompatibleDtype("symeig", vecs, self, "eigenvectors");

View File

@ -3730,7 +3730,7 @@ class TestAutograd(TestCase):
def test_symeig_no_eigenvectors(self):
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
w, v = torch.symeig(A, eigenvectors=False)
with self.assertRaisesRegex(RuntimeError, 'cannot compute backward'):
with self.assertRaisesRegex(RuntimeError, 'is not differentiable'):
torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)])
@skipIfNoLapack

View File

@ -7463,8 +7463,6 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
with warnings.catch_warnings(record=True) as w:
# Trigger warning
torch.symeig(a, out=(out_w, out_v))
# Check warning occurs
self.assertEqual(len(w), 2)
self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))

View File

@ -88,14 +88,12 @@ def basis(A):
return Q
def symeig(A: Tensor, largest: Optional[bool] = False, eigenvectors: Optional[bool] = True) -> Tuple[Tensor, Tensor]:
def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]:
"""Return eigenpairs of A with specified ordering.
"""
if largest is None:
largest = False
if eigenvectors is None:
eigenvectors = True
E, Z = torch.symeig(A, eigenvectors, True)
E, Z = torch.linalg.eigh(A, UPLO='U')
# assuming that E is ordered
if largest:
E = torch.flip(E, dims=(-1,))

View File

@ -998,7 +998,7 @@ class LOBPCG(object):
# The original algorithm 4 from [DuerschPhD2015].
d_col = (d ** -0.5).reshape(d.shape[0], 1)
DUBUD = (UBU * d_col) * _utils.transpose(d_col)
E, Z = _utils.symeig(DUBUD, eigenvectors=True)
E, Z = _utils.symeig(DUBUD)
t = tau * abs(E).max()
if drop:
keep = torch.where(E > t)

View File

@ -8721,6 +8721,27 @@ only the upper triangular portion is used by default.
If :attr:`upper` is ``False``, then lower triangular portion is used.
.. warning::
:func:`torch.symeig` is deprecated in favor of :func:`torch.linalg.eigh`
and will be removed in a future PyTorch release. The default behavior has changed
from using the upper triangular portion of the matrix by default to using the
lower triangular portion.
``L, _ = torch.symeig(A, upper=upper)`` should be replaced with
.. code :: python
UPLO = "U" if upper else "L"
L = torch.linalg.eigvalsh(A, UPLO=UPLO)
``L, V = torch.symeig(A, eigenvectors=True, upper=upper)`` should be replaced with
.. code :: python
UPLO = "U" if upper else "L"
L, V = torch.linalg.eigh(A, UPLO=UPLO)
.. note:: The eigenvalues are returned in ascending order. If :attr:`input` is a batch of matrices,
then the eigenvalues of each matrix in the batch is returned in ascending order.

View File

@ -2339,8 +2339,8 @@ Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, cons
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
TORCH_CHECK(eigenvectors,
"symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
"and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");
"symeig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ",
"Use torch.linalg.eigvalsh(A) instead.");
auto glambda = grads[0];
auto gv = grads[1];

View File

@ -463,13 +463,9 @@ class _PositiveDefinite(Constraint):
event_dim = 2
def check(self, value):
matrix_shape = value.shape[-2:]
batch_shape = value.unsqueeze(0).shape[:-2]
# TODO: replace with batched linear algebra routine when one becomes available
# note that `symeig()` returns eigenvalues in ascending order
flattened_value = value.reshape((-1,) + matrix_shape)
return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
for v in flattened_value]).view(batch_shape)
# Assumes that the matrix or batch of matrices in value are symmetric
# info == 0 means no error, that is, it's SPD
return torch.linalg.cholesky_ex(value).info.eq(0).unsqueeze(0)
class _Cat(Constraint):