mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Deprecate in docs torch.svd and change svd -> linalg_svd (#57981)
Summary: This PR adds a note to the documentation that torch.svd is deprecated together with an upgrade guide on how to use `torch.linalg.svd` and `torch.linalg.svdvals` (Lezcano's instructions from https://github.com/pytorch/pytorch/issues/57549). In addition, all usage of the old svd function is replaced with a new one from torch.linalg module, except for the `at::linalg_pinv` function, that fails the XLA CI build (https://github.com/pytorch/xla/issues/2755, see failure in draft PR https://github.com/pytorch/pytorch/pull/57772). Pull Request resolved: https://github.com/pytorch/pytorch/pull/57981 Reviewed By: ngimel Differential Revision: D28345558 Pulled By: mruberry fbshipit-source-id: 02dd9ae6efe975026e80ca128e9b91dfc65d7213
This commit is contained in:
parent
e573987bea
commit
aaca12bcc2
|
|
@ -2907,6 +2907,21 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
|
|||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compute_uv) {
|
||||
// TODO: uncomment the following when svd is deprecated not only in docs
|
||||
// torch/xla is blocking the transition from at::svd to at::linalg_svd in at::linalg_pinv code
|
||||
// see https://github.com/pytorch/xla/issues/2755
|
||||
// TORCH_WARN_ONCE(
|
||||
// "torch.svd is deprecated in favor of torch.linalg.svd and will be ",
|
||||
// "removed in a future PyTorch release.\n",
|
||||
// "U, S, V = torch.svd(A, some=some, compute_uv=True) (default)\n",
|
||||
// "should be replaced with\n",
|
||||
// "U, S, Vh = torch.linalg.svd(A, full_matrices=not some)\n",
|
||||
// "V = Vh.transpose(-2, -1).conj()\n",
|
||||
// "and\n",
|
||||
// "_, S, _ = torch.svd(A, some=some, compute_uv=False)\n",
|
||||
// "should be replaced with\n",
|
||||
// "S = torch.linalg.svdvals(A)");
|
||||
|
||||
TORCH_CHECK(self.dim() >= 2,
|
||||
"svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
|
||||
return at::_svd_helper(self, some, compute_uv);
|
||||
|
|
@ -2923,7 +2938,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, boo
|
|||
checkLinalgCompatibleDtype("svd", S.scalar_type(), real_dtype, "S");
|
||||
|
||||
Tensor U_tmp, S_tmp, V_tmp;
|
||||
std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv);
|
||||
std::tie(U_tmp, S_tmp, V_tmp) = at::native::svd(self, some, compute_uv);
|
||||
|
||||
at::native::resize_output(U, U_tmp.sizes());
|
||||
at::native::resize_output(S, S_tmp.sizes());
|
||||
|
|
|
|||
|
|
@ -347,9 +347,7 @@ static Tensor& linalg_matrix_rank_out_helper(const Tensor& input, const Tensor&
|
|||
// that are above max(atol, rtol * max(S)) threshold
|
||||
Tensor S, max_S;
|
||||
if (!hermitian) {
|
||||
Tensor U, V;
|
||||
// TODO: replace input.svd with linalg_svd
|
||||
std::tie(U, S, V) = input.svd(/*some=*/true, /*compute_uv=*/false);
|
||||
S = at::linalg_svdvals(input);
|
||||
// singular values are sorted in descending order
|
||||
max_S = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1);
|
||||
} else {
|
||||
|
|
@ -2171,7 +2169,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const
|
|||
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
|
||||
auto permutation_reverse = create_reverse_permutation(permutation);
|
||||
|
||||
result_ = std::get<1>(self_.permute(permutation).svd()).abs();
|
||||
result_ = at::linalg_svdvals(self_.permute(permutation));
|
||||
result_ = _norm_min_max(result_, ord, result_.dim() - 1, keepdim);
|
||||
|
||||
if (keepdim) {
|
||||
|
|
|
|||
|
|
@ -7498,7 +7498,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
|||
# actual rank is known only for dense input
|
||||
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
|
||||
self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank)
|
||||
U, S, V = torch.svd(A2)
|
||||
S = torch.linalg.svdvals(A2)
|
||||
self.assertEqual(s[..., :actual_rank], S[..., :actual_rank])
|
||||
|
||||
all_batches = [(), (1,), (3,), (2, 3)]
|
||||
|
|
|
|||
|
|
@ -98,10 +98,10 @@ def svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
|
|||
.. note:: The input is assumed to be a low-rank matrix.
|
||||
|
||||
.. note:: In general, use the full-rank SVD implementation
|
||||
``torch.svd`` for dense matrices due to its 10-fold
|
||||
:func:`torch.linalg.svd` for dense matrices due to its 10-fold
|
||||
higher performance characteristics. The low-rank SVD
|
||||
will be useful for huge sparse matrices that
|
||||
``torch.svd`` cannot handle.
|
||||
:func:`torch.linalg.svd` cannot handle.
|
||||
|
||||
Args::
|
||||
A (Tensor): the input tensor of size :math:`(*, m, n)`
|
||||
|
|
@ -156,7 +156,8 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
|
|||
assert B_t.shape[-2] == m, (B_t.shape, m)
|
||||
assert B_t.shape[-1] == q, (B_t.shape, q)
|
||||
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
|
||||
U, S, V = torch.svd(B_t)
|
||||
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
|
||||
V = Vh.conj().transpose(-2, -1)
|
||||
V = Q.matmul(V)
|
||||
else:
|
||||
Q = get_approximate_basis(A, q, niter=niter, M=M)
|
||||
|
|
@ -169,7 +170,8 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
|
|||
assert B_t.shape[-2] == q, (B_t.shape, q)
|
||||
assert B_t.shape[-1] == n, (B_t.shape, n)
|
||||
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
|
||||
U, S, V = torch.svd(B_t)
|
||||
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
|
||||
V = Vh.conj().transpose(-2, -1)
|
||||
U = Q.matmul(U)
|
||||
|
||||
return U, S, V
|
||||
|
|
|
|||
|
|
@ -8506,9 +8506,23 @@ Supports :attr:`input` of float, double, cfloat and cdouble data types.
|
|||
The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will
|
||||
always be real-valued, even if :attr:`input` is complex.
|
||||
|
||||
.. warning:: :func:`torch.svd` is deprecated. Please use
|
||||
:func:`torch.linalg.svd` instead, which is similar to NumPy's
|
||||
`numpy.linalg.svd`.
|
||||
.. warning::
|
||||
|
||||
:func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd`
|
||||
and will be removed in a future PyTorch release.
|
||||
|
||||
``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with
|
||||
|
||||
.. code:: python
|
||||
|
||||
U, S, Vh = torch.linalg.svd(A, full_matrices=not some)
|
||||
V = Vh.transpose(-2, -1).conj()
|
||||
|
||||
``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with
|
||||
|
||||
.. code:: python
|
||||
|
||||
S = torch.svdvals(A)
|
||||
|
||||
.. note:: Differences with :func:`torch.linalg.svd`:
|
||||
|
||||
|
|
|
|||
|
|
@ -2492,8 +2492,9 @@ Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, c
|
|||
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
|
||||
Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
|
||||
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
|
||||
Tensor u, sigma, v;
|
||||
std::tie(u, sigma, v) = self.svd();
|
||||
Tensor u, sigma, vh;
|
||||
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
|
||||
Tensor v = vh.conj().transpose(-2, -1);
|
||||
auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1));
|
||||
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
|
||||
};
|
||||
|
|
@ -2545,8 +2546,9 @@ Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor
|
|||
|
||||
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
|
||||
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
|
||||
Tensor u, sigma, v;
|
||||
std::tie(u, sigma, v) = self.svd();
|
||||
Tensor u, sigma, vh;
|
||||
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
|
||||
Tensor v = vh.conj().transpose(-2, -1);
|
||||
// logdet = \sum log(sigma)
|
||||
auto gsigma = grad.unsqueeze(-1).div(sigma);
|
||||
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
|
||||
|
|
@ -2600,9 +2602,9 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet,
|
|||
const Tensor& self,
|
||||
const Tensor& signdet, const Tensor& logabsdet) {
|
||||
auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
|
||||
Tensor u, sigma, v;
|
||||
// TODO: replace self.svd with linalg_svd
|
||||
std::tie(u, sigma, v) = self.svd();
|
||||
Tensor u, sigma, vh;
|
||||
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
|
||||
Tensor v = vh.conj().transpose(-2, -1);
|
||||
// sigma has all non-negative entries (also with at least one zero entry)
|
||||
// so logabsdet = \sum log(abs(sigma))
|
||||
// but det = 0, so backward logabsdet = \sum log(sigma)
|
||||
|
|
|
|||
|
|
@ -1825,13 +1825,12 @@ def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
|
|||
# Creates matrices with a positive nonzero determinant
|
||||
def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs):
|
||||
def make_nonzero_det(A, *, sign=1, min_singular_value=0.1, **kwargs):
|
||||
u, s, v = A.svd()
|
||||
u, s, vh = torch.linalg.svd(A, full_matrices=False)
|
||||
s.clamp_(min=min_singular_value)
|
||||
A = torch.matmul(u, torch.matmul(torch.diag_embed(s), v.transpose(-2, -1)))
|
||||
A = (u * s.unsqueeze(-2)) @ vh
|
||||
det = A.det()
|
||||
if sign is not None:
|
||||
if A.dim() == 2:
|
||||
det = det.item()
|
||||
if (det < 0) ^ (sign < 0):
|
||||
A[0, :].neg_()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1782,13 +1782,13 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig
|
|||
def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
|
||||
assert rank <= l
|
||||
A = torch.randn(l, l, dtype=dtype, device=device)
|
||||
u, s, v = A.svd()
|
||||
u, s, vh = torch.linalg.svd(A, full_matrices=False)
|
||||
for i in range(l):
|
||||
if i >= rank:
|
||||
s[i] = 0
|
||||
elif s[i] == 0:
|
||||
s[i] = 1
|
||||
return u.mm(torch.diag(s).to(dtype)).mm(v.transpose(0, 1))
|
||||
return (u * s.to(dtype).unsqueeze(-2)) @ vh
|
||||
|
||||
def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001):
|
||||
"""
|
||||
|
|
@ -1807,10 +1807,10 @@ def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001)
|
|||
x = torch.rand(shape, dtype=dtype, device=device)
|
||||
m = x.size(-2)
|
||||
n = x.size(-1)
|
||||
u, _, v = x.svd()
|
||||
u, _, vh = torch.linalg.svd(x, full_matrices=False)
|
||||
s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \
|
||||
.sort(-1, descending=True).values.to(dtype)
|
||||
return (u * s.unsqueeze(-2)) @ v.transpose(-2, -1).conj()
|
||||
return (u * s.unsqueeze(-2)) @ vh
|
||||
|
||||
# TODO: remove this (prefer make_symmetric_matrices below)
|
||||
def random_symmetric_matrix(l, *batches, **kwargs):
|
||||
|
|
@ -1896,10 +1896,10 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims,
|
|||
return torch.ones(matrix_size, matrix_size, dtype=dtype, device=device)
|
||||
|
||||
A = torch.randn(batch_dims + (matrix_size, matrix_size), dtype=dtype, device=device)
|
||||
u, _, v = A.svd()
|
||||
u, _, vh = torch.linalg.svd(A, full_matrices=False)
|
||||
real_dtype = A.real.dtype if A.dtype.is_complex else A.dtype
|
||||
s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1)).diag()
|
||||
return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).to(A.dtype).matmul(v.transpose(-2, -1)))
|
||||
s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1))
|
||||
return (u * s.to(A.dtype)) @ vh
|
||||
|
||||
|
||||
# Creates a full rank matrix with distinct signular values or
|
||||
|
|
@ -1908,12 +1908,11 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims,
|
|||
def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype):
|
||||
assert shape[-1] == shape[-2]
|
||||
t = make_tensor(shape, device=device, dtype=dtype)
|
||||
u, _, v = t.svd()
|
||||
u, _, vh = torch.linalg.svd(t, full_matrices=False)
|
||||
# TODO: improve the handling of complex tensors here
|
||||
real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
|
||||
s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1)).diag()
|
||||
u.matmul(s.expand(*shape).to(t.dtype).matmul(v.transpose(-2, -1)))
|
||||
return t
|
||||
s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1))
|
||||
return (u * s.to(dtype)) @ vh
|
||||
|
||||
|
||||
def random_matrix(rows, columns, *batch_dims, **kwargs):
|
||||
|
|
@ -1932,19 +1931,17 @@ def random_matrix(rows, columns, *batch_dims, **kwargs):
|
|||
return torch.ones(rows, columns, dtype=dtype, device=device)
|
||||
|
||||
A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device)
|
||||
u, _, v = A.svd(some=False)
|
||||
s = torch.zeros(rows, columns, dtype=dtype, device=device)
|
||||
u, _, vh = torch.linalg.svd(A, full_matrices=False)
|
||||
k = min(rows, columns)
|
||||
for i in range(k):
|
||||
s[i, i] = float(i + 1) / (k + 1)
|
||||
s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)
|
||||
if singular:
|
||||
# make matrix singular
|
||||
s[k - 1, k - 1] = 0
|
||||
s[k - 1] = 0
|
||||
if k > 2:
|
||||
# increase the order of singularity so that the pivoting
|
||||
# in LU factorization will be non-trivial
|
||||
s[0, 0] = 0
|
||||
return u.matmul(s.expand(batch_dims + (rows, columns)).matmul(v.transpose(-2, -1)))
|
||||
s[0] = 0
|
||||
return (u * s.unsqueeze(-2)) @ vh
|
||||
|
||||
|
||||
def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user