torch.isreal (#41298)

Summary:
https://github.com/pytorch/pytorch/issues/38349

mruberry
Not entirely sure if all the changes are necessary in how functions are added to Pytorch.

Should it throw an error when called with a non-complex tensor? Numpy allows non-complex arrays in its imag() function which is used in its isreal() function but Pytorch's imag() throws an error for non-complex arrays.

Where does assertONNX() get its expected output to compare to?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41298

Reviewed By: ngimel

Differential Revision: D22610500

Pulled By: mruberry

fbshipit-source-id: 817d61f8b1c3670788b81690636bd41335788439
This commit is contained in:
Justin Huber 2020-07-17 22:05:40 -07:00 committed by Facebook GitHub Bot
parent 581e9526bb
commit c6d0fdd215
9 changed files with 72 additions and 0 deletions

View File

@ -403,6 +403,7 @@ _(aten, is_set_to) \
_(aten, is_signed) \ _(aten, is_signed) \
_(aten, is_sparse) \ _(aten, is_sparse) \
_(aten, isclose) \ _(aten, isclose) \
_(aten, isreal) \
_(aten, istft) \ _(aten, istft) \
_(aten, kl_div) \ _(aten, kl_div) \
_(aten, kl_div_backward) \ _(aten, kl_div_backward) \

View File

@ -79,6 +79,16 @@ Tensor isnan(const Tensor& self) {
return self != self; return self != self;
} }
Tensor isreal(const Tensor& self) {
// Note: Integral and Floating tensor values are always real
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true) ||
c10::isFloatingType(self.scalar_type())) {
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
}
return at::imag(self) == 0;
}
Tensor isinf(const Tensor &self) { Tensor isinf(const Tensor &self) {
// Note: Integral tensor values are never infinite // Note: Integral tensor values are never infinite
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) { if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {

View File

@ -1544,6 +1544,10 @@
variants: function, method variants: function, method
device_guard: False device_guard: False
- func: isreal(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
- func: is_nonzero(Tensor self) -> bool - func: is_nonzero(Tensor self) -> bool
use_c10_dispatcher: full use_c10_dispatcher: full
variants: function, method variants: function, method

View File

@ -351,6 +351,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: is_signed .. automethod:: is_signed
.. autoattribute:: is_sparse .. autoattribute:: is_sparse
.. automethod:: istft .. automethod:: istft
.. automethod:: isreal
.. automethod:: item .. automethod:: item
.. automethod:: kthvalue .. automethod:: kthvalue
.. automethod:: lcm .. automethod:: lcm

View File

@ -352,6 +352,7 @@ Comparison Ops
isfinite isfinite
isinf isinf
isnan isnan
isreal
kthvalue kthvalue
le le
lt lt

View File

@ -6473,6 +6473,34 @@ class TestTorchDeviceType(TestCase):
self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype)
self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype)
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.complex64, torch.complex128)
def test_isreal_complex(self, device, dtype):
vals = (1, 1 + 1j, 2 + 0j, 3j, 2 - 1j, 2 - 0j)
self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype)
@dtypes(*torch.testing.get_all_dtypes())
def test_isreal_noncomplex(self, device, dtype):
vals = (1, 2, 3)
# Manual check here since numpy doesn't support bfloat16
result = torch.isreal(torch.tensor(vals, dtype=dtype))
expected = torch.ones(result.size(), dtype=torch.bool, device=device)
self.assertEqual(result, expected)
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.complex64)
def test_isreal_nan_inf(self, device, dtype):
vals = (
complex(-float('inf'), float('inf')),
complex(-float('inf'), 0),
complex(0, float('inf')),
complex(float('inf'), float('nan')),
complex(float('nan'), 0),
complex(-1, 0),
complex(0, 1)
)
self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype)
@onlyCPU @onlyCPU
def test_isfinite_type(self, device): def test_isfinite_type(self, device):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):

View File

@ -346,6 +346,7 @@ def get_testing_overrides():
torch.index_fill: lambda input, dim, index, value: -1, torch.index_fill: lambda input, dim, index, value: -1,
torch.isfinite: lambda tensor: -1, torch.isfinite: lambda tensor: -1,
torch.isinf: lambda tensor: -1, torch.isinf: lambda tensor: -1,
torch.isreal: lambda tensor: -1,
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
cudnn_enabled: -1), cudnn_enabled: -1),
torch.int_repr: lambda input: -1, torch.int_repr: lambda input: -1,

View File

@ -1651,6 +1651,13 @@ isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor
See :func:`torch.isclose` See :func:`torch.isclose`
""") """)
add_docstr_all('isreal',
r"""
isreal() -> Tensor
See :func:`torch.isreal`
""")
add_docstr_all('is_contiguous', add_docstr_all('is_contiguous',
r""" r"""
is_contiguous(memory_format=torch.contiguous_format) -> bool is_contiguous(memory_format=torch.contiguous_format) -> bool

View File

@ -2908,6 +2908,25 @@ Example::
tensor([False, True, False]) tensor([False, True, False])
""") """)
add_docstr(torch.isreal,
r"""
isreal(input) -> Tensor
Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not.
All real-valued types are considered real. Complex values are considered real when their imaginary part is 0.
Arguments:
{input}
Returns:
Tensor: A boolean tensor with True where :attr:`input` is real-valued and False elsewhere.
Example::
>>> torch.isreal(torch.tensor([1, 1+1j, 2+0j]))
tensor([True, False, True])
""")
add_docstr(torch.is_floating_point, add_docstr(torch.is_floating_point,
r""" r"""
is_floating_point(input) -> (bool) is_floating_point(input) -> (bool)