Deprecate in docs torch.svd and change svd -> linalg_svd (#57981)

Summary:
This PR adds a note to the documentation that torch.svd is deprecated together with an upgrade guide on how to use `torch.linalg.svd` and `torch.linalg.svdvals` (Lezcano's instructions from https://github.com/pytorch/pytorch/issues/57549).
In addition, all usage of the old svd function is replaced with a new one from torch.linalg module, except for the `at::linalg_pinv` function, that fails the XLA CI build (https://github.com/pytorch/xla/issues/2755, see failure in draft PR https://github.com/pytorch/pytorch/pull/57772).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57981

Reviewed By: ngimel

Differential Revision: D28345558

Pulled By: mruberry

fbshipit-source-id: 02dd9ae6efe975026e80ca128e9b91dfc65d7213
This commit is contained in:
Ivan Yashchuk 2021-05-11 18:03:02 -07:00 committed by Facebook GitHub Bot
parent e573987bea
commit aaca12bcc2
8 changed files with 68 additions and 41 deletions

View File

@ -2907,6 +2907,21 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
}
std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compute_uv) {
// TODO: uncomment the following when svd is deprecated not only in docs
// torch/xla is blocking the transition from at::svd to at::linalg_svd in at::linalg_pinv code
// see https://github.com/pytorch/xla/issues/2755
// TORCH_WARN_ONCE(
// "torch.svd is deprecated in favor of torch.linalg.svd and will be ",
// "removed in a future PyTorch release.\n",
// "U, S, V = torch.svd(A, some=some, compute_uv=True) (default)\n",
// "should be replaced with\n",
// "U, S, Vh = torch.linalg.svd(A, full_matrices=not some)\n",
// "V = Vh.transpose(-2, -1).conj()\n",
// "and\n",
// "_, S, _ = torch.svd(A, some=some, compute_uv=False)\n",
// "should be replaced with\n",
// "S = torch.linalg.svdvals(A)");
TORCH_CHECK(self.dim() >= 2,
"svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
return at::_svd_helper(self, some, compute_uv);
@ -2923,7 +2938,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, boo
checkLinalgCompatibleDtype("svd", S.scalar_type(), real_dtype, "S");
Tensor U_tmp, S_tmp, V_tmp;
std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv);
std::tie(U_tmp, S_tmp, V_tmp) = at::native::svd(self, some, compute_uv);
at::native::resize_output(U, U_tmp.sizes());
at::native::resize_output(S, S_tmp.sizes());

View File

@ -347,9 +347,7 @@ static Tensor& linalg_matrix_rank_out_helper(const Tensor& input, const Tensor&
// that are above max(atol, rtol * max(S)) threshold
Tensor S, max_S;
if (!hermitian) {
Tensor U, V;
// TODO: replace input.svd with linalg_svd
std::tie(U, S, V) = input.svd(/*some=*/true, /*compute_uv=*/false);
S = at::linalg_svdvals(input);
// singular values are sorted in descending order
max_S = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1);
} else {
@ -2171,7 +2169,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permutation);
result_ = std::get<1>(self_.permute(permutation).svd()).abs();
result_ = at::linalg_svdvals(self_.permute(permutation));
result_ = _norm_min_max(result_, ord, result_.dim() - 1, keepdim);
if (keepdim) {

View File

@ -7498,7 +7498,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
# actual rank is known only for dense input
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank)
U, S, V = torch.svd(A2)
S = torch.linalg.svdvals(A2)
self.assertEqual(s[..., :actual_rank], S[..., :actual_rank])
all_batches = [(), (1,), (3,), (2, 3)]

View File

@ -98,10 +98,10 @@ def svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
.. note:: The input is assumed to be a low-rank matrix.
.. note:: In general, use the full-rank SVD implementation
``torch.svd`` for dense matrices due to its 10-fold
:func:`torch.linalg.svd` for dense matrices due to its 10-fold
higher performance characteristics. The low-rank SVD
will be useful for huge sparse matrices that
``torch.svd`` cannot handle.
:func:`torch.linalg.svd` cannot handle.
Args::
A (Tensor): the input tensor of size :math:`(*, m, n)`
@ -156,7 +156,8 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, V = torch.svd(B_t)
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
V = Vh.conj().transpose(-2, -1)
V = Q.matmul(V)
else:
Q = get_approximate_basis(A, q, niter=niter, M=M)
@ -169,7 +170,8 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, V = torch.svd(B_t)
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
V = Vh.conj().transpose(-2, -1)
U = Q.matmul(U)
return U, S, V

View File

@ -8506,9 +8506,23 @@ Supports :attr:`input` of float, double, cfloat and cdouble data types.
The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will
always be real-valued, even if :attr:`input` is complex.
.. warning:: :func:`torch.svd` is deprecated. Please use
:func:`torch.linalg.svd` instead, which is similar to NumPy's
`numpy.linalg.svd`.
.. warning::
:func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd`
and will be removed in a future PyTorch release.
``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with
.. code:: python
U, S, Vh = torch.linalg.svd(A, full_matrices=not some)
V = Vh.transpose(-2, -1).conj()
``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with
.. code:: python
S = torch.svdvals(A)
.. note:: Differences with :func:`torch.linalg.svd`:

View File

@ -2492,8 +2492,9 @@ Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, c
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
Tensor u, sigma, vh;
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
Tensor v = vh.conj().transpose(-2, -1);
auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1));
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
@ -2545,8 +2546,9 @@ Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
Tensor u, sigma, vh;
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
Tensor v = vh.conj().transpose(-2, -1);
// logdet = \sum log(sigma)
auto gsigma = grad.unsqueeze(-1).div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
@ -2600,9 +2602,9 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet,
const Tensor& self,
const Tensor& signdet, const Tensor& logabsdet) {
auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
Tensor u, sigma, v;
// TODO: replace self.svd with linalg_svd
std::tie(u, sigma, v) = self.svd();
Tensor u, sigma, vh;
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
Tensor v = vh.conj().transpose(-2, -1);
// sigma has all non-negative entries (also with at least one zero entry)
// so logabsdet = \sum log(abs(sigma))
// but det = 0, so backward logabsdet = \sum log(sigma)

View File

@ -1825,13 +1825,12 @@ def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
# Creates matrices with a positive nonzero determinant
def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs):
def make_nonzero_det(A, *, sign=1, min_singular_value=0.1, **kwargs):
u, s, v = A.svd()
u, s, vh = torch.linalg.svd(A, full_matrices=False)
s.clamp_(min=min_singular_value)
A = torch.matmul(u, torch.matmul(torch.diag_embed(s), v.transpose(-2, -1)))
A = (u * s.unsqueeze(-2)) @ vh
det = A.det()
if sign is not None:
if A.dim() == 2:
det = det.item()
if (det < 0) ^ (sign < 0):
A[0, :].neg_()
else:

View File

@ -1782,13 +1782,13 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig
def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
assert rank <= l
A = torch.randn(l, l, dtype=dtype, device=device)
u, s, v = A.svd()
u, s, vh = torch.linalg.svd(A, full_matrices=False)
for i in range(l):
if i >= rank:
s[i] = 0
elif s[i] == 0:
s[i] = 1
return u.mm(torch.diag(s).to(dtype)).mm(v.transpose(0, 1))
return (u * s.to(dtype).unsqueeze(-2)) @ vh
def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001):
"""
@ -1807,10 +1807,10 @@ def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001)
x = torch.rand(shape, dtype=dtype, device=device)
m = x.size(-2)
n = x.size(-1)
u, _, v = x.svd()
u, _, vh = torch.linalg.svd(x, full_matrices=False)
s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \
.sort(-1, descending=True).values.to(dtype)
return (u * s.unsqueeze(-2)) @ v.transpose(-2, -1).conj()
return (u * s.unsqueeze(-2)) @ vh
# TODO: remove this (prefer make_symmetric_matrices below)
def random_symmetric_matrix(l, *batches, **kwargs):
@ -1896,10 +1896,10 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims,
return torch.ones(matrix_size, matrix_size, dtype=dtype, device=device)
A = torch.randn(batch_dims + (matrix_size, matrix_size), dtype=dtype, device=device)
u, _, v = A.svd()
u, _, vh = torch.linalg.svd(A, full_matrices=False)
real_dtype = A.real.dtype if A.dtype.is_complex else A.dtype
s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1)).diag()
return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).to(A.dtype).matmul(v.transpose(-2, -1)))
s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1))
return (u * s.to(A.dtype)) @ vh
# Creates a full rank matrix with distinct signular values or
@ -1908,12 +1908,11 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims,
def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype):
assert shape[-1] == shape[-2]
t = make_tensor(shape, device=device, dtype=dtype)
u, _, v = t.svd()
u, _, vh = torch.linalg.svd(t, full_matrices=False)
# TODO: improve the handling of complex tensors here
real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1)).diag()
u.matmul(s.expand(*shape).to(t.dtype).matmul(v.transpose(-2, -1)))
return t
s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1))
return (u * s.to(dtype)) @ vh
def random_matrix(rows, columns, *batch_dims, **kwargs):
@ -1932,19 +1931,17 @@ def random_matrix(rows, columns, *batch_dims, **kwargs):
return torch.ones(rows, columns, dtype=dtype, device=device)
A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device)
u, _, v = A.svd(some=False)
s = torch.zeros(rows, columns, dtype=dtype, device=device)
u, _, vh = torch.linalg.svd(A, full_matrices=False)
k = min(rows, columns)
for i in range(k):
s[i, i] = float(i + 1) / (k + 1)
s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)
if singular:
# make matrix singular
s[k - 1, k - 1] = 0
s[k - 1] = 0
if k > 2:
# increase the order of singularity so that the pivoting
# in LU factorization will be non-trivial
s[0, 0] = 0
return u.matmul(s.expand(batch_dims + (rows, columns)).matmul(v.transpose(-2, -1)))
s[0] = 0
return (u * s.unsqueeze(-2)) @ vh
def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):