Add matrix power (#11421)

Summary:
vishwakftw Your patch needed some updates because the default native function dispatches changed from `[function, method]` to `[function]`. The CI was run before that change happened so it still shows green, but the internal test caught it.

I did some changes when rebasing and updating so I didn't just force push to your branch. Let's see if this passes CI and internal test. If it does, let me know if you want me to force push to your branch or use this PR instead.

Note to reviewers: patch was already approved at #10068 .

cc yf225
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11421

Differential Revision: D9733407

Pulled By: SsnL

fbshipit-source-id: cf2ed293bb9942dcc5158934ff4def2f63252599
This commit is contained in:
Tongzhou Wang 2018-09-08 15:21:14 -07:00 committed by Facebook Github Bot
parent 802380ac93
commit d3f98b5ffc
10 changed files with 142 additions and 3 deletions

View File

@ -347,5 +347,44 @@ Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor
return result;
}
Tensor matrix_power(const Tensor& a, int64_t n) {
AT_CHECK(a.dim() >= 2 && at::isFloatingType(a.type().scalarType()),
"matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor "
"of floating types with dim at least 2");
if (n == 0) {
return a.clone().copy_(at::eye(a.size(-2), a.options()).expand_as(a));
} else if (n < 0) {
AT_CHECK(a.dim() == 2, "Negative powers for batch matrices are currently not supported");
Tensor a_ = at::inverse(a);
n *= -1;
return at::native::matrix_power(a_, n);
} else if (n == 1) {
return a.clone();
} else if (n == 2) {
return at::native::matmul(a, a);
} else if (n == 3) {
return at::native::matmul(at::native::matmul(a, a), a);
}
// This is a binary decomposition of n.
// Moving from the least significant bit to the most significant bit
// This is done to reduce the number of matrix multiplications
// by raising the input matrix in powers of 2
// The total number of matrix multiplications are
// number of bits + number of bits that equal 1 ~ O(log n)
// instead of O(n)
Tensor result, z;
int64_t r;
while (n > 0) {
z = (!z.defined()) ? a.clone() : at::native::matmul(z, z);
r = n % 2;
n = n / 2;
if (r == 1) {
result = (!result.defined()) ? z.clone() : at::native::matmul(result, z);
}
}
return result;
}
}
} // namespace native
} // namespace at

View File

@ -993,7 +993,6 @@
- func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor
- func: matmul(Tensor self, Tensor other) -> Tensor
@ -1005,6 +1004,9 @@
- func: matrix_rank(Tensor self, bool symmetric=false) -> Tensor
- func: matrix_power(Tensor self, int64_t n) -> Tensor
variants: function, method
- func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor)
variants: function, method

View File

@ -285,6 +285,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: masked_fill_
.. automethod:: masked_select
.. automethod:: matmul
.. automethod:: matrix_power
.. automethod:: max
.. automethod:: mean
.. automethod:: median

View File

@ -169,7 +169,7 @@ Pointwise Ops
.. autofunction:: cos
.. autofunction:: cosh
.. autofunction:: div
.. autofunction:: digamma
.. autofunction:: digamma
.. autofunction:: erf
.. autofunction:: erfc
.. autofunction:: erfinv
@ -296,6 +296,7 @@ BLAS and LAPACK Operations
.. autofunction:: logdet
.. autofunction:: slogdet
.. autofunction:: matmul
.. autofunction:: matrix_power
.. autofunction:: matrix_rank
.. autofunction:: mm
.. autofunction:: mv

View File

