diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 3b50923abb4..2371d82efc6 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -347,5 +347,44 @@ Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor return result; } +Tensor matrix_power(const Tensor& a, int64_t n) { + AT_CHECK(a.dim() >= 2 && at::isFloatingType(a.type().scalarType()), + "matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor " + "of floating types with dim at least 2"); + if (n == 0) { + return a.clone().copy_(at::eye(a.size(-2), a.options()).expand_as(a)); + } else if (n < 0) { + AT_CHECK(a.dim() == 2, "Negative powers for batch matrices are currently not supported"); + Tensor a_ = at::inverse(a); + n *= -1; + return at::native::matrix_power(a_, n); + } else if (n == 1) { + return a.clone(); + } else if (n == 2) { + return at::native::matmul(a, a); + } else if (n == 3) { + return at::native::matmul(at::native::matmul(a, a), a); + } + + // This is a binary decomposition of n. + // Moving from the least significant bit to the most significant bit + // This is done to reduce the number of matrix multiplications + // by raising the input matrix in powers of 2 + // The total number of matrix multiplications are + // number of bits + number of bits that equal 1 ~ O(log n) + // instead of O(n) + Tensor result, z; + int64_t r; + while (n > 0) { + z = (!z.defined()) ? a.clone() : at::native::matmul(z, z); + r = n % 2; + n = n / 2; + if (r == 1) { + result = (!result.defined()) ? z.clone() : at::native::matmul(result, z); + } + } + return result; } -} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ca7db1f8915..4e14e8f8fb5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -993,7 +993,6 @@ - func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor - - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor - func: matmul(Tensor self, Tensor other) -> Tensor @@ -1005,6 +1004,9 @@ - func: matrix_rank(Tensor self, bool symmetric=false) -> Tensor +- func: matrix_power(Tensor self, int64_t n) -> Tensor + variants: function, method + - func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 468de8f5b98..85f6232ff44 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -285,6 +285,7 @@ view of a storage and defines numeric operations on it. .. automethod:: masked_fill_ .. automethod:: masked_select .. automethod:: matmul + .. automethod:: matrix_power .. automethod:: max .. automethod:: mean .. automethod:: median diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 18d21f0e2a1..31585d4a969 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -169,7 +169,7 @@ Pointwise Ops .. autofunction:: cos .. autofunction:: cosh .. autofunction:: div -.. autofunction:: digamma +.. autofunction:: digamma .. autofunction:: erf .. autofunction:: erfc .. autofunction:: erfinv @@ -296,6 +296,7 @@ BLAS and LAPACK Operations .. autofunction:: logdet .. autofunction:: slogdet .. autofunction:: matmul +.. autofunction:: matrix_power .. autofunction:: matrix_rank .. autofunction:: mm .. autofunction:: mv diff --git a/test/test_autograd.py b/test/test_autograd.py index 0dd1801da10..faba5efb184 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2976,6 +2976,14 @@ method_tests = [ ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"), ('matmul', (S, S, M, M), ((M,),), "4d_1d"), ('matmul', (M,), ((S, S, M, S),), "1d_4d"), + ('matrix_power', (S, S), [2], "n=2"), + ('matrix_power', (S, S, S), [3], "n=3"), + ('matrix_power', (S, S, S), [1], "n=1"), + ('matrix_power', (S, S, S), [0], "n=0"), + ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", + NO_ARGS, [skipIfNoLapack]), + ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", + NO_ARGS, [skipIfNoLapack]), ('addcmul', (S, S), ((S, S), (S, S))), ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'), ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'), diff --git a/test/test_cuda.py b/test/test_cuda.py index b174a63201e..1ca7155dd09 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1460,6 +1460,10 @@ class TestCuda(TestCase): def test_matrix_rank(self): TestTorch._test_matrix_rank(self, lambda x: x.cuda()) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_matrix_power(self): + TestTorch._test_matrix_power(self, conv_fn=lambda t: t.cuda()) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_det_logdet_slogdet(self): TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda()) diff --git a/test/test_jit.py b/test/test_jit.py index 55af9cfeae3..5eb7d4649bd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7094,6 +7094,8 @@ EXCLUDE_SCRIPT = { 'test_split_dim_neg0', 'test_gesv', 'test_inverse', + 'test_matrix_power_n=-1', # involves inverse + 'test_matrix_power_n=-3', # involves inverse # skipped nn functional tests # ops involves sampling which could not test 'test_nn_dropout', diff --git a/test/test_torch.py b/test/test_torch.py index 9bee3af682e..e0af3075c43 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4595,6 +4595,49 @@ class TestTorch(TestCase): def test_pinverse(self): self._test_pinverse(self, conv_fn=lambda x: x) + @staticmethod + def _test_matrix_power(self, conv_fn): + def run_test(M, sign=1): + if sign == -1: + M = M.inverse() + MP2 = torch.matrix_power(M, 2) + self.assertEqual(MP2, torch.matmul(M, M)) + + MP3 = torch.matrix_power(M, 3) + self.assertEqual(MP3, torch.matmul(MP2, M)) + + MP4 = torch.matrix_power(M, 4) + self.assertEqual(MP4, torch.matmul(MP2, MP2)) + + MP6 = torch.matrix_power(M, 6) + self.assertEqual(MP6, torch.matmul(MP3, MP3)) + + MP0 = torch.matrix_power(M, 0) + self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M)) + + # Single matrix + M = conv_fn(torch.randn(5, 5)) + run_test(M) + + # Batch matrices + M = conv_fn(torch.randn(3, 3, 3)) + run_test(M) + + # Many batch matrices + M = conv_fn(torch.randn(2, 3, 3, 3)) + run_test(M) + + # Single matrix, but full rank + # This is for negative powers + from test_autograd import random_fullrank_matrix_distinct_singular_value + M = conv_fn(random_fullrank_matrix_distinct_singular_value(5)) + run_test(M) + run_test(M, sign=-1) + + @skipIfNoLapack + def test_matrix_power(self): + self._test_matrix_power(self, conv_fn=lambda x: x) + @staticmethod def _test_det_logdet_slogdet(self, conv_fn): def reference_det(M): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index ff4c0702753..f7ace5e0edc 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1317,6 +1317,13 @@ masked_select(mask) -> Tensor See :func:`torch.masked_select` """) +add_docstr_all('matrix_power', + r""" +matrix_power(n) -> Tensor + +See :func:`torch.matrix_power` +""") + add_docstr_all('max', r""" max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 3c7da8cc246..ea6016f778f 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2385,6 +2385,38 @@ Example:: tensor(9) """) +add_docstr(torch.matrix_power, + r""" +matrix_power(input, n) -> Tensor + +Returns the matrix raised to the power :attr:`n` for square matrices. +For batch of matrices, each individual matrix is raised to the power :attr:`n`. + +If :attr:`n` is negative, then the inverse of the matrix (if invertible) is +raised to the power :attr:`n`. If :attr:`n` is 0, then an identity matrix +is returned. + +Args: + input (Tensor): the input tensor + n (int): the power to raise the matrix to + +Example:: + + >>> a = torch.randn(2, 2, 2) + >>> a + tensor([[[-1.9975, -1.9610], + [ 0.9592, -2.3364]], + + [[-1.2534, -1.3429], + [ 0.4153, -1.4664]]]) + >>> torch.matrix_power(a, 3) + tensor([[[ 3.9392, -23.9916], + [ 11.7357, -0.2070]], + + [[ 0.2468, -6.7168], + [ 2.0774, -0.8187]]]) +""") + add_docstr(torch.max, r""" .. function:: max(input) -> Tensor