diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 8f8aac253ed..f0d51161c08 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2907,6 +2907,21 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { + // TODO: uncomment the following when svd is deprecated not only in docs + // torch/xla is blocking the transition from at::svd to at::linalg_svd in at::linalg_pinv code + // see https://github.com/pytorch/xla/issues/2755 + // TORCH_WARN_ONCE( + // "torch.svd is deprecated in favor of torch.linalg.svd and will be ", + // "removed in a future PyTorch release.\n", + // "U, S, V = torch.svd(A, some=some, compute_uv=True) (default)\n", + // "should be replaced with\n", + // "U, S, Vh = torch.linalg.svd(A, full_matrices=not some)\n", + // "V = Vh.transpose(-2, -1).conj()\n", + // "and\n", + // "_, S, _ = torch.svd(A, some=some, compute_uv=False)\n", + // "should be replaced with\n", + // "S = torch.linalg.svdvals(A)"); + TORCH_CHECK(self.dim() >= 2, "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); @@ -2923,7 +2938,7 @@ std::tuple svd_out(const Tensor& self, bool some, boo checkLinalgCompatibleDtype("svd", S.scalar_type(), real_dtype, "S"); Tensor U_tmp, S_tmp, V_tmp; - std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); + std::tie(U_tmp, S_tmp, V_tmp) = at::native::svd(self, some, compute_uv); at::native::resize_output(U, U_tmp.sizes()); at::native::resize_output(S, S_tmp.sizes()); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c69dec0625b..3ffe47afb1a 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -347,9 +347,7 @@ static Tensor& linalg_matrix_rank_out_helper(const Tensor& input, const Tensor& // that are above max(atol, rtol * max(S)) threshold Tensor S, max_S; if (!hermitian) { - Tensor U, V; - // TODO: replace input.svd with linalg_svd - std::tie(U, S, V) = input.svd(/*some=*/true, /*compute_uv=*/false); + S = at::linalg_svdvals(input); // singular values are sorted in descending order max_S = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); } else { @@ -2171,7 +2169,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); - result_ = std::get<1>(self_.permute(permutation).svd()).abs(); + result_ = at::linalg_svdvals(self_.permute(permutation)); result_ = _norm_min_max(result_, ord, result_.dim() - 1, keepdim); if (keepdim) { diff --git a/test/test_linalg.py b/test/test_linalg.py index e8b3d68d9a1..c7175e57705 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7498,7 +7498,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A # actual rank is known only for dense input detect_rank = (s.abs() > 1e-5).sum(axis=-1) self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank) - U, S, V = torch.svd(A2) + S = torch.linalg.svdvals(A2) self.assertEqual(s[..., :actual_rank], S[..., :actual_rank]) all_batches = [(), (1,), (3,), (2, 3)] diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 33982f17846..060935e66ab 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -98,10 +98,10 @@ def svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, .. note:: The input is assumed to be a low-rank matrix. .. note:: In general, use the full-rank SVD implementation - ``torch.svd`` for dense matrices due to its 10-fold + :func:`torch.linalg.svd` for dense matrices due to its 10-fold higher performance characteristics. The low-rank SVD will be useful for huge sparse matrices that - ``torch.svd`` cannot handle. + :func:`torch.linalg.svd` cannot handle. Args:: A (Tensor): the input tensor of size :math:`(*, m, n)` @@ -156,7 +156,8 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, assert B_t.shape[-2] == m, (B_t.shape, m) assert B_t.shape[-1] == q, (B_t.shape, q) assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, V = torch.svd(B_t) + U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) + V = Vh.conj().transpose(-2, -1) V = Q.matmul(V) else: Q = get_approximate_basis(A, q, niter=niter, M=M) @@ -169,7 +170,8 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, assert B_t.shape[-2] == q, (B_t.shape, q) assert B_t.shape[-1] == n, (B_t.shape, n) assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, V = torch.svd(B_t) + U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) + V = Vh.conj().transpose(-2, -1) U = Q.matmul(U) return U, S, V diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0a255e33437..c35e3fb5078 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8506,9 +8506,23 @@ Supports :attr:`input` of float, double, cfloat and cdouble data types. The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will always be real-valued, even if :attr:`input` is complex. -.. warning:: :func:`torch.svd` is deprecated. Please use - :func:`torch.linalg.svd` instead, which is similar to NumPy's - `numpy.linalg.svd`. +.. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.transpose(-2, -1).conj() + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.svdvals(A) .. note:: Differences with :func:`torch.linalg.svd`: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index dd75b9dccad..7b87e4357f0 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2492,8 +2492,9 @@ Tensor linalg_qr_backward(const std::vector &grads, c // http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) { auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { - Tensor u, sigma, v; - std::tie(u, sigma, v) = self.svd(); + Tensor u, sigma, vh; + std::tie(u, sigma, vh) = at::linalg_svd(self, false); + Tensor v = vh.conj().transpose(-2, -1); auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1)); return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; @@ -2545,8 +2546,9 @@ Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) { auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { - Tensor u, sigma, v; - std::tie(u, sigma, v) = self.svd(); + Tensor u, sigma, vh; + std::tie(u, sigma, vh) = at::linalg_svd(self, false); + Tensor v = vh.conj().transpose(-2, -1); // logdet = \sum log(sigma) auto gsigma = grad.unsqueeze(-1).div(sigma); return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); @@ -2600,9 +2602,9 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, const Tensor& self, const Tensor& signdet, const Tensor& logabsdet) { auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { - Tensor u, sigma, v; - // TODO: replace self.svd with linalg_svd - std::tie(u, sigma, v) = self.svd(); + Tensor u, sigma, vh; + std::tie(u, sigma, vh) = at::linalg_svd(self, false); + Tensor v = vh.conj().transpose(-2, -1); // sigma has all non-negative entries (also with at least one zero entry) // so logabsdet = \sum log(abs(sigma)) // but det = 0, so backward logabsdet = \sum log(sigma) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f2778a41208..ce1b144961c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1825,13 +1825,12 @@ def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): # Creates matrices with a positive nonzero determinant def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs): def make_nonzero_det(A, *, sign=1, min_singular_value=0.1, **kwargs): - u, s, v = A.svd() + u, s, vh = torch.linalg.svd(A, full_matrices=False) s.clamp_(min=min_singular_value) - A = torch.matmul(u, torch.matmul(torch.diag_embed(s), v.transpose(-2, -1))) + A = (u * s.unsqueeze(-2)) @ vh det = A.det() if sign is not None: if A.dim() == 2: - det = det.item() if (det < 0) ^ (sign < 0): A[0, :].neg_() else: diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index c5db6ada80c..ce944ac5c08 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1782,13 +1782,13 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): assert rank <= l A = torch.randn(l, l, dtype=dtype, device=device) - u, s, v = A.svd() + u, s, vh = torch.linalg.svd(A, full_matrices=False) for i in range(l): if i >= rank: s[i] = 0 elif s[i] == 0: s[i] = 1 - return u.mm(torch.diag(s).to(dtype)).mm(v.transpose(0, 1)) + return (u * s.to(dtype).unsqueeze(-2)) @ vh def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001): """ @@ -1807,10 +1807,10 @@ def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001) x = torch.rand(shape, dtype=dtype, device=device) m = x.size(-2) n = x.size(-1) - u, _, v = x.svd() + u, _, vh = torch.linalg.svd(x, full_matrices=False) s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \ .sort(-1, descending=True).values.to(dtype) - return (u * s.unsqueeze(-2)) @ v.transpose(-2, -1).conj() + return (u * s.unsqueeze(-2)) @ vh # TODO: remove this (prefer make_symmetric_matrices below) def random_symmetric_matrix(l, *batches, **kwargs): @@ -1896,10 +1896,10 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, return torch.ones(matrix_size, matrix_size, dtype=dtype, device=device) A = torch.randn(batch_dims + (matrix_size, matrix_size), dtype=dtype, device=device) - u, _, v = A.svd() + u, _, vh = torch.linalg.svd(A, full_matrices=False) real_dtype = A.real.dtype if A.dtype.is_complex else A.dtype - s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1)).diag() - return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).to(A.dtype).matmul(v.transpose(-2, -1))) + s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1)) + return (u * s.to(A.dtype)) @ vh # Creates a full rank matrix with distinct signular values or @@ -1908,12 +1908,11 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype): assert shape[-1] == shape[-2] t = make_tensor(shape, device=device, dtype=dtype) - u, _, v = t.svd() + u, _, vh = torch.linalg.svd(t, full_matrices=False) # TODO: improve the handling of complex tensors here real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype - s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1)).diag() - u.matmul(s.expand(*shape).to(t.dtype).matmul(v.transpose(-2, -1))) - return t + s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1)) + return (u * s.to(dtype)) @ vh def random_matrix(rows, columns, *batch_dims, **kwargs): @@ -1932,19 +1931,17 @@ def random_matrix(rows, columns, *batch_dims, **kwargs): return torch.ones(rows, columns, dtype=dtype, device=device) A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device) - u, _, v = A.svd(some=False) - s = torch.zeros(rows, columns, dtype=dtype, device=device) + u, _, vh = torch.linalg.svd(A, full_matrices=False) k = min(rows, columns) - for i in range(k): - s[i, i] = float(i + 1) / (k + 1) + s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device) if singular: # make matrix singular - s[k - 1, k - 1] = 0 + s[k - 1] = 0 if k > 2: # increase the order of singularity so that the pivoting # in LU factorization will be non-trivial - s[0, 0] = 0 - return u.matmul(s.expand(batch_dims + (rows, columns)).matmul(v.transpose(-2, -1))) + s[0] = 0 + return (u * s.unsqueeze(-2)) @ vh def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):