diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index f0d51161c08..ddd6bb80345 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2383,11 +2383,37 @@ std::tuple _symeig_helper_cpu(const Tensor& self, bool eigenvect } std::tuple symeig(const Tensor& self, bool eigenvectors, bool upper) { + TORCH_WARN_ONCE( + "torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future ", + "PyTorch release.\n", + "The default behavior has changed from using the upper triangular portion of the matrix by default ", + "to using the lower triangular portion.\n", + "L, _ = torch.symeig(A, upper=upper)\n", + "should be replaced with\n", + "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n", + "and\n", + "L, V = torch.symeig(A, eigenvectors=True)\n" + "should be replaced with\n", + "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')" + ); squareCheckInputs(self); return at::_symeig_helper(self, eigenvectors, upper); } std::tuple symeig_out(const Tensor& self, bool eigenvectors, bool upper, Tensor& vals, Tensor& vecs) { + TORCH_WARN_ONCE( + "torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future ", + "PyTorch release.\n", + "The default behavior has changed from using the upper triangular portion of the matrix by default ", + "to using the lower triangular portion.\n", + "L, _ = torch.symeig(A, upper=upper)\n", + "should be replaced with\n", + "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n", + "and\n", + "L, V = torch.symeig(A, eigenvectors=True)\n" + "should be replaced with\n", + "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')" + ); checkSameDevice("symeig", vals, self, "eigenvalues"); checkSameDevice("symeig", vecs, self, "eigenvectors"); checkLinalgCompatibleDtype("symeig", vecs, self, "eigenvectors"); diff --git a/test/test_autograd.py b/test/test_autograd.py index 91d369b5552..43d83a647cc 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3730,7 +3730,7 @@ class TestAutograd(TestCase): def test_symeig_no_eigenvectors(self): A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True) w, v = torch.symeig(A, eigenvectors=False) - with self.assertRaisesRegex(RuntimeError, 'cannot compute backward'): + with self.assertRaisesRegex(RuntimeError, 'is not differentiable'): torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)]) @skipIfNoLapack diff --git a/test/test_linalg.py b/test/test_linalg.py index e48d51f51f0..a595eb77e90 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7463,8 +7463,6 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A with warnings.catch_warnings(record=True) as w: # Trigger warning torch.symeig(a, out=(out_w, out_v)) - # Check warning occurs - self.assertEqual(len(w), 2) self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 2d166e24134..568ae8b74aa 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -88,14 +88,12 @@ def basis(A): return Q -def symeig(A: Tensor, largest: Optional[bool] = False, eigenvectors: Optional[bool] = True) -> Tuple[Tensor, Tensor]: +def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: """Return eigenpairs of A with specified ordering. """ if largest is None: largest = False - if eigenvectors is None: - eigenvectors = True - E, Z = torch.symeig(A, eigenvectors, True) + E, Z = torch.linalg.eigh(A, UPLO='U') # assuming that E is ordered if largest: E = torch.flip(E, dims=(-1,)) diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 586741c728c..135ff146d91 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -998,7 +998,7 @@ class LOBPCG(object): # The original algorithm 4 from [DuerschPhD2015]. d_col = (d ** -0.5).reshape(d.shape[0], 1) DUBUD = (UBU * d_col) * _utils.transpose(d_col) - E, Z = _utils.symeig(DUBUD, eigenvectors=True) + E, Z = _utils.symeig(DUBUD) t = tau * abs(E).max() if drop: keep = torch.where(E > t) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9b266c96775..418f325f13f 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8721,6 +8721,27 @@ only the upper triangular portion is used by default. If :attr:`upper` is ``False``, then lower triangular portion is used. +.. warning:: + + :func:`torch.symeig` is deprecated in favor of :func:`torch.linalg.eigh` + and will be removed in a future PyTorch release. The default behavior has changed + from using the upper triangular portion of the matrix by default to using the + lower triangular portion. + + ``L, _ = torch.symeig(A, upper=upper)`` should be replaced with + + .. code :: python + + UPLO = "U" if upper else "L" + L = torch.linalg.eigvalsh(A, UPLO=UPLO) + + ``L, V = torch.symeig(A, eigenvectors=True, upper=upper)`` should be replaced with + + .. code :: python + + UPLO = "U" if upper else "L" + L, V = torch.linalg.eigh(A, UPLO=UPLO) + .. note:: The eigenvalues are returned in ascending order. If :attr:`input` is a batch of matrices, then the eigenvalues of each matrix in the batch is returned in ascending order. diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 0c5e1ed41fe..78fab3d0400 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2339,8 +2339,8 @@ Tensor symeig_backward(const std::vector &grads, cons // leads to stable gradient updates, and retains symmetry of the updated matrix if it // were updated by a gradient based algorithm. TORCH_CHECK(eigenvectors, - "symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ", - "and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)"); + "symeig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ", + "Use torch.linalg.eigvalsh(A) instead."); auto glambda = grads[0]; auto gv = grads[1]; diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index fb05d17ac27..99808b6b80b 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -463,13 +463,9 @@ class _PositiveDefinite(Constraint): event_dim = 2 def check(self, value): - matrix_shape = value.shape[-2:] - batch_shape = value.unsqueeze(0).shape[:-2] - # TODO: replace with batched linear algebra routine when one becomes available - # note that `symeig()` returns eigenvalues in ascending order - flattened_value = value.reshape((-1,) + matrix_shape) - return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0 - for v in flattened_value]).view(batch_shape) + # Assumes that the matrix or batch of matrices in value are symmetric + # info == 0 means no error, that is, it's SPD + return torch.linalg.cholesky_ex(value).info.eq(0).unsqueeze(0) class _Cat(Constraint):