mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implemented torch.cov (#58311)
Summary: Based from https://github.com/pytorch/pytorch/pull/50466 Adds the initial implementation of `torch.cov` similar to `numpy.cov`. For simplicity, we removed support for many parameters in `numpy.cov` that are either redundant such as `bias`, or have simple workarounds such as `y` and `rowvar`. cc PandaBoi closes https://github.com/pytorch/pytorch/issues/19037 Pull Request resolved: https://github.com/pytorch/pytorch/pull/58311 Reviewed By: jbschlosser Differential Revision: D29431651 Pulled By: heitorschueroff fbshipit-source-id: 167dea880f534934b145ba94291a9d634c25b01b
This commit is contained in:
parent
8f658d537d
commit
ec9c03c234
|
|
@ -260,6 +260,7 @@ _(aten, cosine_embedding_loss) \
|
||||||
_(aten, cosine_similarity) \
|
_(aten, cosine_similarity) \
|
||||||
_(aten, count_nonzero) \
|
_(aten, count_nonzero) \
|
||||||
_(aten, cross) \
|
_(aten, cross) \
|
||||||
|
_(aten, cov) \
|
||||||
_(aten, std_mean) \
|
_(aten, std_mean) \
|
||||||
_(aten, var_mean) \
|
_(aten, var_mean) \
|
||||||
_(aten, ctc_loss) \
|
_(aten, ctc_loss) \
|
||||||
|
|
|
||||||
109
aten/src/ATen/native/Correlation.cpp
Normal file
109
aten/src/ATen/native/Correlation.cpp
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
namespace native {
|
||||||
|
|
||||||
|
Tensor cov(
|
||||||
|
const Tensor& self,
|
||||||
|
int64_t correction,
|
||||||
|
const c10::optional<Tensor>& fweights,
|
||||||
|
const c10::optional<Tensor>& aweights) {
|
||||||
|
constexpr int64_t OBSERVATIONS_DIM = 1;
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
self.ndimension() <= 2,
|
||||||
|
"cov(): expected input to have two or fewer dimensions but got an input with ",
|
||||||
|
self.ndimension(),
|
||||||
|
" dimensions");
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
self.scalar_type() != kBool, "cov(): bool dtype is not supported for input");
|
||||||
|
|
||||||
|
// View input tensor as 2D (variables, observations)
|
||||||
|
auto in = self.ndimension() < 2 ? self.view({1, -1}) : self;
|
||||||
|
const auto num_observations = in.size(OBSERVATIONS_DIM);
|
||||||
|
|
||||||
|
// The product of frequencies (fweights) and weights (aweights).
|
||||||
|
Tensor w;
|
||||||
|
|
||||||
|
if (fweights.has_value()) {
|
||||||
|
w = fweights.value();
|
||||||
|
TORCH_CHECK(
|
||||||
|
w.ndimension() <= 1,
|
||||||
|
"cov(): expected fweights to have one or fewer dimensions but got fweights with ",
|
||||||
|
w.ndimension(),
|
||||||
|
" dimensions");
|
||||||
|
TORCH_CHECK(
|
||||||
|
at::isIntegralType(w.scalar_type(), false),
|
||||||
|
"cov(): expected fweights to have integral dtype but got fweights with ",
|
||||||
|
w.scalar_type(),
|
||||||
|
" dtype");
|
||||||
|
TORCH_CHECK(
|
||||||
|
w.numel() == num_observations,
|
||||||
|
"cov(): expected fweights to have the same numel as there are observations in the input but got ",
|
||||||
|
w.numel(),
|
||||||
|
" != ",
|
||||||
|
num_observations);
|
||||||
|
TORCH_CHECK(
|
||||||
|
num_observations == 0 || w.min().ge(0).item<bool>(),
|
||||||
|
"cov(): fweights cannot be negative");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (aweights.has_value()) {
|
||||||
|
const auto& aw = aweights.value();
|
||||||
|
TORCH_CHECK(
|
||||||
|
aw.ndimension() <= 1,
|
||||||
|
"cov(): expected aweights to have one or fewer dimensions but got aweights with ",
|
||||||
|
aw.ndimension(),
|
||||||
|
" dimensions");
|
||||||
|
TORCH_CHECK(
|
||||||
|
at::isFloatingType(aw.scalar_type()),
|
||||||
|
"cov(): expected aweights to have floating point dtype but got aweights with ",
|
||||||
|
aw.scalar_type(),
|
||||||
|
" dtype");
|
||||||
|
TORCH_CHECK(
|
||||||
|
aw.numel() == num_observations,
|
||||||
|
"cov(): expected aweights to have the same numel as there are observations in the input but got ",
|
||||||
|
aw.numel(),
|
||||||
|
" != ",
|
||||||
|
num_observations);
|
||||||
|
TORCH_CHECK(
|
||||||
|
num_observations == 0 || aw.min().ge(0).item<bool>(),
|
||||||
|
"cov(): aweights cannot be negative");
|
||||||
|
w = w.defined() ? w * aw : aw;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute a weighted average of the observations
|
||||||
|
const auto w_sum = w.defined()
|
||||||
|
? w.sum()
|
||||||
|
: at::scalar_tensor(num_observations, in.options().dtype(kLong));
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
!w.defined() || w_sum.ne(0).item<bool>(),
|
||||||
|
"cov(): weights sum to zero, can't be normalized");
|
||||||
|
|
||||||
|
const auto avg = (w.defined() ? in * w : in).sum(OBSERVATIONS_DIM) / w_sum;
|
||||||
|
|
||||||
|
// Compute the normalization factor
|
||||||
|
Tensor norm_factor;
|
||||||
|
|
||||||
|
if (w.defined() && aweights.has_value() && correction != 0) {
|
||||||
|
norm_factor = w_sum - correction * (w * aweights.value()).sum() / w_sum;
|
||||||
|
} else {
|
||||||
|
norm_factor = w_sum - correction;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (norm_factor.le(0).item<bool>()) {
|
||||||
|
TORCH_WARN("cov(): degrees of freedom is <= 0");
|
||||||
|
norm_factor.zero_();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute covariance matrix
|
||||||
|
in = in - avg.unsqueeze(1);
|
||||||
|
const auto c = at::mm(in, (w.defined() ? in * w : in).t().conj());
|
||||||
|
return at::true_divide(c, norm_factor).squeeze();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace native
|
||||||
|
} // namespace at
|
||||||
|
|
@ -1277,6 +1277,9 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: count_nonzero
|
CompositeExplicitAutograd: count_nonzero
|
||||||
|
|
||||||
|
- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
|
||||||
|
variants: function, method
|
||||||
|
|
||||||
- func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
|
- func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
|
||||||
dispatch:
|
dispatch:
|
||||||
CUDA: cudnn_affine_grid_generator_forward
|
CUDA: cudnn_affine_grid_generator_forward
|
||||||
|
|
|
||||||
|
|
@ -290,6 +290,7 @@ Tensor class reference
|
||||||
Tensor.cosh
|
Tensor.cosh
|
||||||
Tensor.cosh_
|
Tensor.cosh_
|
||||||
Tensor.count_nonzero
|
Tensor.count_nonzero
|
||||||
|
Tensor.cov
|
||||||
Tensor.acosh
|
Tensor.acosh
|
||||||
Tensor.acosh_
|
Tensor.acosh_
|
||||||
Tensor.arccosh
|
Tensor.arccosh
|
||||||
|
|
|
||||||
|
|
@ -481,6 +481,7 @@ Other Operations
|
||||||
cdist
|
cdist
|
||||||
clone
|
clone
|
||||||
combinations
|
combinations
|
||||||
|
cov
|
||||||
cross
|
cross
|
||||||
cummax
|
cummax
|
||||||
cummin
|
cummin
|
||||||
|
|
|
||||||
|
|
@ -4372,6 +4372,53 @@ else:
|
||||||
x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
|
x = torch.empty(50000000, device=device, dtype=dtype).exponential_()
|
||||||
self.assertTrue(x.min() > 0)
|
self.assertTrue(x.min() > 0)
|
||||||
|
|
||||||
|
@dtypes(torch.float, torch.cfloat)
|
||||||
|
def test_cov(self, device, dtype):
|
||||||
|
def check(t, correction=1, fweights=None, aweights=None):
|
||||||
|
actual = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights)
|
||||||
|
t = t.cpu().numpy()
|
||||||
|
fweights = fweights.cpu().numpy() if fweights is not None else None
|
||||||
|
aweights = aweights.cpu().numpy() if aweights is not None else None
|
||||||
|
expected = np.cov(t, ddof=correction, fweights=fweights, aweights=aweights)
|
||||||
|
expected = torch.from_numpy(np.array(expected)).to(dtype=actual.dtype)
|
||||||
|
self.assertEqual(actual, expected, atol=1e-05, rtol=1e-05)
|
||||||
|
|
||||||
|
def generate_input_tensors():
|
||||||
|
yield make_tensor((0, 0), device, dtype)
|
||||||
|
yield make_tensor((1, 0), device, dtype)
|
||||||
|
yield make_tensor((0, 1), device, dtype)
|
||||||
|
yield make_tensor((2), device, dtype)
|
||||||
|
yield make_tensor((2, 1), device, dtype)
|
||||||
|
yield make_tensor((2, 2), device, dtype)
|
||||||
|
yield make_tensor((2, 3), device, dtype)
|
||||||
|
yield make_tensor((5, 10), device, dtype)
|
||||||
|
yield make_tensor((5, 10), device, dtype, noncontiguous=True)
|
||||||
|
yield torch.tensor([0, -2, nan, 10.2, inf], dtype=dtype, device=device)
|
||||||
|
|
||||||
|
for t in generate_input_tensors():
|
||||||
|
check(t)
|
||||||
|
num_observations = t.numel() if t.ndim < 2 else t.size(1)
|
||||||
|
if num_observations > 0:
|
||||||
|
fweights = torch.randint(1, 10, (num_observations,), device=device)
|
||||||
|
aweights = make_tensor((num_observations,), device, torch.float, low=1)
|
||||||
|
for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]):
|
||||||
|
check(t, correction, fweights, aweights)
|
||||||
|
|
||||||
|
def test_cov_error(self, device):
|
||||||
|
def check(msg, *args, **kwargs):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, r'cov\(\):.*' + msg + r'.*'):
|
||||||
|
torch.cov(*args, **kwargs)
|
||||||
|
|
||||||
|
a = torch.rand(2)
|
||||||
|
check(r'expected input to have two or fewer dimensions', torch.rand(2, 2, 2))
|
||||||
|
check(r'expected fweights to have one or fewer dimensions', a, fweights=torch.rand(2, 2))
|
||||||
|
check(r'expected aweights to have one or fewer dimensions', a, aweights=torch.rand(2, 2))
|
||||||
|
check(r'expected fweights to have integral dtype', a, fweights=torch.rand(2))
|
||||||
|
check(r'expected aweights to have floating point dtype', a, aweights=torch.tensor([1, 1]))
|
||||||
|
check(r'expected fweights to have the same numel', a, fweights=torch.tensor([1]))
|
||||||
|
check(r'expected aweights to have the same numel', a, aweights=torch.rand(1))
|
||||||
|
check(r'fweights cannot be negative', a, fweights=torch.tensor([-1, -2]))
|
||||||
|
check(r'aweights cannot be negative', a, aweights=torch.tensor([-1., -2.]))
|
||||||
|
|
||||||
@skipIfNoSciPy
|
@skipIfNoSciPy
|
||||||
@dtypes(*torch.testing.get_all_fp_dtypes())
|
@dtypes(*torch.testing.get_all_fp_dtypes())
|
||||||
|
|
|
||||||
|
|
@ -974,6 +974,7 @@ aten_native_source_non_codegen_list = [
|
||||||
"aten/src/ATen/native/ConvolutionMM3d.cpp",
|
"aten/src/ATen/native/ConvolutionMM3d.cpp",
|
||||||
"aten/src/ATen/native/ConvolutionTBC.cpp",
|
"aten/src/ATen/native/ConvolutionTBC.cpp",
|
||||||
"aten/src/ATen/native/Copy.cpp",
|
"aten/src/ATen/native/Copy.cpp",
|
||||||
|
"aten/src/ATen/native/Correlation.cpp",
|
||||||
"aten/src/ATen/native/CPUFallback.cpp",
|
"aten/src/ATen/native/CPUFallback.cpp",
|
||||||
"aten/src/ATen/native/Cross.cpp",
|
"aten/src/ATen/native/Cross.cpp",
|
||||||
"aten/src/ATen/native/DilatedMaxPool2d.cpp",
|
"aten/src/ATen/native/DilatedMaxPool2d.cpp",
|
||||||
|
|
|
||||||
|
|
@ -1004,6 +1004,12 @@ count_nonzero(dim=None) -> Tensor
|
||||||
See :func:`torch.count_nonzero`
|
See :func:`torch.count_nonzero`
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
add_docstr_all('cov', r"""
|
||||||
|
cov(*, correction=1, fweights=None, aweights=None) -> Tensor
|
||||||
|
|
||||||
|
See :func:`torch.cov`
|
||||||
|
""")
|
||||||
|
|
||||||
add_docstr_all('cross',
|
add_docstr_all('cross',
|
||||||
r"""
|
r"""
|
||||||
cross(other, dim=-1) -> Tensor
|
cross(other, dim=-1) -> Tensor
|
||||||
|
|
|
||||||
|
|
@ -1688,6 +1688,75 @@ Example::
|
||||||
False
|
False
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
add_docstr(torch.cov, r"""
|
||||||
|
cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor
|
||||||
|
|
||||||
|
Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are
|
||||||
|
the variables and columns are the observations.
|
||||||
|
|
||||||
|
A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains
|
||||||
|
the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents
|
||||||
|
a single variable (Scalar or 1D) then its variance is returned.
|
||||||
|
|
||||||
|
The unbiased sample covariance of the variables :math:`x` and :math:`y` is given by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{N~-~1}
|
||||||
|
|
||||||
|
where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively.
|
||||||
|
|
||||||
|
If :attr:`fweights` and/or :attr:`aweights` are provided, the unbiased weighted covariance
|
||||||
|
is calculated, which is given by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)}{\sum^{N}_{i = 1}w_i~-~1}
|
||||||
|
|
||||||
|
where :math:`w` denotes :attr:`fweights` or :attr:`aweights` based on whichever is provided, or
|
||||||
|
:math:`w = fweights \times aweights` if both are provided, and
|
||||||
|
:math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): A 2D matrix containing multiple variables and observations, or a
|
||||||
|
Scalar or 1D vector representing a single variable.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
correction (int, optional): difference between the sample size and sample degrees of freedom.
|
||||||
|
Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate,
|
||||||
|
even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0``
|
||||||
|
will return the simple average. Defaults to ``1``.
|
||||||
|
fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of
|
||||||
|
times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`.
|
||||||
|
Must have integral dtype. Ignored if ``None``. `Defaults to ``None``.
|
||||||
|
aweights (tensor, optional): A Scalar or 1D array of observation vector weights.
|
||||||
|
These relative weights are typically large for observations considered “important” and smaller for
|
||||||
|
observations considered less “important”. Its numel must equal the number of columns of :attr:`input`.
|
||||||
|
Must have floating point dtype. Ignored if ``None``. `Defaults to ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor) The covariance matrix of the variables.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T
|
||||||
|
>>> x
|
||||||
|
tensor([[0, 1, 2],
|
||||||
|
[2, 1, 0]])
|
||||||
|
>>> torch.cov(x)
|
||||||
|
tensor([[ 1., -1.],
|
||||||
|
[-1., 1.]])
|
||||||
|
>>> torch.cov(x, correction=0)
|
||||||
|
tensor([[ 0.6667, -0.6667],
|
||||||
|
[-0.6667, 0.6667]])
|
||||||
|
>>> fw = torch.randint(1, 10, (3,))
|
||||||
|
>>> fw
|
||||||
|
tensor([1, 6, 9])
|
||||||
|
>>> aw = torch.rand(3)
|
||||||
|
>>> aw
|
||||||
|
tensor([0.4282, 0.0255, 0.4144])
|
||||||
|
>>> torch.cov(x, fweights=fw, aweights=aw)
|
||||||
|
tensor([[ 0.4169, -0.4169],
|
||||||
|
[-0.4169, 0.4169]])
|
||||||
|
""")
|
||||||
|
|
||||||
add_docstr(torch.cat,
|
add_docstr(torch.cat,
|
||||||
r"""
|
r"""
|
||||||
cat(tensors, dim=0, *, out=None) -> Tensor
|
cat(tensors, dim=0, *, out=None) -> Tensor
|
||||||
|
|
|
||||||
|
|
@ -374,6 +374,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.clamp_min: lambda input, min, out=None: -1,
|
torch.clamp_min: lambda input, min, out=None: -1,
|
||||||
torch.clamp_max: lambda input, max, out=None: -1,
|
torch.clamp_max: lambda input, max, out=None: -1,
|
||||||
torch.column_stack: lambda tensors, out=None: -1,
|
torch.column_stack: lambda tensors, out=None: -1,
|
||||||
|
torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
|
||||||
torch.clone: lambda input: -1,
|
torch.clone: lambda input: -1,
|
||||||
torch.combinations: lambda input, r=2, with_replacement=False: -1,
|
torch.combinations: lambda input, r=2, with_replacement=False: -1,
|
||||||
torch.complex: lambda real, imag: -1,
|
torch.complex: lambda real, imag: -1,
|
||||||
|
|
|
||||||
|
|
@ -3037,6 +3037,22 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
|
shapes = [(2,), (1, 2), (3, 2), (2, 3)]
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
for shape in shapes:
|
||||||
|
t = make_tensor(shape, device, dtype, requires_grad=requires_grad)
|
||||||
|
inputs.append(SampleInput(t))
|
||||||
|
num_observations = t.numel() if t.ndimension() < 2 else t.size(1)
|
||||||
|
fweights = make_tensor((num_observations,), device, torch.int, low=0, high=10, requires_grad=requires_grad)
|
||||||
|
aweights = make_tensor((num_observations,), device, torch.float, low=0, high=1, requires_grad=requires_grad)
|
||||||
|
for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]):
|
||||||
|
inputs.append(SampleInput(t, kwargs={'correction': correction, 'fweights': fw, 'aweights': aw}))
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False):
|
def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False):
|
||||||
"""
|
"""
|
||||||
This function generates input for torch.svd with distinct singular values so that autograd is always stable.
|
This function generates input for torch.svd with distinct singular values so that autograd is always stable.
|
||||||
|
|
@ -5285,6 +5301,14 @@ op_db: List[OpInfo] = [
|
||||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', device_type='cpu',
|
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', device_type='cpu',
|
||||||
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
|
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
|
||||||
)),
|
)),
|
||||||
|
OpInfo('cov',
|
||||||
|
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||||
|
backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||||
|
sample_inputs_func=sample_inputs_cov,
|
||||||
|
supports_out=False,
|
||||||
|
# JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
|
||||||
|
skips=(SkipInfo('TestJit', 'test_variant_consistency_jit'),)),
|
||||||
OpInfo('cross',
|
OpInfo('cross',
|
||||||
dtypes=all_types_and_complex(),
|
dtypes=all_types_and_complex(),
|
||||||
dtypesIfCUDA=all_types_and_complex_and(torch.half),
|
dtypesIfCUDA=all_types_and_complex_and(torch.half),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user