mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
C++ API parity: isfinite
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28918 Test Plan: Imported from OSS Differential Revision: D18233037 Pulled By: pbelevich fbshipit-source-id: c76b9467bbc1fbb2c9bf49855895c98438b36c12
This commit is contained in:
parent
5d69bc1eda
commit
8df5e10ee9
|
|
@ -87,6 +87,16 @@ Tensor isnan(const Tensor& self) {
|
|||
return self != self;
|
||||
}
|
||||
|
||||
Tensor isfinite(const Tensor& self) {
|
||||
// Integral tensor types are finite
|
||||
if (!self.is_floating_point()) {
|
||||
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
|
||||
}
|
||||
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "isfinite", [&]() {
|
||||
return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
|
||||
});
|
||||
}
|
||||
|
||||
bool is_nonzero(const Tensor& self) {
|
||||
auto n = self.numel();
|
||||
AT_ASSERT(n >= 0);
|
||||
|
|
|
|||
|
|
@ -6483,3 +6483,9 @@
|
|||
dispatch:
|
||||
CPU: im2col_backward_cpu
|
||||
CUDA: im2col_backward_cuda
|
||||
|
||||
- func: isfinite(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
device_guard: False
|
||||
supports_named_tensor: True
|
||||
|
|
|
|||
|
|
@ -1960,3 +1960,63 @@ TEST_F(FunctionalTest, Dropout3d) {
|
|||
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
|
||||
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
|
||||
}
|
||||
|
||||
template<c10::ScalarType S, typename T>
|
||||
void test_isfinite(const at::Device& device) {
|
||||
const std::vector<T> values = {
|
||||
std::numeric_limits<T>::lowest(),
|
||||
0, 1, 42,
|
||||
std::numeric_limits<T>::min(),
|
||||
std::numeric_limits<T>::max()
|
||||
};
|
||||
for (const auto value : values) {
|
||||
const auto x = torch::full({3, 3}, value, torch::TensorOptions().dtype(S).device(device));
|
||||
ASSERT_TRUE(torch::isfinite(x).all().template item<bool>());
|
||||
}
|
||||
if (std::numeric_limits<T>::has_infinity) {
|
||||
const auto inf = std::numeric_limits<T>::infinity();
|
||||
const auto x = torch::tensor({
|
||||
-inf,
|
||||
std::numeric_limits<T>::lowest(),
|
||||
static_cast<T>(0),
|
||||
static_cast<T>(1),
|
||||
static_cast<T>(42),
|
||||
std::numeric_limits<T>::min(),
|
||||
std::numeric_limits<T>::max(),
|
||||
inf
|
||||
}, torch::TensorOptions().dtype(S).device(device));
|
||||
ASSERT_TRUE(torch::allclose(
|
||||
// torch::allclose does not support comparing torch::kBool
|
||||
torch::isfinite(x).toType(torch::kInt),
|
||||
torch::tensor(
|
||||
{false, true, true, true, true, true, true, false},
|
||||
torch::TensorOptions().device(device)
|
||||
).toType(torch::kInt)
|
||||
));
|
||||
}
|
||||
if (std::numeric_limits<T>::has_quiet_NaN) {
|
||||
const auto x = torch::tensor({
|
||||
std::numeric_limits<T>::quiet_NaN()
|
||||
}, torch::TensorOptions().dtype(S).device(device));
|
||||
ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
|
||||
}
|
||||
if (std::numeric_limits<T>::has_signaling_NaN) {
|
||||
const auto x = torch::tensor({
|
||||
std::numeric_limits<T>::signaling_NaN()
|
||||
}, torch::TensorOptions().dtype(S).device(device));
|
||||
ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, isfinite) {
|
||||
for (const auto device : {at::Device("cpu"), at::Device("cuda")}) {
|
||||
test_isfinite<torch::kUInt8, unsigned char>(device);
|
||||
test_isfinite<torch::kInt8, char>(device);
|
||||
test_isfinite<torch::kInt16, short>(device);
|
||||
test_isfinite<torch::kInt32, int>(device);
|
||||
test_isfinite<torch::kInt64, long>(device);
|
||||
test_isfinite<torch::kFloat32, float>(device);
|
||||
test_isfinite<torch::kFloat64, double>(device);
|
||||
}
|
||||
test_isfinite<torch::kFloat16, c10::Half>(at::Device("cuda"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2412,6 +2412,22 @@ Example::
|
|||
tensor(1.9073e-06)
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.isfinite,
|
||||
r"""
|
||||
Returns a new tensor with boolean elements representing if each element is `Finite` or not.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): A tensor to check
|
||||
|
||||
Returns:
|
||||
Tensor: ``A torch.Tensor with dtype torch.bool`` containing a True at each location of finite elements and False otherwise
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
|
||||
tensor([True, False, True, False, False])
|
||||
""")
|
||||
|
||||
add_docstr(torch.isnan,
|
||||
r"""
|
||||
Returns a new tensor with boolean elements representing if each element is `NaN` or not.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ __all__ = [
|
|||
'cdist',
|
||||
'chain_matmul',
|
||||
'einsum',
|
||||
'isfinite',
|
||||
'isinf',
|
||||
'lu',
|
||||
'lu_unpack',
|
||||
|
|
@ -242,32 +241,6 @@ Examples::
|
|||
return torch._C._VariableFunctions.einsum(equation, operands)
|
||||
|
||||
|
||||
def isfinite(tensor):
|
||||
r"""Returns a new tensor with boolean elements representing if each element is `Finite` or not.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): A tensor to check
|
||||
|
||||
Returns:
|
||||
Tensor: ``A torch.Tensor with dtype torch.bool`` containing a True at each location of finite elements and False otherwise
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
|
||||
tensor([True, False, True, False, False])
|
||||
"""
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("The argument is not a tensor: {}".format(repr(tensor)))
|
||||
|
||||
# Support int input, nan and inf are concepts in floating point numbers.
|
||||
# Numpy uses type 'Object' when the int overflows long, but we don't
|
||||
# have a similar concept. It's safe to assume any created LongTensor doesn't
|
||||
# overflow and it's finite.
|
||||
if not tensor.is_floating_point():
|
||||
return torch.ones_like(tensor, dtype=torch.bool, memory_format=torch.legacy_contiguous_format)
|
||||
return (tensor == tensor) & (tensor.abs() != inf)
|
||||
|
||||
|
||||
def isinf(tensor):
|
||||
r"""Returns a new tensor with boolean elements representing if each element is `+/-INF` or not.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user