From c6d0fdd21537a2c23f94e7a44ac552d24b447906 Mon Sep 17 00:00:00 2001 From: Justin Huber Date: Fri, 17 Jul 2020 22:05:40 -0700 Subject: [PATCH] torch.isreal (#41298) Summary: https://github.com/pytorch/pytorch/issues/38349 mruberry Not entirely sure if all the changes are necessary in how functions are added to Pytorch. Should it throw an error when called with a non-complex tensor? Numpy allows non-complex arrays in its imag() function which is used in its isreal() function but Pytorch's imag() throws an error for non-complex arrays. Where does assertONNX() get its expected output to compare to? Pull Request resolved: https://github.com/pytorch/pytorch/pull/41298 Reviewed By: ngimel Differential Revision: D22610500 Pulled By: mruberry fbshipit-source-id: 817d61f8b1c3670788b81690636bd41335788439 --- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/TensorCompare.cpp | 10 ++++++++ aten/src/ATen/native/native_functions.yaml | 4 ++++ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_torch.py | 28 ++++++++++++++++++++++ torch/_overrides.py | 1 + torch/_tensor_docs.py | 7 ++++++ torch/_torch_docs.py | 19 +++++++++++++++ 9 files changed, 72 insertions(+) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 758b5544033..41d61844239 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -403,6 +403,7 @@ _(aten, is_set_to) \ _(aten, is_signed) \ _(aten, is_sparse) \ _(aten, isclose) \ +_(aten, isreal) \ _(aten, istft) \ _(aten, kl_div) \ _(aten, kl_div_backward) \ diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index a2193c17aaa..2aecf12c36c 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -79,6 +79,16 @@ Tensor isnan(const Tensor& self) { return self != self; } +Tensor isreal(const Tensor& self) { + // Note: Integral and Floating tensor values are always real + if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true) || + c10::isFloatingType(self.scalar_type())) { + return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve); + } + + return at::imag(self) == 0; +} + Tensor isinf(const Tensor &self) { // Note: Integral tensor values are never infinite if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 377d496f4b4..e898c9a1158 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1544,6 +1544,10 @@ variants: function, method device_guard: False +- func: isreal(Tensor self) -> Tensor + use_c10_dispatcher: full + variants: function, method + - func: is_nonzero(Tensor self) -> bool use_c10_dispatcher: full variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index dcb9b587424..11575b01153 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -351,6 +351,7 @@ view of a storage and defines numeric operations on it. .. automethod:: is_signed .. autoattribute:: is_sparse .. automethod:: istft + .. automethod:: isreal .. automethod:: item .. automethod:: kthvalue .. automethod:: lcm diff --git a/docs/source/torch.rst b/docs/source/torch.rst index f750aace12b..94a4e5b2531 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -352,6 +352,7 @@ Comparison Ops isfinite isinf isnan + isreal kthvalue le lt diff --git a/test/test_torch.py b/test/test_torch.py index 40aea1f12b7..4e3bec2428f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6473,6 +6473,34 @@ class TestTorchDeviceType(TestCase): self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) + @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') + @dtypes(torch.complex64, torch.complex128) + def test_isreal_complex(self, device, dtype): + vals = (1, 1 + 1j, 2 + 0j, 3j, 2 - 1j, 2 - 0j) + self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) + + @dtypes(*torch.testing.get_all_dtypes()) + def test_isreal_noncomplex(self, device, dtype): + vals = (1, 2, 3) + # Manual check here since numpy doesn't support bfloat16 + result = torch.isreal(torch.tensor(vals, dtype=dtype)) + expected = torch.ones(result.size(), dtype=torch.bool, device=device) + self.assertEqual(result, expected) + + @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') + @dtypes(torch.complex64) + def test_isreal_nan_inf(self, device, dtype): + vals = ( + complex(-float('inf'), float('inf')), + complex(-float('inf'), 0), + complex(0, float('inf')), + complex(float('inf'), float('nan')), + complex(float('nan'), 0), + complex(-1, 0), + complex(0, 1) + ) + self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) + @onlyCPU def test_isfinite_type(self, device): with self.assertRaises(TypeError): diff --git a/torch/_overrides.py b/torch/_overrides.py index 45bb2aec6bf..027d0afd81b 100644 --- a/torch/_overrides.py +++ b/torch/_overrides.py @@ -346,6 +346,7 @@ def get_testing_overrides(): torch.index_fill: lambda input, dim, index, value: -1, torch.isfinite: lambda tensor: -1, torch.isinf: lambda tensor: -1, + torch.isreal: lambda tensor: -1, torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1), torch.int_repr: lambda input: -1, diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 86da9e73bce..28102964c36 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1651,6 +1651,13 @@ isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor See :func:`torch.isclose` """) +add_docstr_all('isreal', + r""" +isreal() -> Tensor + +See :func:`torch.isreal` +""") + add_docstr_all('is_contiguous', r""" is_contiguous(memory_format=torch.contiguous_format) -> bool diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 2e0d2634861..4988286154a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2908,6 +2908,25 @@ Example:: tensor([False, True, False]) """) +add_docstr(torch.isreal, + r""" +isreal(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. +All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + +Arguments: + {input} + +Returns: + Tensor: A boolean tensor with True where :attr:`input` is real-valued and False elsewhere. + +Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) +""") + add_docstr(torch.is_floating_point, r""" is_floating_point(input) -> (bool)