Added torch.linalg.matrix_power (#52608)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52608

**TODO**

- [x] Add OpInfo
- [x] Update documentation
- [x] Add more tests and compare against NumPy

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D27261532

Pulled By: heitorschueroff

fbshipit-source-id: c1e4ab297da3683f6d5751be8790602f9dc37b6b
This commit is contained in:
Heitor Schueroff 2021-03-23 15:08:06 -07:00 committed by Facebook GitHub Bot
parent 345b26ca08
commit f9e7f132fb
10 changed files with 227 additions and 77 deletions

View File

@ -450,7 +450,6 @@ _(aten, masked_fill) \
_(aten, masked_scatter) \
_(aten, masked_select) \
_(aten, matmul) \
_(aten, matrix_power) \
_(aten, matrix_rank) \
_(aten, matrix_exp) \
_(aten, max) \

View File

@ -196,6 +196,8 @@ namespace c10 {
_(aten, clip_) \
_(aten, det) \
_(aten, linalg_det) \
_(aten, matrix_power) \
_(aten, linalg_matrix_power) \
_(aten, linalg_norm) \
_(aten, linalg_vector_norm) \
_(aten, append) \

View File

@ -201,6 +201,102 @@ Tensor pinverse(const Tensor& self, double rcond) {
return at::linalg_pinv(self, rcond, /*hermitian=*/false);
}
// matrix_power implementation
namespace {
/**
* @brief Raises the input matrix to the given power n
*
* If the exponent n is negative, the inverse of the input
* matrix will be raised to power abs(n).
*
* @param self (batched) square matrix to raise to power n
* @param n exponent to raise matrix (or matrices in batch) to
* @param _out optional tensor to write the output to
* @return Tensor input matrix raised to power n
*/
Tensor linalg_matrix_power_impl(
const Tensor& self,
int64_t n,
c10::optional<Tensor> _out) {
auto out = _out.value_or(Tensor());
squareCheckInputs(self);
if (_out.has_value()) {
checkSameDevice("matrix_power", out, self);
checkLinalgCompatibleDtype("matrix_power", out, self);
at::native::resize_output(out, self.sizes());
}
// For n=0 we return the identity matrix of the same shape as input.
if (n == 0) {
if (!_out.has_value()) {
// Clone input to include result in the autograd graph
out = self.clone(at::MemoryFormat::Contiguous);
}
return out.copy_(at::eye(self.size(-2), self.options()));
}
if (n == 1) {
return _out.has_value() ? out.copy_(self)
: self.clone(at::MemoryFormat::Contiguous);
}
if (n == -1) {
return _out.has_value() ? at::linalg_inv_out(out, self)
: at::linalg_inv(self);
}
// For negative n we inverte the input matrix before raising to power abs(n)
auto a = n < 0 ? at::linalg_inv(self) : self;
n = std::abs(n);
// Fast paths for small powers
if (n == 2) {
return _out.has_value() ? at::matmul_out(out, a, a) : at::matmul(a, a);
}
if (n == 3) {
return _out.has_value() ? at::matmul_out(out, at::matmul(a, a), a)
: at::matmul(at::matmul(a, a), a);
}
// This is a binary decomposition of n.
// Moving from the least significant bit to the most significant bit
// This is done to reduce the number of matrix multiplications
// by raising the input matrix in powers of 2
// The total number of matrix multiplications are
// number of bits + number of bits that equal 1 ~ O(log n)
// instead of O(n)
Tensor z, result;
while (n > 0) {
const auto bit = n % 2;
n = n / 2;
z = z.defined() ? at::matmul(z, z) : a;
if (bit == 1) {
if (_out.has_value() && n <= 0) {
// Last multiplication can use the out version
return result.defined() ? at::matmul_out(out, result, z) : out.copy_(z);
}
result = result.defined() ? at::matmul(result, z) : z;
}
}
return result;
}
} // namespace
Tensor& linalg_matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
linalg_matrix_power_impl(self, n, result);
return result;
}
Tensor linalg_matrix_power(const Tensor& self, int64_t n) {
return linalg_matrix_power_impl(self, n, c10::nullopt);
}
Tensor matrix_power(const Tensor& self, int64_t n) {
return at::native::linalg_matrix_power(self, n);
}
Tensor& linalg_matrix_rank_out(Tensor& result, const Tensor& self, optional<double> tol, bool hermitian) {
checkSameDevice("linalg_matrix_rank", result, self);
ScalarType output_type = ScalarType::Long;
@ -1535,44 +1631,6 @@ Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
);
}
Tensor matrix_power(const Tensor& a, int64_t n) {
TORCH_CHECK(a.dim() >= 2 && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())),
"matrix_power(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor "
"of floating types with dim at least 2");
if (n == 0) {
return a.clone(at::MemoryFormat::Contiguous).copy_(at::eye(a.size(-2), a.options()).expand_as(a));
} else if (n < 0) {
Tensor a_ = at::inverse(a);
n *= -1;
return at::native::matrix_power(a_, n);
} else if (n == 1) {
return a.clone(at::MemoryFormat::Contiguous);
} else if (n == 2) {
return at::native::matmul(a, a);
} else if (n == 3) {
return at::native::matmul(at::native::matmul(a, a), a);
}
// This is a binary decomposition of n.
// Moving from the least significant bit to the most significant bit
// This is done to reduce the number of matrix multiplications
// by raising the input matrix in powers of 2
// The total number of matrix multiplications are
// number of bits + number of bits that equal 1 ~ O(log n)
// instead of O(n)
Tensor result, z;
int64_t r;
while (n > 0) {
z = (!z.defined()) ? a.clone(at::MemoryFormat::Contiguous) : at::native::matmul(z, z);
r = n % 2;
n = n / 2;
if (r == 1) {
result = (!result.defined()) ? z.clone(at::MemoryFormat::Contiguous) : at::native::matmul(result, z);
}
}
return result;
}
Tensor frobenius_norm(const Tensor& self) {
return at::norm(self);
}

