mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added torch.linalg.matrix_power (#52608)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52608 **TODO** - [x] Add OpInfo - [x] Update documentation - [x] Add more tests and compare against NumPy Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D27261532 Pulled By: heitorschueroff fbshipit-source-id: c1e4ab297da3683f6d5751be8790602f9dc37b6b
This commit is contained in:
parent
345b26ca08
commit
f9e7f132fb
|
|
@ -450,7 +450,6 @@ _(aten, masked_fill) \
|
|||
_(aten, masked_scatter) \
|
||||
_(aten, masked_select) \
|
||||
_(aten, matmul) \
|
||||
_(aten, matrix_power) \
|
||||
_(aten, matrix_rank) \
|
||||
_(aten, matrix_exp) \
|
||||
_(aten, max) \
|
||||
|
|
|
|||
|
|
@ -196,6 +196,8 @@ namespace c10 {
|
|||
_(aten, clip_) \
|
||||
_(aten, det) \
|
||||
_(aten, linalg_det) \
|
||||
_(aten, matrix_power) \
|
||||
_(aten, linalg_matrix_power) \
|
||||
_(aten, linalg_norm) \
|
||||
_(aten, linalg_vector_norm) \
|
||||
_(aten, append) \
|
||||
|
|
|
|||
|
|
@ -201,6 +201,102 @@ Tensor pinverse(const Tensor& self, double rcond) {
|
|||
return at::linalg_pinv(self, rcond, /*hermitian=*/false);
|
||||
}
|
||||
|
||||
// matrix_power implementation
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @brief Raises the input matrix to the given power n
|
||||
*
|
||||
* If the exponent n is negative, the inverse of the input
|
||||
* matrix will be raised to power abs(n).
|
||||
*
|
||||
* @param self (batched) square matrix to raise to power n
|
||||
* @param n exponent to raise matrix (or matrices in batch) to
|
||||
* @param _out optional tensor to write the output to
|
||||
* @return Tensor input matrix raised to power n
|
||||
*/
|
||||
Tensor linalg_matrix_power_impl(
|
||||
const Tensor& self,
|
||||
int64_t n,
|
||||
c10::optional<Tensor> _out) {
|
||||
auto out = _out.value_or(Tensor());
|
||||
|
||||
squareCheckInputs(self);
|
||||
if (_out.has_value()) {
|
||||
checkSameDevice("matrix_power", out, self);
|
||||
checkLinalgCompatibleDtype("matrix_power", out, self);
|
||||
at::native::resize_output(out, self.sizes());
|
||||
}
|
||||
|
||||
// For n=0 we return the identity matrix of the same shape as input.
|
||||
if (n == 0) {
|
||||
if (!_out.has_value()) {
|
||||
// Clone input to include result in the autograd graph
|
||||
out = self.clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
return out.copy_(at::eye(self.size(-2), self.options()));
|
||||
}
|
||||
if (n == 1) {
|
||||
return _out.has_value() ? out.copy_(self)
|
||||
: self.clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
if (n == -1) {
|
||||
return _out.has_value() ? at::linalg_inv_out(out, self)
|
||||
: at::linalg_inv(self);
|
||||
}
|
||||
|
||||
// For negative n we inverte the input matrix before raising to power abs(n)
|
||||
auto a = n < 0 ? at::linalg_inv(self) : self;
|
||||
n = std::abs(n);
|
||||
|
||||
// Fast paths for small powers
|
||||
if (n == 2) {
|
||||
return _out.has_value() ? at::matmul_out(out, a, a) : at::matmul(a, a);
|
||||
}
|
||||
if (n == 3) {
|
||||
return _out.has_value() ? at::matmul_out(out, at::matmul(a, a), a)
|
||||
: at::matmul(at::matmul(a, a), a);
|
||||
}
|
||||
|
||||
// This is a binary decomposition of n.
|
||||
// Moving from the least significant bit to the most significant bit
|
||||
// This is done to reduce the number of matrix multiplications
|
||||
// by raising the input matrix in powers of 2
|
||||
// The total number of matrix multiplications are
|
||||
// number of bits + number of bits that equal 1 ~ O(log n)
|
||||
// instead of O(n)
|
||||
Tensor z, result;
|
||||
while (n > 0) {
|
||||
const auto bit = n % 2;
|
||||
n = n / 2;
|
||||
z = z.defined() ? at::matmul(z, z) : a;
|
||||
if (bit == 1) {
|
||||
if (_out.has_value() && n <= 0) {
|
||||
// Last multiplication can use the out version
|
||||
return result.defined() ? at::matmul_out(out, result, z) : out.copy_(z);
|
||||
}
|
||||
result = result.defined() ? at::matmul(result, z) : z;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor& linalg_matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
|
||||
linalg_matrix_power_impl(self, n, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor linalg_matrix_power(const Tensor& self, int64_t n) {
|
||||
return linalg_matrix_power_impl(self, n, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor matrix_power(const Tensor& self, int64_t n) {
|
||||
return at::native::linalg_matrix_power(self, n);
|
||||
}
|
||||
|
||||
Tensor& linalg_matrix_rank_out(Tensor& result, const Tensor& self, optional<double> tol, bool hermitian) {
|
||||
checkSameDevice("linalg_matrix_rank", result, self);
|
||||
ScalarType output_type = ScalarType::Long;
|
||||
|
|
@ -1535,44 +1631,6 @@ Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
|
|||
);
|
||||
}
|
||||
|
||||
Tensor matrix_power(const Tensor& a, int64_t n) {
|
||||
TORCH_CHECK(a.dim() >= 2 && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())),
|
||||
"matrix_power(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor "
|
||||
"of floating types with dim at least 2");
|
||||
if (n == 0) {
|
||||
return a.clone(at::MemoryFormat::Contiguous).copy_(at::eye(a.size(-2), a.options()).expand_as(a));
|
||||
} else if (n < 0) {
|
||||
Tensor a_ = at::inverse(a);
|
||||
n *= -1;
|
||||
return at::native::matrix_power(a_, n);
|
||||
} else if (n == 1) {
|
||||
return a.clone(at::MemoryFormat::Contiguous);
|
||||
} else if (n == 2) {
|
||||
return at::native::matmul(a, a);
|
||||
} else if (n == 3) {
|
||||
return at::native::matmul(at::native::matmul(a, a), a);
|
||||
}
|
||||
|
||||
// This is a binary decomposition of n.
|
||||
// Moving from the least significant bit to the most significant bit
|
||||
// This is done to reduce the number of matrix multiplications
|
||||
// by raising the input matrix in powers of 2
|
||||
// The total number of matrix multiplications are
|
||||
// number of bits + number of bits that equal 1 ~ O(log n)
|
||||
// instead of O(n)
|
||||
Tensor result, z;
|
||||
int64_t r;
|
||||
while (n > 0) {
|
||||
z = (!z.defined()) ? a.clone(at::MemoryFormat::Contiguous) : at::native::matmul(z, z);
|
||||
r = n % 2;
|
||||
n = n / 2;
|
||||
if (r == 1) {
|
||||
result = (!result.defined()) ? z.clone(at::MemoryFormat::Contiguous) : at::native::matmul(result, z);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor frobenius_norm(const Tensor& self) {
|
||||
return at::norm(self);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9058,6 +9058,12 @@
|
|||
CPU: _linalg_qr_helper_cpu
|
||||
CUDA: _linalg_qr_helper_cuda
|
||||
|
||||
- func: linalg_matrix_power(Tensor self, int n) -> Tensor
|
||||
python_module: linalg
|
||||
|
||||
- func: linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
|
||||
- func: linalg_matrix_rank(Tensor self, float? tol=None, bool hermitian=False) -> Tensor
|
||||
python_module: linalg
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ Functions
|
|||
.. autofunction:: slogdet
|
||||
.. autofunction:: eigh
|
||||
.. autofunction:: eigvalsh
|
||||
.. autofunction:: matrix_power
|
||||
.. autofunction:: matrix_rank
|
||||
.. autofunction:: norm
|
||||
.. autofunction:: vector_norm
|
||||
|
|
|
|||
|
|
@ -5784,50 +5784,45 @@ else:
|
|||
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
|
||||
atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.double)
|
||||
def test_matrix_power(self, device, dtype):
|
||||
def run_test(M, sign=1):
|
||||
if sign == -1:
|
||||
M = M.inverse()
|
||||
MP2 = torch.matrix_power(M, 2)
|
||||
self.assertEqual(MP2, torch.matmul(M, M))
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_matrix_power_non_negative(self, device, dtype):
|
||||
def check(*size, discontiguous=False):
|
||||
t = make_tensor(size, device, dtype, discontiguous=discontiguous)
|
||||
for n in range(8):
|
||||
res = torch.linalg.matrix_power(t, n)
|
||||
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
|
||||
self.assertEqual(res.cpu(), torch.from_numpy(ref))
|
||||
|
||||
MP3 = torch.matrix_power(M, 3)
|
||||
self.assertEqual(MP3, torch.matmul(MP2, M))
|
||||
check(0, 0)
|
||||
check(1, 1)
|
||||
check(5, 5)
|
||||
check(5, 5, discontiguous=True)
|
||||
check(0, 3, 3)
|
||||
check(2, 3, 3)
|
||||
check(2, 3, 4, 4, discontiguous=True)
|
||||
|
||||
MP4 = torch.matrix_power(M, 4)
|
||||
self.assertEqual(MP4, torch.matmul(MP2, MP2))
|
||||
|
||||
MP6 = torch.matrix_power(M, 6)
|
||||
self.assertEqual(MP6, torch.matmul(MP3, MP3))
|
||||
|
||||
MP0 = torch.matrix_power(M, 0)
|
||||
self.assertEqual(MP0, torch.eye(M.size(-2), dtype=dtype).expand_as(M))
|
||||
|
||||
# Single matrix
|
||||
M = torch.randn(5, 5, dtype=dtype, device=device)
|
||||
run_test(M)
|
||||
|
||||
# Batch matrices
|
||||
M = torch.randn(3, 3, 3, dtype=dtype, device=device)
|
||||
run_test(M)
|
||||
|
||||
# Many batch matrices
|
||||
M = torch.randn(2, 3, 3, 3, dtype=dtype, device=device)
|
||||
run_test(M)
|
||||
|
||||
# This is for negative powers
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_matrix_power_negative(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
|
||||
M = random_fullrank_matrix_distinct_singular_value(5, dtype=dtype, device=device)
|
||||
run_test(M, sign=-1)
|
||||
|
||||
M = random_fullrank_matrix_distinct_singular_value(3, 3, dtype=dtype, device=device)
|
||||
run_test(M, sign=-1)
|
||||
def check(*size):
|
||||
t = random_fullrank_matrix_distinct_singular_value(*size, dtype=dtype, device=device)
|
||||
for n in range(-7, 0):
|
||||
res = torch.linalg.matrix_power(t, n)
|
||||
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
|
||||
self.assertEqual(res.cpu(), torch.from_numpy(ref))
|
||||
|
||||
M = random_fullrank_matrix_distinct_singular_value(3, 2, 3, dtype=dtype, device=device)
|
||||
run_test(M, sign=-1)
|
||||
check(0)
|
||||
check(5)
|
||||
check(0, 2)
|
||||
check(3, 0)
|
||||
check(3, 2)
|
||||
check(5, 2, 3)
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
|
|
|
|||
|
|
@ -80,6 +80,14 @@ inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, optional<Scal
|
|||
return torch::linalg_vector_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
|
||||
}
|
||||
|
||||
inline Tensor matrix_power(const Tensor& self, int64_t n) {
|
||||
return torch::linalg_matrix_power(self, n);
|
||||
}
|
||||
|
||||
inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
|
||||
return torch::linalg_matrix_power_out(result, self, n);
|
||||
}
|
||||
|
||||
inline Tensor matrix_rank(const Tensor input, optional<double> tol, bool hermitian) {
|
||||
return torch::linalg_matrix_rank(input, tol, hermitian);
|
||||
}
|
||||
|
|
@ -219,6 +227,15 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string o
|
|||
return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
|
||||
}
|
||||
|
||||
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_power
|
||||
inline Tensor matrix_power(const Tensor& self, int64_t n) {
|
||||
return torch::linalg_matrix_power(self, n);
|
||||
}
|
||||
|
||||
inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
|
||||
return torch::linalg_matrix_power_out(result, self, n);
|
||||
}
|
||||
|
||||
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_rank
|
||||
inline Tensor matrix_rank(const Tensor input, optional<double> tol, bool hermitian) {
|
||||
return detail::matrix_rank(input, tol, hermitian);
|
||||
|
|
|
|||
|
|
@ -519,6 +519,49 @@ Example::
|
|||
tensor(5.7220e-06)
|
||||
""")
|
||||
|
||||
matrix_power = _add_docstr(_linalg.linalg_matrix_power, r"""
|
||||
matrix_power(input, n, *, out=None) -> Tensor
|
||||
|
||||
Raises the square matrix :attr:`input`, or each square matrix in a batched
|
||||
:attr:`input`, to the integer power :attr:`n`.
|
||||
|
||||
If :attr:`n` is 0, the identity matrix (or batch of identity matrices) of the same shape
|
||||
as :attr:`input` is returned. If :attr:`n` is negative, the inverse of each matrix
|
||||
(if invertible) is computed and then raised to the integer power ``abs(n)``.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input matrix of size `(n, n)` or the batch of matrices of size
|
||||
`(*, n, n)` where `*` is one or more batch dimensions.
|
||||
n (int): the exponent to raise the :attr:`input` matrix to
|
||||
|
||||
Keyword args:
|
||||
out (Tensor, optional): tensor to write the output to.
|
||||
|
||||
Example:
|
||||
|
||||
>>> a = torch.randn(3, 3)
|
||||
>>> a
|
||||
tensor([[-0.2270, 0.6663, -1.3515],
|
||||
[-0.9838, -0.4002, -1.9313],
|
||||
[-0.7886, -0.0450, 0.0528]])
|
||||
>>> torch.linalg.matrix_power(a, 0)
|
||||
tensor([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]])
|
||||
>>> torch.linalg.matrix_power(a, 3)
|
||||
tensor([[ 1.0756, 0.4980, 0.0100],
|
||||
[-1.6617, 1.4994, -1.9980],
|
||||
[-0.4509, 0.2731, 0.8001]])
|
||||
>>> torch.linalg.matrix_power(a.expand(2, -1, -1), -2)
|
||||
tensor([[[ 0.2640, 0.4571, -0.5511],
|
||||
[-1.0163, 0.3491, -1.5292],
|
||||
[-0.4899, 0.0822, 0.2773]],
|
||||
|
||||
[[ 0.2640, 0.4571, -0.5511],
|
||||
[-1.0163, 0.3491, -1.5292],
|
||||
[-0.4899, 0.0822, 0.2773]]])
|
||||
""")
|
||||
|
||||
matrix_rank = _add_docstr(_linalg.linalg_matrix_rank, r"""
|
||||
matrix_rank(input, tol=None, hermitian=False, *, out=None) -> Tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -531,6 +531,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.masked_select: lambda input, mask, out=None: -1,
|
||||
torch.matmul: lambda input, other, out=None: -1,
|
||||
torch.matrix_power: lambda input, n: -1,
|
||||
torch.linalg.matrix_power: lambda input, n, out=None: -1,
|
||||
torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
|
||||
torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
|
||||
torch.matrix_exp: lambda input: -1,
|
||||
|
|
|
|||
|
|
@ -374,6 +374,27 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
|
|||
args=(torch.tensor([1, 2, 3]),),
|
||||
kwargs=dict(dim=1)),)
|
||||
|
||||
def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad):
|
||||
# (<matrix_size>, (<batch_sizes, ...>))
|
||||
test_sizes = [
|
||||
(1, ()),
|
||||
(2, (0,)),
|
||||
(2, (2,)),
|
||||
]
|
||||
|
||||
inputs = []
|
||||
for matrix_size, batch_sizes in test_sizes:
|
||||
size = batch_sizes + (matrix_size, matrix_size)
|
||||
for n in (0, 3, 5):
|
||||
t = make_tensor(size, device, dtype, requires_grad=requires_grad)
|
||||
inputs.append(SampleInput(t, args=(n,)))
|
||||
for n in [-4, -2, -1]:
|
||||
t = random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_sizes, device=device, dtype=dtype)
|
||||
t.requires_grad = requires_grad
|
||||
inputs.append(SampleInput(t, args=(n,)))
|
||||
|
||||
return inputs
|
||||
|
||||
def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad):
|
||||
test_sizes = [
|
||||
(S,),
|
||||
|
|
@ -2662,6 +2683,13 @@ op_db: List[OpInfo] = [
|
|||
SkipInfo('TestGradients', 'test_fn_grad'),
|
||||
SkipInfo('TestCommon', 'test_variant_consistency_jit'),
|
||||
)),
|
||||
OpInfo('linalg.matrix_power',
|
||||
aten_name='linalg_matrix_power',
|
||||
dtypes=floating_and_complex_types(),
|
||||
supports_inplace_autograd=False,
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, skipCUDAIfRocm],
|
||||
sample_inputs_func=sample_inputs_linalg_matrix_power,
|
||||
),
|
||||
OpInfo('linalg.norm',
|
||||
op=torch.linalg.norm,
|
||||
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user