Implemented torch.cov (#58311)

Summary:
Based from https://github.com/pytorch/pytorch/pull/50466

Adds the initial implementation of `torch.cov` similar to `numpy.cov`. For simplicity, we removed support for many parameters in `numpy.cov` that are either redundant such as `bias`, or have simple workarounds such as `y` and `rowvar`.

cc PandaBoi

closes https://github.com/pytorch/pytorch/issues/19037

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58311

Reviewed By: jbschlosser

Differential Revision: D29431651

Pulled By: heitorschueroff

fbshipit-source-id: 167dea880f534934b145ba94291a9d634c25b01b
This commit is contained in:
Heitor Schueroff 2021-06-29 13:59:46 -07:00 committed by Facebook GitHub Bot
parent 8f658d537d
commit ec9c03c234
11 changed files with 263 additions and 0 deletions

View File

@ -260,6 +260,7 @@ _(aten, cosine_embedding_loss) \
_(aten, cosine_similarity) \
_(aten, count_nonzero) \
_(aten, cross) \
_(aten, cov) \
_(aten, std_mean) \
_(aten, var_mean) \
_(aten, ctc_loss) \

View File

@ -0,0 +1,109 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
namespace at {
namespace native {
Tensor cov(
const Tensor& self,
int64_t correction,
const c10::optional<Tensor>& fweights,
const c10::optional<Tensor>& aweights) {
constexpr int64_t OBSERVATIONS_DIM = 1;
TORCH_CHECK(
self.ndimension() <= 2,
"cov(): expected input to have two or fewer dimensions but got an input with ",
self.ndimension(),
" dimensions");
TORCH_CHECK(
self.scalar_type() != kBool, "cov(): bool dtype is not supported for input");
// View input tensor as 2D (variables, observations)
auto in = self.ndimension() < 2 ? self.view({1, -1}) : self;
const auto num_observations = in.size(OBSERVATIONS_DIM);
// The product of frequencies (fweights) and weights (aweights).
Tensor w;
if (fweights.has_value()) {
w = fweights.value();
TORCH_CHECK(
w.ndimension() <= 1,
"cov(): expected fweights to have one or fewer dimensions but got fweights with ",
w.ndimension(),
" dimensions");
TORCH_CHECK(
at::isIntegralType(w.scalar_type(), false),
"cov(): expected fweights to have integral dtype but got fweights with ",
w.scalar_type(),
" dtype");
TORCH_CHECK(
w.numel() == num_observations,
"cov(): expected fweights to have the same numel as there are observations in the input but got ",
w.numel(),
" != ",
num_observations);
TORCH_CHECK(
num_observations == 0 || w.min().ge(0).item<bool>(),
"cov(): fweights cannot be negative");
}
if (aweights.has_value()) {
const auto& aw = aweights.value();
TORCH_CHECK(
aw.ndimension() <= 1,
"cov(): expected aweights to have one or fewer dimensions but got aweights with ",
aw.ndimension(),
" dimensions");
TORCH_CHECK(
at::isFloatingType(aw.scalar_type()),
"cov(): expected aweights to have floating point dtype but got aweights with ",
aw.scalar_type(),
" dtype");
TORCH_CHECK(
aw.numel() == num_observations,
"cov(): expected aweights to have the same numel as there are observations in the input but got ",
aw.numel(),
" != ",
num_observations);
TORCH_CHECK(
num_observations == 0 || aw.min().ge(0).item<bool>(),
"cov(): aweights cannot be negative");
w = w.defined() ? w * aw : aw;
}
// Compute a weighted average of the observations
const auto w_sum = w.defined()
? w.sum()
: at::scalar_tensor(num_observations, in.options().dtype(kLong));
TORCH_CHECK(
!w.defined() || w_sum.ne(0).item<bool>(),
"cov(): weights sum to zero, can't be normalized");
const auto avg = (w.defined() ? in * w : in).sum(OBSERVATIONS_DIM) / w_sum;
// Compute the normalization factor
Tensor norm_factor;
if (w.defined() && aweights.has_value() && correction != 0) {
norm_factor = w_sum - correction * (w * aweights.value()).sum() / w_sum;
} else {
norm_factor = w_sum - correction;
}
if (norm_factor.le(0).item<bool>()) {
TORCH_WARN("cov(): degrees of freedom is <= 0");
norm_factor.zero_();
}
// Compute covariance matrix
in = in - avg.unsqueeze(1);
const auto c = at::mm(in, (w.defined() ? in * w : in).t().conj());
return at::true_divide(c, norm_factor).squeeze();
}
} // namespace native
} // namespace at

View File

@ -1277,6 +1277,9 @@
dispatch:
CompositeExplicitAutograd: count_nonzero
- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
variants: function, method
- func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
dispatch:
CUDA: cudnn_affine_grid_generator_forward

View File

@ -290,6 +290,7 @@ Tensor class reference
Tensor.cosh
Tensor.cosh_
Tensor.count_nonzero
Tensor.cov
Tensor.acosh
Tensor.acosh_
Tensor.arccosh

View File

@ -481,6 +481,7 @@ Other Operations
cdist
clone
combinations
cov
cross
cummax
cummin

View File

@ -4372,6 +4372,53 @@ else:
x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
self.assertTrue(x.min() > 0)
@dtypes(torch.float, torch.cfloat)
def test_cov(self, device, dtype):
def check(t, correction=1, fweights=None, aweights=None):
actual = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights)
t = t.cpu().numpy()
fweights = fweights.cpu().numpy() if fweights is not None else None
aweights = aweights.cpu().numpy() if aweights is not None else None
expected = np.cov(t, ddof=correction, fweights=fweights, aweights=aweights)
expected = torch.from_numpy(np.array(expected)).to(dtype=actual.dtype)
self.assertEqual(actual, expected, atol=1e-05, rtol=1e-05)
def generate_input_tensors():
yield make_tensor((0, 0), device, dtype)
yield make_tensor((1, 0), device, dtype)
yield make_tensor((0, 1), device, dtype)
yield make_tensor((2), device, dtype)
yield make_tensor((2, 1), device, dtype)
yield make_tensor((2, 2), device, dtype)
yield make_tensor((2, 3), device, dtype)
yield make_tensor((5, 10), device, dtype)
yield make_tensor((5, 10), device, dtype, noncontiguous=True)
yield torch.tensor([0, -2, nan, 10.2, inf], dtype=dtype, device=device)
for t in generate_input_tensors():
check(t)
num_observations = t.numel() if t.ndim < 2 else t.size(1)
if num_observations > 0:
fweights = torch.randint(1, 10, (num_observations,), device=device)
aweights = make_tensor((num_observations,), device, torch.float, low=1)
for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]):
check(t, correction, fweights, aweights)
def test_cov_error(self, device):
def check(msg, *args, **kwargs):
with self.assertRaisesRegex(RuntimeError, r'cov\(\):.*' + msg + r'.*'):
torch.cov(*args, **kwargs)
a = torch.rand(2)
check(r'expected input to have two or fewer dimensions', torch.rand(2, 2, 2))
check(r'expected fweights to have one or fewer dimensions', a, fweights=torch.rand(2, 2))
check(r'expected aweights to have one or fewer dimensions', a, aweights=torch.rand(2, 2))
check(r'expected fweights to have integral dtype', a, fweights=torch.rand(2))
check(r'expected aweights to have floating point dtype', a, aweights=torch.tensor([1, 1]))
check(r'expected fweights to have the same numel', a, fweights=torch.tensor([1]))
check(r'expected aweights to have the same numel', a, aweights=torch.rand(1))
check(r'fweights cannot be negative', a, fweights=torch.tensor([-1, -2]))
check(r'aweights cannot be negative', a, aweights=torch.tensor([-1., -2.]))
@skipIfNoSciPy
@dtypes(*torch.testing.get_all_fp_dtypes())

