Remove random_fullrank_matrix_distinc_singular_value (#68183)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68183

We do so in favour of
`make_fullrank_matrices_with_distinct_singular_values` as this latter
one not only has an even longer name, but also generates inputs
correctly for them to work with the PR that tests noncontig inputs
latter in this stack.

We also heavily simplified the generation of samples for the SVD, as it was
fairly convoluted and it was not generating the inputs correclty for
the noncontiguous test.

To do the transition, we also needed to fix the following issue, as it was popping
up in the tests:

Fixes https://github.com/pytorch/pytorch/issues/66856

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D32684853

Pulled By: mruberry

fbshipit-source-id: e88189c8b67dbf592eccdabaf2aa6d2e2f7b95a4
This commit is contained in:
lezcano 2022-01-05 20:30:12 -08:00 committed by Facebook GitHub Bot
parent 08ef4ae0bc
commit baeca11a21
3 changed files with 131 additions and 187 deletions

View File

@ -17,7 +17,8 @@ from functools import reduce, partial, wraps
from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU,
iter_indices, gradcheck, gradgradcheck)
iter_indices, gradcheck, gradgradcheck,
make_fullrank_matrices_with_distinct_singular_values)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes,
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
@ -3213,7 +3214,8 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_inverse(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device=device, dtype=dtype)
def run_test(torch_inverse, matrix, batches, n):
matrix_inverse = torch_inverse(matrix)
@ -3265,7 +3267,7 @@ class TestLinalg(TestCase):
[[], [0], [2], [2, 1]],
[0, 5]
):
matrices = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype, device=device)
matrices = make_arg(*batches, n, n)
run_test(torch_inverse, matrices, batches, n)
# test non-contiguous input
@ -3273,7 +3275,7 @@ class TestLinalg(TestCase):
if n > 0:
run_test(
torch_inverse,
random_fullrank_matrix_distinct_singular_value(n * 2, *batches, dtype=dtype, device=device)
make_arg(*batches, 2 * n, 2 * n)
.view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n),
batches, n
)
@ -3321,10 +3323,11 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
torch.float64: 1e-5, torch.complex128: 1e-5})
def test_inverse_many_batches(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device=device, dtype=dtype)
def test_inverse_many_batches_helper(torch_inverse, b, n):
matrices = random_fullrank_matrix_distinct_singular_value(b, n, n, dtype=dtype, device=device)
matrices = make_arg(b, n, n)
matrices_inverse = torch_inverse(matrices)
# Compare against NumPy output
@ -3542,10 +3545,11 @@ class TestLinalg(TestCase):
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
def solve_test_helper(self, A_dims, b_dims, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_A = partial(make_fullrank, device=device, dtype=dtype)
b = torch.randn(*b_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
A = make_A(*A_dims)
return b, A
@skipCUDAIfNoMagma
@ -3554,7 +3558,7 @@ class TestLinalg(TestCase):
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
def test_solve(self, device, dtype):
def run_test(n, batch, rhs):
A_dims = (n, *batch)
A_dims = (*batch, n, n)
b_dims = (*batch, n, *rhs)
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
@ -3600,8 +3604,10 @@ class TestLinalg(TestCase):
@dtypes(*floating_and_complex_types())
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
def test_solve_batched_non_contiguous(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device=device).permute(1, 0, 2)
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_A = partial(make_fullrank, device=device, dtype=dtype)
A = make_A(2, 2, 2).permute(1, 0, 2)
b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0)
self.assertFalse(A.is_contiguous())
self.assertFalse(b.is_contiguous())
@ -3680,7 +3686,7 @@ class TestLinalg(TestCase):
@dtypes(*floating_and_complex_types())
def test_old_solve(self, device, dtype):
for (k, n) in zip([2, 3, 5], [3, 5, 7]):
b, A = self.solve_test_helper((n,), (n, k), device, dtype)
b, A = self.solve_test_helper((n, n), (n, k), device, dtype)
x = torch.solve(b, A)[0]
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
@ -3700,15 +3706,18 @@ class TestLinalg(TestCase):
self.assertEqual(b, Ax)
for batchsize in [1, 3, 4]:
solve_batch_helper((5, batchsize), (batchsize, 5, 10))
solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10))
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_old_solve_batched_non_contiguous(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device=device).permute(1, 0, 2)
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_A = partial(make_fullrank, device=device, dtype=dtype)
A = make_A(2, 2, 2).permute(1, 0, 2)
b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0)
x, _ = torch.solve(b, A)
x_exp = solve(A.cpu().numpy(), b.cpu().numpy())
@ -3719,7 +3728,7 @@ class TestLinalg(TestCase):
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_old_solve_batched_many_batches(self, device, dtype):
for A_dims, b_dims in zip([(5, 256, 256), (3, )], [(5, 1), (512, 512, 3, 1)]):
for A_dims, b_dims in zip([(256, 256, 5, 5), (3, 3)], [(5, 1), (512, 512, 3, 1)]):
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
x, _ = torch.solve(b, A)
Ax = torch.matmul(A, x)
@ -3734,7 +3743,7 @@ class TestLinalg(TestCase):
def run_test(A_dims, b_dims):
A_matrix_size = A_dims[-1]
A_batch_dims = A_dims[:-2]
b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype)
b, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), b_dims, device, dtype)
x, _ = torch.solve(b, A)
x_exp = solve(A.cpu().numpy(), b.cpu().numpy())
self.assertEqual(x, x_exp)
@ -4196,26 +4205,27 @@ class TestLinalg(TestCase):
@skipCPUIfNoLapack
@dtypes(torch.float64)
def test_matrix_rank_atol_rtol(self, device, dtype):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device=device, dtype=dtype)
# creates a matrix with singular values arange(1/(n+1), 1, 1/(n+1)) and rank=n
# creates a matrix with singular values rank=n and singular values in range [2/3, 3/2]
# the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ...
n = 9
a = make_fullrank_matrices_with_distinct_singular_values(n, n, dtype=dtype, device=device)
a = make_arg(n, n)
# test float and tensor variants
for tol_value in [0.51, torch.tensor(0.51, device=device)]:
# using rtol (relative tolerance) takes into account the largest singular value (0.9 in this case)
for tol_value in [0.81, torch.tensor(0.81, device=device)]:
# using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case)
result = torch.linalg.matrix_rank(a, rtol=tol_value)
self.assertEqual(result, 5) # there are 5 singular values above 0.9*0.51=0.459
self.assertEqual(result, 2) # there are 2 singular values above 1.5*0.81 = 1.215
# atol is used directly to compare with singular values
result = torch.linalg.matrix_rank(a, atol=tol_value)
self.assertEqual(result, 4) # there are 4 singular values above 0.51
self.assertEqual(result, 7) # there are 7 singular values above 0.81
# when both are specified the maximum tolerance is used
result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value)
self.assertEqual(result, 4) # there are 4 singular values above max(0.51, 0.9*0.51)
self.assertEqual(result, 2) # there are 2 singular values above max(0.81, 1.5*0.81)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@ -6832,7 +6842,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_pinverse(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device=device, dtype=dtype)
def run_test(M):
# Testing against definition for pseudo-inverses
@ -6857,7 +6868,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]:
matsize = sizes[-1]
batchdims = sizes[:-2]
M = fullrank(matsize, *batchdims, dtype=dtype, device=device)
M = make_arg(*batchdims, matsize, matsize)
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
@ -6884,21 +6895,22 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCUDAIfNoMagmaAndNoCusolver
@dtypes(torch.double, torch.cdouble)
def test_matrix_power_negative(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device=device, dtype=dtype)
def check(*size):
t = random_fullrank_matrix_distinct_singular_value(*size, dtype=dtype, device=device)
t = make_arg(*size)
for n in range(-7, 0):
res = torch.linalg.matrix_power(t, n)
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
self.assertEqual(res.cpu(), torch.from_numpy(ref))
check(0)
check(5)
check(0, 2)
check(3, 0)
check(3, 2)
check(5, 2, 3)
check(0, 0)
check(5, 5)
check(2, 0, 0)
check(0, 3, 3)
check(2, 3, 3)
check(2, 3, 5, 5)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@ -7761,9 +7773,10 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_lu_solve_batched_non_contiguous(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_A = partial(make_fullrank, device=device, dtype=dtype)
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device=device)
A = make_A(2, 2, 2)
b = torch.randn(2, 2, 2, dtype=dtype, device=device)
x_exp = np.linalg.solve(A.cpu().permute(0, 2, 1).numpy(), b.cpu().permute(2, 1, 0).numpy())
A = A.permute(0, 2, 1)
@ -7774,10 +7787,11 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
self.assertEqual(x, x_exp)
def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_A = partial(make_fullrank, device=device, dtype=dtype)
b = torch.randn(*b_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
A = make_A(*A_dims)
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
self.assertEqual(info, torch.zeros_like(info))
return b, A, LU_data, LU_pivots
@ -7790,7 +7804,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
def test_lu_solve(self, device, dtype):
def sub_test(pivot):
for k, n in zip([2, 3, 5], [3, 5, 7]):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
@ -7817,7 +7831,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
self.assertEqual(b, Ax)
for batchsize in [1, 3, 4]:
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)
lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot)
# Tests tensors with 0 elements
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
@ -7840,19 +7854,20 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
Ax = torch.matmul(A, x)
self.assertEqual(Ax, b.expand_as(Ax))
run_test((5, 65536), (65536, 5, 10))
run_test((5, 262144), (262144, 5, 10))
run_test((65536, 5, 5), (65536, 5, 10))
run_test((262144, 5, 5), (262144, 5, 10))
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
def test_lu_solve_batched_broadcasting(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_A = partial(make_fullrank, device=device, dtype=dtype)
def run_test(A_dims, b_dims, pivot=True):
A_matrix_size = A_dims[-1]
A_batch_dims = A_dims[:-2]
A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype, device=device)
A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
b = make_tensor(b_dims, dtype=dtype, device=device)
x_exp = np.linalg.solve(A.cpu(), b.cpu())
LU_data, LU_pivots = torch.lu(A, pivot=pivot)