@ -2976,6 +2976,14 @@ method_tests = [
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"),
('matmul', (S, S, M, M), ((M,),), "4d_1d"),
('matmul', (M,), ((S, S, M, S),), "1d_4d"),
('matrix_power', (S, S), [2], "n=2"),
('matrix_power', (S, S, S), [3], "n=3"),
('matrix_power', (S, S, S), [1], "n=1"),
('matrix_power', (S, S, S), [0], "n=0"),
('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1",
NO_ARGS, [skipIfNoLapack]),
('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3",
NO_ARGS, [skipIfNoLapack]),
('addcmul', (S, S), ((S, S), (S, S))),
('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),

View File

@ -1460,6 +1460,10 @@ class TestCuda(TestCase):
def test_matrix_rank(self):
TestTorch._test_matrix_rank(self, lambda x: x.cuda())
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_matrix_power(self):
TestTorch._test_matrix_power(self, conv_fn=lambda t: t.cuda())
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_det_logdet_slogdet(self):
TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda())

View File

@ -7094,6 +7094,8 @@ EXCLUDE_SCRIPT = {
'test_split_dim_neg0',
'test_gesv',
'test_inverse',
'test_matrix_power_n=-1', # involves inverse
'test_matrix_power_n=-3', # involves inverse
# skipped nn functional tests
# ops involves sampling which could not test
'test_nn_dropout',

View File

@ -4595,6 +4595,49 @@ class TestTorch(TestCase):
def test_pinverse(self):
self._test_pinverse(self, conv_fn=lambda x: x)
@staticmethod
def _test_matrix_power(self, conv_fn):
def run_test(M, sign=1):
if sign == -1:
M = M.inverse()
MP2 = torch.matrix_power(M, 2)
self.assertEqual(MP2, torch.matmul(M, M))
MP3 = torch.matrix_power(M, 3)
self.assertEqual(MP3, torch.matmul(MP2, M))
MP4 = torch.matrix_power(M, 4)
self.assertEqual(MP4, torch.matmul(MP2, MP2))
MP6 = torch.matrix_power(M, 6)
self.assertEqual(MP6, torch.matmul(MP3, MP3))
MP0 = torch.matrix_power(M, 0)
self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M))
# Single matrix
M = conv_fn(torch.randn(5, 5))
run_test(M)
# Batch matrices
M = conv_fn(torch.randn(3, 3, 3))
run_test(M)
# Many batch matrices
M = conv_fn(torch.randn(2, 3, 3, 3))
run_test(M)
# Single matrix, but full rank
# This is for negative powers
from test_autograd import random_fullrank_matrix_distinct_singular_value
M = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
run_test(M)
run_test(M, sign=-1)
@skipIfNoLapack
def test_matrix_power(self):
self._test_matrix_power(self, conv_fn=lambda x: x)
@staticmethod
def _test_det_logdet_slogdet(self, conv_fn):
def reference_det(M):

View File

@ -1317,6 +1317,13 @@ masked_select(mask) -> Tensor
See :func:`torch.masked_select`
""")
add_docstr_all('matrix_power',
r"""
matrix_power(n) -> Tensor
See :func:`torch.matrix_power`
""")
add_docstr_all('max',
r"""
max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor)

View File

@ -2385,6 +2385,38 @@ Example::
tensor(9)
""")
add_docstr(torch.matrix_power,
r"""
matrix_power(input, n) -> Tensor
Returns the matrix raised to the power :attr:`n` for square matrices.
For batch of matrices, each individual matrix is raised to the power :attr:`n`.
If :attr:`n` is negative, then the inverse of the matrix (if invertible) is
raised to the power :attr:`n`. If :attr:`n` is 0, then an identity matrix
is returned.
Args:
input (Tensor): the input tensor
n (int): the power to raise the matrix to
Example::
>>> a = torch.randn(2, 2, 2)
>>> a
tensor([[[-1.9975, -1.9610],
[ 0.9592, -2.3364]],
[[-1.2534, -1.3429],
[ 0.4153, -1.4664]]])
>>> torch.matrix_power(a, 3)
tensor([[[ 3.9392, -23.9916],
[ 11.7357, -0.2070]],
[[ 0.2468, -6.7168],
[ 2.0774, -0.8187]]])
""")
add_docstr(torch.max,
r"""
.. function:: max(input) -> Tensor