mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
torch.isreal (#41298)
Summary: https://github.com/pytorch/pytorch/issues/38349 mruberry Not entirely sure if all the changes are necessary in how functions are added to Pytorch. Should it throw an error when called with a non-complex tensor? Numpy allows non-complex arrays in its imag() function which is used in its isreal() function but Pytorch's imag() throws an error for non-complex arrays. Where does assertONNX() get its expected output to compare to? Pull Request resolved: https://github.com/pytorch/pytorch/pull/41298 Reviewed By: ngimel Differential Revision: D22610500 Pulled By: mruberry fbshipit-source-id: 817d61f8b1c3670788b81690636bd41335788439
This commit is contained in:
parent
581e9526bb
commit
c6d0fdd215
|
|
@ -403,6 +403,7 @@ _(aten, is_set_to) \
|
|||
_(aten, is_signed) \
|
||||
_(aten, is_sparse) \
|
||||
_(aten, isclose) \
|
||||
_(aten, isreal) \
|
||||
_(aten, istft) \
|
||||
_(aten, kl_div) \
|
||||
_(aten, kl_div_backward) \
|
||||
|
|
|
|||
|
|
@ -79,6 +79,16 @@ Tensor isnan(const Tensor& self) {
|
|||
return self != self;
|
||||
}
|
||||
|
||||
Tensor isreal(const Tensor& self) {
|
||||
// Note: Integral and Floating tensor values are always real
|
||||
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true) ||
|
||||
c10::isFloatingType(self.scalar_type())) {
|
||||
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
|
||||
}
|
||||
|
||||
return at::imag(self) == 0;
|
||||
}
|
||||
|
||||
Tensor isinf(const Tensor &self) {
|
||||
// Note: Integral tensor values are never infinite
|
||||
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
|
||||
|
|
|
|||
|
|
@ -1544,6 +1544,10 @@
|
|||
variants: function, method
|
||||
device_guard: False
|
||||
|
||||
- func: isreal(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
||||
- func: is_nonzero(Tensor self) -> bool
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -351,6 +351,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: is_signed
|
||||
.. autoattribute:: is_sparse
|
||||
.. automethod:: istft
|
||||
.. automethod:: isreal
|
||||
.. automethod:: item
|
||||
.. automethod:: kthvalue
|
||||
.. automethod:: lcm
|
||||
|
|
|
|||
|
|
@ -352,6 +352,7 @@ Comparison Ops
|
|||
isfinite
|
||||
isinf
|
||||
isnan
|
||||
isreal
|
||||
kthvalue
|
||||
le
|
||||
lt
|
||||
|
|
|
|||
|
|
@ -6473,6 +6473,34 @@ class TestTorchDeviceType(TestCase):
|
|||
self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype)
|
||||
self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@dtypes(torch.complex64, torch.complex128)
|
||||
def test_isreal_complex(self, device, dtype):
|
||||
vals = (1, 1 + 1j, 2 + 0j, 3j, 2 - 1j, 2 - 0j)
|
||||
self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_isreal_noncomplex(self, device, dtype):
|
||||
vals = (1, 2, 3)
|
||||
# Manual check here since numpy doesn't support bfloat16
|
||||
result = torch.isreal(torch.tensor(vals, dtype=dtype))
|
||||
expected = torch.ones(result.size(), dtype=torch.bool, device=device)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@dtypes(torch.complex64)
|
||||
def test_isreal_nan_inf(self, device, dtype):
|
||||
vals = (
|
||||
complex(-float('inf'), float('inf')),
|
||||
complex(-float('inf'), 0),
|
||||
complex(0, float('inf')),
|
||||
complex(float('inf'), float('nan')),
|
||||
complex(float('nan'), 0),
|
||||
complex(-1, 0),
|
||||
complex(0, 1)
|
||||
)
|
||||
self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
def test_isfinite_type(self, device):
|
||||
with self.assertRaises(TypeError):
|
||||
|
|
|
|||
|
|
@ -346,6 +346,7 @@ def get_testing_overrides():
|
|||
torch.index_fill: lambda input, dim, index, value: -1,
|
||||
torch.isfinite: lambda tensor: -1,
|
||||
torch.isinf: lambda tensor: -1,
|
||||
torch.isreal: lambda tensor: -1,
|
||||
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
|
||||
cudnn_enabled: -1),
|
||||
torch.int_repr: lambda input: -1,
|
||||
|
|
|
|||
|
|
@ -1651,6 +1651,13 @@ isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor
|
|||
See :func:`torch.isclose`
|
||||
""")
|
||||
|
||||
add_docstr_all('isreal',
|
||||
r"""
|
||||
isreal() -> Tensor
|
||||
|
||||
See :func:`torch.isreal`
|
||||
""")
|
||||
|
||||
add_docstr_all('is_contiguous',
|
||||
r"""
|
||||
is_contiguous(memory_format=torch.contiguous_format) -> bool
|
||||
|
|
|
|||
|
|
@ -2908,6 +2908,25 @@ Example::
|
|||
tensor([False, True, False])
|
||||
""")
|
||||
|
||||
add_docstr(torch.isreal,
|
||||
r"""
|
||||
isreal(input) -> Tensor
|
||||
|
||||
Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not.
|
||||
All real-valued types are considered real. Complex values are considered real when their imaginary part is 0.
|
||||
|
||||
Arguments:
|
||||
{input}
|
||||
|
||||
Returns:
|
||||
Tensor: A boolean tensor with True where :attr:`input` is real-valued and False elsewhere.
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.isreal(torch.tensor([1, 1+1j, 2+0j]))
|
||||
tensor([True, False, True])
|
||||
""")
|
||||
|
||||
add_docstr(torch.is_floating_point,
|
||||
r"""
|
||||
is_floating_point(input) -> (bool)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user