View File

@ -33,7 +33,6 @@ from torch.testing._internal.common_utils import \
make_fullrank_matrices_with_distinct_singular_values,
random_symmetric_pd_matrix, make_symmetric_matrices,
make_symmetric_pd_matrices, random_square_matrix_of_rank,
random_fullrank_matrix_distinct_singular_value,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, slowTest, noncontiguous_like,
@ -1279,7 +1278,7 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs):
yield SampleInput(make_input((S, S, S)), args=args)
def sample_inputs_linalg_det(op_info, device, dtype, requires_grad):
def sample_inputs_linalg_det(op_info, device, dtype, requires_grad, **kwargs):
kw = dict(device=device, dtype=dtype)
inputs = [
make_tensor((S, S), **kw),
@ -1292,13 +1291,13 @@ def sample_inputs_linalg_det(op_info, device, dtype, requires_grad):
random_square_matrix_of_rank(S, 1, **kw), # rank1
random_square_matrix_of_rank(S, 2, **kw), # rank2
random_fullrank_matrix_distinct_singular_value(S, **kw), # distinct_singular_value
make_fullrank_matrices_with_distinct_singular_values(S, S, **kw), # full rank
make_tensor((3, 3, S, S), **kw), # batched
make_tensor((3, 3, 1, 1), **kw), # batched_1x1
random_symmetric_matrix(S, 3, **kw), # batched_symmetric
random_symmetric_psd_matrix(S, 3, **kw), # batched_symmetric_psd
random_symmetric_pd_matrix(S, 3, **kw), # batched_symmetric_pd
random_fullrank_matrix_distinct_singular_value(S, 3, 3, **kw), # batched_distinct_singular_values
make_fullrank_matrices_with_distinct_singular_values(S, 3, 3, **kw), # batched fullrank
make_tensor((0, 0), **kw),
make_tensor((0, S, S), **kw),
]
@ -1344,6 +1343,9 @@ def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad, **k
def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad):
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
make_arg_fullrank = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad)
# (<matrix_size>, (<batch_sizes, ...>))
test_sizes = [
(1, ()),
@ -1351,18 +1353,15 @@ def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad):
(2, (2,)),
]
inputs = []
for matrix_size, batch_sizes in test_sizes:
size = batch_sizes + (matrix_size, matrix_size)
for n in (0, 3, 5):
t = make_tensor(size, device, dtype, requires_grad=requires_grad)
inputs.append(SampleInput(t, args=(n,)))
for n in [-4, -2, -1]:
t = random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_sizes, device=device, dtype=dtype)
t.requires_grad = requires_grad
inputs.append(SampleInput(t, args=(n,)))
def generate_inputs():
for matrix_size, batch_sizes in test_sizes:
size = batch_sizes + (matrix_size, matrix_size)
for n in (0, 3, 5):
yield SampleInput(make_arg(size), args=(n,))
for n in [-4, -2, -1]:
yield SampleInput(make_arg_fullrank(*size), args=(n,))
return inputs
return list(generate_inputs())
def sample_inputs_hsplit(op_info, device, dtype, requires_grad):
return (SampleInput(make_tensor((6,), device, dtype,
@ -2592,8 +2591,7 @@ def sample_inputs_T(self, device, dtype, requires_grad, **kwargs):
def sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad=False, **kwargs):
"""
This function generates always invertible input for linear algebra ops using
random_fullrank_matrix_distinct_singular_value.
This function generates invertible inputs for linear algebra ops
The input is generated as the itertools.product of 'batches' and 'ns'.
In total this function generates 8 SampleInputs
'batches' cases include:
@ -2604,16 +2602,17 @@ def sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad=False,
'ns' gives 0x0 and 5x5 matrices.
Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
"""
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fn = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
batches = [(), (0, ), (2, ), (1, 1)]
ns = [5, 0]
out = []
for batch, n in product(batches, ns):
a = random_fullrank_matrix_distinct_singular_value(n, *batch, dtype=dtype, device=device)
a.requires_grad = requires_grad
out.append(SampleInput(a))
return out
def generate_samples():
for batch, n in product(batches, ns):
yield SampleInput(make_arg(*batch, n, n))
return list(generate_samples())
def sample_inputs_linalg_pinv_singular(op_info, device, dtype, requires_grad=False, **kwargs):
"""
@ -5198,7 +5197,7 @@ def sample_inputs_linalg_pinv_hermitian(op_info, device, dtype, requires_grad=Fa
def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs):
"""
This function generates always solvable input for torch.linalg.solve
Using random_fullrank_matrix_distinct_singular_value gives a non-singular (=invertible, =solvable) matrices 'a'.
We sample a fullrank square matrix (i.e. invertible) A
The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'.
The second input is generated as the product of 'batches', 'ns' and 'nrhs'.
In total this function generates 18 SampleInputs
@ -5218,7 +5217,9 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False, vect
Once torch.solve / triangular_solve / cholesky_solve and its testing are removed,
'vector_rhs_allowed' may be removed here as well.
"""
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_a = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad)
make_b = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
batches = [(), (0, ), (2, )]
ns = [5, 0]
@ -5226,14 +5227,12 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False, vect
nrhs = [(), (1,), (3,)]
else:
nrhs = [(1,), (3,)]
out = []
for n, batch, rhs in product(ns, batches, nrhs):
a = random_fullrank_matrix_distinct_singular_value(n, *batch, dtype=dtype, device=device)
a.requires_grad = requires_grad
b = torch.randn(*batch, n, *rhs, dtype=dtype, device=device)
b.requires_grad = requires_grad
out.append(SampleInput(a, args=(b,)))
return out
def generate_samples():
for n, batch, rhs in product(ns, batches, nrhs):
yield SampleInput(make_a(*batch, n, n), args=(make_b((batch + (n,) + rhs)),))
return list(generate_samples())
def sample_inputs_linalg_solve_triangular(op_info, device, dtype, requires_grad=False, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device)
@ -5450,93 +5449,34 @@ def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
return inputs
def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False):
"""
This function generates input for torch.svd with distinct singular values so that autograd is always stable.
Matrices of different size:
square matrix - S x S size
tall marix - S x (S-2)
wide matrix - (S-2) x S
and batched variants of above are generated.
Each SampleInput has a function 'output_process_fn_grad' attached to it that is applied on the output of torch.svd
It is needed for autograd checks, because backward of svd doesn't work for an arbitrary loss function.
"""
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad)
# svd and linalg.svd returns V and V.conj().T, respectively. So we need to slice
# along different dimensions when needed (this is used by
# test_cases2:wide_all and wide_all_batched below)
if is_linalg_svd:
def slice_V(v):
return v[..., :(S - 2), :]
is_linalg_svd = (op_info.name == "linalg.svd")
def uv_loss(usv):
u00 = usv[0][0, 0]
v00_conj = usv[2][0, 0]
return u00 * v00_conj
else:
def slice_V(v):
return v[..., :, :(S - 2)]
batches = [(), (0, ), (2, )]
ns = [0, 2, 5]
def uv_loss(usv):
u00 = usv[0][0, 0]
v00_conj = usv[2][0, 0].conj()
return u00 * v00_conj
# The .abs() is to make these functions invariant under multiplication by e^{i\theta}
# since the complex SVD is unique up to multiplication of the columns of U / V a z \in C with \norm{z} = 1
# See the docs of torch.linalg.svd for more info
def check_grads(usv):
S = usv[1]
k = S.shape[-1]
U = usv[0][..., :k]
Vh = usv[2] if is_linalg_svd else usv[2].mH
Vh = Vh[..., :k, :]
return (U.abs(), S, Vh.abs())
test_cases1 = ( # some=True (default)
# loss functions for complex-valued svd have to be "gauge invariant",
# i.e. loss functions shouldn't change when sigh of the singular vectors change.
# the simplest choice to satisfy this requirement is to apply 'abs'.
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device),
lambda usv: usv[1]), # 'check_grad_s'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device),
lambda usv: abs(usv[0])), # 'check_grad_u'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device),
lambda usv: abs(usv[2])), # 'check_grad_v'
# this test is important as it checks the additional term that is non-zero only for complex-valued inputs
# and when the loss function depends both on 'u' and 'v'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device),
uv_loss), # 'check_grad_uv'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device)[:(S - 2)],
lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device)[:, :(S - 2)],
lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'tall'
(random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype, device=device),
lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'batched'
(random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype, device=device)[..., :(S - 2), :],
lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'wide_batched'
(random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype, device=device)[..., :, :(S - 2)],
lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'tall_batched'
)
test_cases2 = ( # some=False
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device)[:(S - 2)],
lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype, device=device)[:, :(S - 2)],
lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all'
(random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype, device=device)[..., :(S - 2), :],
lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all_batched'
(random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype, device=device)[..., :, :(S - 2)],
lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched'
)
fullmat = 'full_matrices' if is_linalg_svd else 'some'
out = []
for a, out_fn in test_cases1:
a.requires_grad = requires_grad
if is_linalg_svd:
kwargs = {'full_matrices': False}
else:
kwargs = {'some': True}
out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn))
def generate_inputs():
for batch, n, k, fullmat_val in product(batches, ns, ns, (True, False)):
shape = batch + (n, k)
yield SampleInput(make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=check_grads)
for a, out_fn in test_cases2:
a.requires_grad = requires_grad
if is_linalg_svd:
kwargs = {'full_matrices': True}
else:
kwargs = {'some': False}
out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn))
return out
return list(generate_inputs())
def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs):
@ -5610,12 +5550,6 @@ def sample_inputs_pow(op_info, device, dtype, requires_grad, **kwargs):
args=(make_arg((2, 2), requires_grad=requires_grad),)))
return tuple(samples)
def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=False)
def sample_inputs_linalg_svd(op_info, device, dtype, requires_grad=False, **kwargs):
return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=True)
def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs):
batches = [(), (0, ), (2, ), (1, 1)]
ns = [5, 2, 0]
@ -12737,6 +12671,7 @@ op_db: List[OpInfo] = [
op=torch.svd,
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_svd,
check_batched_gradgrad=False,
decorators=[
skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfRocm,
@ -12746,7 +12681,8 @@ op_db: List[OpInfo] = [
op=torch.linalg.svd,
aten_name='linalg_svd',
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_svd,
sample_inputs_func=sample_inputs_svd,
check_batched_gradgrad=False,
decorators=[
skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfRocm,

View File

@ -2641,23 +2641,6 @@ def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device):
dtype=dtype, device=device)
return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device)
# TODO: remove this (prefer make_fullrank_matrices_with_distinct_singular_values below)
def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims,
**kwargs):
dtype = kwargs.get('dtype', torch.double)
device = kwargs.get('device', 'cpu')
silent = kwargs.get("silent", False)
if silent and not torch._C.has_lapack:
return torch.ones(matrix_size, matrix_size, dtype=dtype, device=device)
A = torch.randn(batch_dims + (matrix_size, matrix_size), dtype=dtype, device=device)
u, _, vh = torch.linalg.svd(A, full_matrices=False)
real_dtype = A.real.dtype if A.dtype.is_complex else A.dtype
s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1))
return (u * s.to(A.dtype)) @ vh
# Creates a full rank matrix with distinct signular values or
# a batch of such matrices
def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False):
@ -2667,8 +2650,18 @@ def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype,
# TODO: improve the handling of complex tensors here
real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
k = min(shape[-1], shape[-2])
s = torch.arange(1., k + 1, dtype=real_dtype, device=device).mul_(1.0 / (k + 1))
x = (u * s.to(dtype)) @ vh
# We choose the singular values to be "around one"
# This is to make the matrix well conditioned
# s = [2, 3, ..., k+1]
s = torch.arange(2, k + 2, dtype=real_dtype, device=device)
# s = [2, -3, 4, ..., (-1)^k k+1]
s[1::2] *= -1.
# 1 + 1/s so that the singular values are in the range [2/3, 3/2]
# This gives a condition number of 9/4, which should be good enough
s.reciprocal_().add_(1.)
# Note that the singular values need not be ordered in an SVD so
# we don't need need to sort S
x = (u * s.to(u.dtype)) @ vh
x.requires_grad_(requires_grad)
return x