mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Clean up error handling in is_nonzero and where in TensorCompare.cpp (#38150)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38150 Differential Revision: D21539736 Pulled By: ezyang fbshipit-source-id: e390c12f5948192a552d66dcd1bb89b2cb45f170
This commit is contained in:
parent
5a979fcb99
commit
336e1ec592
|
|
@ -115,13 +115,9 @@ Tensor isfinite(const Tensor& self) {
|
|||
|
||||
bool is_nonzero(const Tensor& self) {
|
||||
auto n = self.numel();
|
||||
AT_ASSERT(n >= 0);
|
||||
if (n == 0) {
|
||||
AT_ERROR("bool value of Tensor with no values is ambiguous");
|
||||
}
|
||||
if (n > 1) {
|
||||
AT_ERROR("bool value of Tensor with more than one value is ambiguous");
|
||||
}
|
||||
TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
|
||||
TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous");
|
||||
|
||||
Scalar localScalar = self.item();
|
||||
if (localScalar.isFloatingPoint()) {
|
||||
return localScalar.to<double>() != 0;
|
||||
|
|
@ -132,18 +128,17 @@ bool is_nonzero(const Tensor& self) {
|
|||
} else if (localScalar.isBoolean()) {
|
||||
return localScalar.to<bool>();
|
||||
}
|
||||
AT_ERROR("expected non-Tensor backend scalar");
|
||||
TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar");
|
||||
}
|
||||
|
||||
Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
|
||||
"expected condition, x and y to be on the same device, but condition is on ",
|
||||
"Expected condition, x and y to be on the same device, but condition is on ",
|
||||
condition.device(), " and x and y are on ", self.device(), " and ", other.device(),
|
||||
" respectively");
|
||||
if (condition.scalar_type() != ScalarType::Byte && condition.scalar_type() != ScalarType::Bool) {
|
||||
AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ",
|
||||
toString(condition.scalar_type()));
|
||||
}
|
||||
TORCH_CHECK(condition.scalar_type() == ScalarType::Byte || condition.scalar_type() == ScalarType::Bool,
|
||||
"Expected condition to have ScalarType Byte, but got ScalarType ",
|
||||
toString(condition.scalar_type()));
|
||||
Tensor b_condition, b_self, b_other;
|
||||
std::tie(b_condition, b_self, b_other) = expand_outplace(condition, self, other, "where");
|
||||
return at::_s_where(b_condition, b_self, b_other);
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
bool value of Tensor with no values is ambiguous
|
||||
Boolean value of Tensor with no values is ambiguous
|
||||
|
|
@ -1 +1 @@
|
|||
bool value of Tensor with more than one value is ambiguous
|
||||
Boolean value of Tensor with more than one value is ambiguous
|
||||
|
|
@ -306,7 +306,7 @@ class TestList(JitTestCase):
|
|||
test_invalid_list_equality,
|
||||
(),
|
||||
RuntimeError,
|
||||
"bool value of Tensor")
|
||||
"Boolean value of Tensor")
|
||||
|
||||
def test_list_sort(self):
|
||||
template = dedent('''
|
||||
|
|
@ -338,7 +338,7 @@ class TestList(JitTestCase):
|
|||
return x
|
||||
|
||||
self.checkScriptRaisesRegex(test_fail, (([torch.zeros([2]), torch.zeros([2])],)), Exception,
|
||||
"bool value of Tensor with more than one value")
|
||||
"Boolean value of Tensor with more than one value")
|
||||
|
||||
@torch.jit.script
|
||||
def test_mutation():
|
||||
|
|
|
|||
|
|
@ -6889,7 +6889,7 @@ a")
|
|||
self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
|
||||
|
||||
self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
|
||||
"bool value of Tensor with more than one value")
|
||||
"Boolean value of Tensor with more than one value")
|
||||
|
||||
def test_not_cast(x):
|
||||
if not x:
|
||||
|
|
|
|||
|
|
@ -2214,7 +2214,7 @@ class TestSparse(TestCase):
|
|||
self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (0., 0.), (1,)).is_nonzero())
|
||||
self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (-1., 1.), (1,)).is_nonzero())
|
||||
self.assertTrue(torch.sparse_coo_tensor(torch.zeros(0, 1), 12.3, []).is_nonzero()) # scalar sparse tensor
|
||||
with self.assertRaisesRegex(RuntimeError, "bool value of Tensor with no values is ambiguous"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
|
||||
torch.sparse_coo_tensor(([0, 1],), self.value_empty(2, 0), (4, 0)).is_nonzero()
|
||||
|
||||
def test_allow_tensor_metadata_change(self):
|
||||
|
|
|
|||
|
|
@ -410,7 +410,7 @@ class _TestTorchMixin(object):
|
|||
x = torch.rand(16, device=devices[1])
|
||||
y = torch.rand(16, device=devices[2])
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"expected condition, x and y to be on the same device"):
|
||||
"Expected condition, x and y to be on the same device"):
|
||||
torch.where(condition, x, y)
|
||||
|
||||
def test_where_bool_tensor(self):
|
||||
|
|
@ -4721,12 +4721,12 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
|||
self.assertExpectedRaisesInline(
|
||||
RuntimeError,
|
||||
lambda: torch.tensor([]).is_nonzero(),
|
||||
"bool value of Tensor with no values is ambiguous",
|
||||
"Boolean value of Tensor with no values is ambiguous",
|
||||
)
|
||||
self.assertExpectedRaisesInline(
|
||||
RuntimeError,
|
||||
lambda: torch.tensor([0, 0]).is_nonzero(),
|
||||
"bool value of Tensor with more than one value is ambiguous",
|
||||
"Boolean value of Tensor with more than one value is ambiguous",
|
||||
)
|
||||
self.assertFalse(torch.tensor(0).is_nonzero())
|
||||
self.assertTrue(torch.tensor(1).is_nonzero())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user