View File

@ -974,6 +974,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/ConvolutionMM3d.cpp",
"aten/src/ATen/native/ConvolutionTBC.cpp",
"aten/src/ATen/native/Copy.cpp",
"aten/src/ATen/native/Correlation.cpp",
"aten/src/ATen/native/CPUFallback.cpp",
"aten/src/ATen/native/Cross.cpp",
"aten/src/ATen/native/DilatedMaxPool2d.cpp",

View File

@ -1004,6 +1004,12 @@ count_nonzero(dim=None) -> Tensor
See :func:`torch.count_nonzero`
""")
add_docstr_all('cov', r"""
cov(*, correction=1, fweights=None, aweights=None) -> Tensor
See :func:`torch.cov`
""")
add_docstr_all('cross',
r"""
cross(other, dim=-1) -> Tensor

View File

@ -1688,6 +1688,75 @@ Example::
False
""")
add_docstr(torch.cov, r"""
cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor
Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are
the variables and columns are the observations.
A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains
the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents
a single variable (Scalar or 1D) then its variance is returned.
The unbiased sample covariance of the variables :math:`x` and :math:`y` is given by:
.. math::
\text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{N~-~1}
where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively.
If :attr:`fweights` and/or :attr:`aweights` are provided, the unbiased weighted covariance
is calculated, which is given by:
.. math::
\text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)}{\sum^{N}_{i = 1}w_i~-~1}
where :math:`w` denotes :attr:`fweights` or :attr:`aweights` based on whichever is provided, or
:math:`w = fweights \times aweights` if both are provided, and
:math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable.
Args:
input (Tensor): A 2D matrix containing multiple variables and observations, or a
Scalar or 1D vector representing a single variable.
Keyword Args:
correction (int, optional): difference between the sample size and sample degrees of freedom.
Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate,
even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0``
will return the simple average. Defaults to ``1``.
fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of
times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`.
Must have integral dtype. Ignored if ``None``. `Defaults to ``None``.
aweights (tensor, optional): A Scalar or 1D array of observation vector weights.
These relative weights are typically large for observations considered important and smaller for
observations considered less important. Its numel must equal the number of columns of :attr:`input`.
Must have floating point dtype. Ignored if ``None``. `Defaults to ``None``.
Returns:
(Tensor) The covariance matrix of the variables.
Example::
>>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T
>>> x
tensor([[0, 1, 2],
[2, 1, 0]])
>>> torch.cov(x)
tensor([[ 1., -1.],
[-1., 1.]])
>>> torch.cov(x, correction=0)
tensor([[ 0.6667, -0.6667],
[-0.6667, 0.6667]])
>>> fw = torch.randint(1, 10, (3,))
>>> fw
tensor([1, 6, 9])
>>> aw = torch.rand(3)
>>> aw
tensor([0.4282, 0.0255, 0.4144])
>>> torch.cov(x, fweights=fw, aweights=aw)
tensor([[ 0.4169, -0.4169],
[-0.4169, 0.4169]])
""")
add_docstr(torch.cat,
r"""
cat(tensors, dim=0, *, out=None) -> Tensor

View File

@ -374,6 +374,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.clamp_min: lambda input, min, out=None: -1,
torch.clamp_max: lambda input, max, out=None: -1,
torch.column_stack: lambda tensors, out=None: -1,
torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
torch.clone: lambda input: -1,
torch.combinations: lambda input, r=2, with_replacement=False: -1,
torch.complex: lambda real, imag: -1,

View File

@ -3037,6 +3037,22 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
]
def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
shapes = [(2,), (1, 2), (3, 2), (2, 3)]
inputs = []
for shape in shapes:
t = make_tensor(shape, device, dtype, requires_grad=requires_grad)
inputs.append(SampleInput(t))
num_observations = t.numel() if t.ndimension() < 2 else t.size(1)
fweights = make_tensor((num_observations,), device, torch.int, low=0, high=10, requires_grad=requires_grad)
aweights = make_tensor((num_observations,), device, torch.float, low=0, high=1, requires_grad=requires_grad)
for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]):
inputs.append(SampleInput(t, kwargs={'correction': correction, 'fweights': fw, 'aweights': aw}))
return inputs
def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False):
"""
This function generates input for torch.svd with distinct singular values so that autograd is always stable.
@ -5285,6 +5301,14 @@ op_db: List[OpInfo] = [
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', device_type='cpu',
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
)),
OpInfo('cov',
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
sample_inputs_func=sample_inputs_cov,
supports_out=False,
# JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
skips=(SkipInfo('TestJit', 'test_variant_consistency_jit'),)),
OpInfo('cross',
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex_and(torch.half),