View File

@ -9058,6 +9058,12 @@
CPU: _linalg_qr_helper_cpu
CUDA: _linalg_qr_helper_cuda
- func: linalg_matrix_power(Tensor self, int n) -> Tensor
python_module: linalg
- func: linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
- func: linalg_matrix_rank(Tensor self, float? tol=None, bool hermitian=False) -> Tensor
python_module: linalg
variants: function

View File

@ -22,6 +22,7 @@ Functions
.. autofunction:: slogdet
.. autofunction:: eigh
.. autofunction:: eigvalsh
.. autofunction:: matrix_power
.. autofunction:: matrix_rank
.. autofunction:: norm
.. autofunction:: vector_norm

View File

@ -5784,50 +5784,45 @@ else:
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_matrix_power(self, device, dtype):
def run_test(M, sign=1):
if sign == -1:
M = M.inverse()
MP2 = torch.matrix_power(M, 2)
self.assertEqual(MP2, torch.matmul(M, M))
@skipCUDAIfNoMagmaAndNoCusolver
@dtypes(torch.double, torch.cdouble)
def test_matrix_power_non_negative(self, device, dtype):
def check(*size, discontiguous=False):
t = make_tensor(size, device, dtype, discontiguous=discontiguous)
for n in range(8):
res = torch.linalg.matrix_power(t, n)
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
self.assertEqual(res.cpu(), torch.from_numpy(ref))
MP3 = torch.matrix_power(M, 3)
self.assertEqual(MP3, torch.matmul(MP2, M))
check(0, 0)
check(1, 1)
check(5, 5)
check(5, 5, discontiguous=True)
check(0, 3, 3)
check(2, 3, 3)
check(2, 3, 4, 4, discontiguous=True)
MP4 = torch.matrix_power(M, 4)
self.assertEqual(MP4, torch.matmul(MP2, MP2))
MP6 = torch.matrix_power(M, 6)
self.assertEqual(MP6, torch.matmul(MP3, MP3))
MP0 = torch.matrix_power(M, 0)
self.assertEqual(MP0, torch.eye(M.size(-2), dtype=dtype).expand_as(M))
# Single matrix
M = torch.randn(5, 5, dtype=dtype, device=device)
run_test(M)
# Batch matrices
M = torch.randn(3, 3, 3, dtype=dtype, device=device)
run_test(M)
# Many batch matrices
M = torch.randn(2, 3, 3, 3, dtype=dtype, device=device)
run_test(M)
# This is for negative powers
@skipCUDAIfRocm
@skipCPUIfNoLapack
@skipCUDAIfNoMagmaAndNoCusolver
@dtypes(torch.double, torch.cdouble)
def test_matrix_power_negative(self, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
M = random_fullrank_matrix_distinct_singular_value(5, dtype=dtype, device=device)
run_test(M, sign=-1)
M = random_fullrank_matrix_distinct_singular_value(3, 3, dtype=dtype, device=device)
run_test(M, sign=-1)
def check(*size):
t = random_fullrank_matrix_distinct_singular_value(*size, dtype=dtype, device=device)
for n in range(-7, 0):
res = torch.linalg.matrix_power(t, n)
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
self.assertEqual(res.cpu(), torch.from_numpy(ref))
M = random_fullrank_matrix_distinct_singular_value(3, 2, 3, dtype=dtype, device=device)
run_test(M, sign=-1)
check(0)
check(5)
check(0, 2)
check(3, 0)
check(3, 2)
check(5, 2, 3)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack

View File

@ -80,6 +80,14 @@ inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, optional<Scal
return torch::linalg_vector_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
}
inline Tensor matrix_power(const Tensor& self, int64_t n) {
return torch::linalg_matrix_power(self, n);
}
inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
return torch::linalg_matrix_power_out(result, self, n);
}
inline Tensor matrix_rank(const Tensor input, optional<double> tol, bool hermitian) {
return torch::linalg_matrix_rank(input, tol, hermitian);
}
@ -219,6 +227,15 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string o
return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
}
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_power
inline Tensor matrix_power(const Tensor& self, int64_t n) {
return torch::linalg_matrix_power(self, n);
}
inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
return torch::linalg_matrix_power_out(result, self, n);
}
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_rank
inline Tensor matrix_rank(const Tensor input, optional<double> tol, bool hermitian) {
return detail::matrix_rank(input, tol, hermitian);

View File

@ -519,6 +519,49 @@ Example::
tensor(5.7220e-06)
""")
matrix_power = _add_docstr(_linalg.linalg_matrix_power, r"""
matrix_power(input, n, *, out=None) -> Tensor
Raises the square matrix :attr:`input`, or each square matrix in a batched
:attr:`input`, to the integer power :attr:`n`.
If :attr:`n` is 0, the identity matrix (or batch of identity matrices) of the same shape
as :attr:`input` is returned. If :attr:`n` is negative, the inverse of each matrix
(if invertible) is computed and then raised to the integer power ``abs(n)``.
Args:
input (Tensor): the input matrix of size `(n, n)` or the batch of matrices of size
`(*, n, n)` where `*` is one or more batch dimensions.
n (int): the exponent to raise the :attr:`input` matrix to
Keyword args:
out (Tensor, optional): tensor to write the output to.
Example:
>>> a = torch.randn(3, 3)
>>> a
tensor([[-0.2270, 0.6663, -1.3515],
[-0.9838, -0.4002, -1.9313],
[-0.7886, -0.0450, 0.0528]])
>>> torch.linalg.matrix_power(a, 0)
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
>>> torch.linalg.matrix_power(a, 3)
tensor([[ 1.0756, 0.4980, 0.0100],
[-1.6617, 1.4994, -1.9980],
[-0.4509, 0.2731, 0.8001]])
>>> torch.linalg.matrix_power(a.expand(2, -1, -1), -2)
tensor([[[ 0.2640, 0.4571, -0.5511],
[-1.0163, 0.3491, -1.5292],
[-0.4899, 0.0822, 0.2773]],
[[ 0.2640, 0.4571, -0.5511],
[-1.0163, 0.3491, -1.5292],
[-0.4899, 0.0822, 0.2773]]])
""")
matrix_rank = _add_docstr(_linalg.linalg_matrix_rank, r"""
matrix_rank(input, tol=None, hermitian=False, *, out=None) -> Tensor

View File

@ -531,6 +531,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.masked_select: lambda input, mask, out=None: -1,
torch.matmul: lambda input, other, out=None: -1,
torch.matrix_power: lambda input, n: -1,
torch.linalg.matrix_power: lambda input, n, out=None: -1,
torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
torch.matrix_exp: lambda input: -1,

View File

@ -374,6 +374,27 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
args=(torch.tensor([1, 2, 3]),),
kwargs=dict(dim=1)),)
def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad):
# (<matrix_size>, (<batch_sizes, ...>))
test_sizes = [
(1, ()),
(2, (0,)),
(2, (2,)),
]
inputs = []
for matrix_size, batch_sizes in test_sizes:
size = batch_sizes + (matrix_size, matrix_size)
for n in (0, 3, 5):
t = make_tensor(size, device, dtype, requires_grad=requires_grad)
inputs.append(SampleInput(t, args=(n,)))
for n in [-4, -2, -1]:
t = random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_sizes, device=device, dtype=dtype)
t.requires_grad = requires_grad
inputs.append(SampleInput(t, args=(n,)))
return inputs
def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad):
test_sizes = [
(S,),
@ -2662,6 +2683,13 @@ op_db: List[OpInfo] = [
SkipInfo('TestGradients', 'test_fn_grad'),
SkipInfo('TestCommon', 'test_variant_consistency_jit'),
)),
OpInfo('linalg.matrix_power',
aten_name='linalg_matrix_power',
dtypes=floating_and_complex_types(),
supports_inplace_autograd=False,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, skipCUDAIfRocm],
sample_inputs_func=sample_inputs_linalg_matrix_power,
),
OpInfo('linalg.norm',
op=torch.linalg.norm,
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),