mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add matrix power (#11421)
Summary: vishwakftw Your patch needed some updates because the default native function dispatches changed from `[function, method]` to `[function]`. The CI was run before that change happened so it still shows green, but the internal test caught it. I did some changes when rebasing and updating so I didn't just force push to your branch. Let's see if this passes CI and internal test. If it does, let me know if you want me to force push to your branch or use this PR instead. Note to reviewers: patch was already approved at #10068 . cc yf225 Pull Request resolved: https://github.com/pytorch/pytorch/pull/11421 Differential Revision: D9733407 Pulled By: SsnL fbshipit-source-id: cf2ed293bb9942dcc5158934ff4def2f63252599
This commit is contained in:
parent
802380ac93
commit
d3f98b5ffc
|
|
@ -347,5 +347,44 @@ Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor matrix_power(const Tensor& a, int64_t n) {
|
||||
AT_CHECK(a.dim() >= 2 && at::isFloatingType(a.type().scalarType()),
|
||||
"matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor "
|
||||
"of floating types with dim at least 2");
|
||||
if (n == 0) {
|
||||
return a.clone().copy_(at::eye(a.size(-2), a.options()).expand_as(a));
|
||||
} else if (n < 0) {
|
||||
AT_CHECK(a.dim() == 2, "Negative powers for batch matrices are currently not supported");
|
||||
Tensor a_ = at::inverse(a);
|
||||
n *= -1;
|
||||
return at::native::matrix_power(a_, n);
|
||||
} else if (n == 1) {
|
||||
return a.clone();
|
||||
} else if (n == 2) {
|
||||
return at::native::matmul(a, a);
|
||||
} else if (n == 3) {
|
||||
return at::native::matmul(at::native::matmul(a, a), a);
|
||||
}
|
||||
|
||||
// This is a binary decomposition of n.
|
||||
// Moving from the least significant bit to the most significant bit
|
||||
// This is done to reduce the number of matrix multiplications
|
||||
// by raising the input matrix in powers of 2
|
||||
// The total number of matrix multiplications are
|
||||
// number of bits + number of bits that equal 1 ~ O(log n)
|
||||
// instead of O(n)
|
||||
Tensor result, z;
|
||||
int64_t r;
|
||||
while (n > 0) {
|
||||
z = (!z.defined()) ? a.clone() : at::native::matmul(z, z);
|
||||
r = n % 2;
|
||||
n = n / 2;
|
||||
if (r == 1) {
|
||||
result = (!result.defined()) ? z.clone() : at::native::matmul(result, z);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -993,7 +993,6 @@
|
|||
|
||||
- func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
|
||||
|
||||
|
||||
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor
|
||||
|
||||
- func: matmul(Tensor self, Tensor other) -> Tensor
|
||||
|
|
@ -1005,6 +1004,9 @@
|
|||
|
||||
- func: matrix_rank(Tensor self, bool symmetric=false) -> Tensor
|
||||
|
||||
- func: matrix_power(Tensor self, int64_t n) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor)
|
||||
variants: function, method
|
||||
|
||||
|
|
|
|||
|
|
@ -285,6 +285,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: masked_fill_
|
||||
.. automethod:: masked_select
|
||||
.. automethod:: matmul
|
||||
.. automethod:: matrix_power
|
||||
.. automethod:: max
|
||||
.. automethod:: mean
|
||||
.. automethod:: median
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ Pointwise Ops
|
|||
.. autofunction:: cos
|
||||
.. autofunction:: cosh
|
||||
.. autofunction:: div
|
||||
.. autofunction:: digamma
|
||||
.. autofunction:: digamma
|
||||
.. autofunction:: erf
|
||||
.. autofunction:: erfc
|
||||
.. autofunction:: erfinv
|
||||
|
|
@ -296,6 +296,7 @@ BLAS and LAPACK Operations
|
|||
.. autofunction:: logdet
|
||||
.. autofunction:: slogdet
|
||||
.. autofunction:: matmul
|
||||
.. autofunction:: matrix_power
|
||||
.. autofunction:: matrix_rank
|
||||
.. autofunction:: mm
|
||||
.. autofunction:: mv
|
||||
|
|
|
|||
|
|
@ -2976,6 +2976,14 @@ method_tests = [
|
|||
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"),
|
||||
('matmul', (S, S, M, M), ((M,),), "4d_1d"),
|
||||
('matmul', (M,), ((S, S, M, S),), "1d_4d"),
|
||||
('matrix_power', (S, S), [2], "n=2"),
|
||||
('matrix_power', (S, S, S), [3], "n=3"),
|
||||
('matrix_power', (S, S, S), [1], "n=1"),
|
||||
('matrix_power', (S, S, S), [0], "n=0"),
|
||||
('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1",
|
||||
NO_ARGS, [skipIfNoLapack]),
|
||||
('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3",
|
||||
NO_ARGS, [skipIfNoLapack]),
|
||||
('addcmul', (S, S), ((S, S), (S, S))),
|
||||
('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
|
||||
('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
|
||||
|
|
|
|||
|
|
@ -1460,6 +1460,10 @@ class TestCuda(TestCase):
|
|||
def test_matrix_rank(self):
|
||||
TestTorch._test_matrix_rank(self, lambda x: x.cuda())
|
||||
|
||||
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
|
||||
def test_matrix_power(self):
|
||||
TestTorch._test_matrix_power(self, conv_fn=lambda t: t.cuda())
|
||||
|
||||
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
|
||||
def test_det_logdet_slogdet(self):
|
||||
TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda())
|
||||
|
|
|
|||
|
|
@ -7094,6 +7094,8 @@ EXCLUDE_SCRIPT = {
|
|||
'test_split_dim_neg0',
|
||||
'test_gesv',
|
||||
'test_inverse',
|
||||
'test_matrix_power_n=-1', # involves inverse
|
||||
'test_matrix_power_n=-3', # involves inverse
|
||||
# skipped nn functional tests
|
||||
# ops involves sampling which could not test
|
||||
'test_nn_dropout',
|
||||
|
|
|
|||
|
|
@ -4595,6 +4595,49 @@ class TestTorch(TestCase):
|
|||
def test_pinverse(self):
|
||||
self._test_pinverse(self, conv_fn=lambda x: x)
|
||||
|
||||
@staticmethod
|
||||
def _test_matrix_power(self, conv_fn):
|
||||
def run_test(M, sign=1):
|
||||
if sign == -1:
|
||||
M = M.inverse()
|
||||
MP2 = torch.matrix_power(M, 2)
|
||||
self.assertEqual(MP2, torch.matmul(M, M))
|
||||
|
||||
MP3 = torch.matrix_power(M, 3)
|
||||
self.assertEqual(MP3, torch.matmul(MP2, M))
|
||||
|
||||
MP4 = torch.matrix_power(M, 4)
|
||||
self.assertEqual(MP4, torch.matmul(MP2, MP2))
|
||||
|
||||
MP6 = torch.matrix_power(M, 6)
|
||||
self.assertEqual(MP6, torch.matmul(MP3, MP3))
|
||||
|
||||
MP0 = torch.matrix_power(M, 0)
|
||||
self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M))
|
||||
|
||||
# Single matrix
|
||||
M = conv_fn(torch.randn(5, 5))
|
||||
run_test(M)
|
||||
|
||||
# Batch matrices
|
||||
M = conv_fn(torch.randn(3, 3, 3))
|
||||
run_test(M)
|
||||
|
||||
# Many batch matrices
|
||||
M = conv_fn(torch.randn(2, 3, 3, 3))
|
||||
run_test(M)
|
||||
|
||||
# Single matrix, but full rank
|
||||
# This is for negative powers
|
||||
from test_autograd import random_fullrank_matrix_distinct_singular_value
|
||||
M = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
|
||||
run_test(M)
|
||||
run_test(M, sign=-1)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_matrix_power(self):
|
||||
self._test_matrix_power(self, conv_fn=lambda x: x)
|
||||
|
||||
@staticmethod
|
||||
def _test_det_logdet_slogdet(self, conv_fn):
|
||||
def reference_det(M):
|
||||
|
|
|
|||
|
|
@ -1317,6 +1317,13 @@ masked_select(mask) -> Tensor
|
|||
See :func:`torch.masked_select`
|
||||
""")
|
||||
|
||||
add_docstr_all('matrix_power',
|
||||
r"""
|
||||
matrix_power(n) -> Tensor
|
||||
|
||||
See :func:`torch.matrix_power`
|
||||
""")
|
||||
|
||||
add_docstr_all('max',
|
||||
r"""
|
||||
max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor)
|
||||
|
|
|
|||
|
|
@ -2385,6 +2385,38 @@ Example::
|
|||
tensor(9)
|
||||
""")
|
||||
|
||||
add_docstr(torch.matrix_power,
|
||||
r"""
|
||||
matrix_power(input, n) -> Tensor
|
||||
|
||||
Returns the matrix raised to the power :attr:`n` for square matrices.
|
||||
For batch of matrices, each individual matrix is raised to the power :attr:`n`.
|
||||
|
||||
If :attr:`n` is negative, then the inverse of the matrix (if invertible) is
|
||||
raised to the power :attr:`n`. If :attr:`n` is 0, then an identity matrix
|
||||
is returned.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
n (int): the power to raise the matrix to
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(2, 2, 2)
|
||||
>>> a
|
||||
tensor([[[-1.9975, -1.9610],
|
||||
[ 0.9592, -2.3364]],
|
||||
|
||||
[[-1.2534, -1.3429],
|
||||
[ 0.4153, -1.4664]]])
|
||||
>>> torch.matrix_power(a, 3)
|
||||
tensor([[[ 3.9392, -23.9916],
|
||||
[ 11.7357, -0.2070]],
|
||||
|
||||
[[ 0.2468, -6.7168],
|
||||
[ 2.0774, -0.8187]]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.max,
|
||||
r"""
|
||||
.. function:: max(input) -> Tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user