Clean up error handling in is_nonzero and where in TensorCompare.cpp (#38150)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38150

Differential Revision: D21539736

Pulled By: ezyang

fbshipit-source-id: e390c12f5948192a552d66dcd1bb89b2cb45f170
This commit is contained in:
Hong Xu 2020-05-13 20:17:29 -07:00 committed by Facebook GitHub Bot
parent 5a979fcb99
commit 336e1ec592
7 changed files with 17 additions and 22 deletions

View File

@ -115,13 +115,9 @@ Tensor isfinite(const Tensor& self) {
bool is_nonzero(const Tensor& self) {
auto n = self.numel();
AT_ASSERT(n >= 0);
if (n == 0) {
AT_ERROR("bool value of Tensor with no values is ambiguous");
}
if (n > 1) {
AT_ERROR("bool value of Tensor with more than one value is ambiguous");
}
TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous");
Scalar localScalar = self.item();
if (localScalar.isFloatingPoint()) {
return localScalar.to<double>() != 0;
@ -132,18 +128,17 @@ bool is_nonzero(const Tensor& self) {
} else if (localScalar.isBoolean()) {
return localScalar.to<bool>();
}
AT_ERROR("expected non-Tensor backend scalar");
TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar");
}
Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
"expected condition, x and y to be on the same device, but condition is on ",
"Expected condition, x and y to be on the same device, but condition is on ",
condition.device(), " and x and y are on ", self.device(), " and ", other.device(),
" respectively");
if (condition.scalar_type() != ScalarType::Byte && condition.scalar_type() != ScalarType::Bool) {
AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ",
toString(condition.scalar_type()));
}
TORCH_CHECK(condition.scalar_type() == ScalarType::Byte || condition.scalar_type() == ScalarType::Bool,
"Expected condition to have ScalarType Byte, but got ScalarType ",
toString(condition.scalar_type()));
Tensor b_condition, b_self, b_other;
std::tie(b_condition, b_self, b_other) = expand_outplace(condition, self, other, "where");
return at::_s_where(b_condition, b_self, b_other);

View File

@ -1 +1 @@
bool value of Tensor with no values is ambiguous
Boolean value of Tensor with no values is ambiguous

View File

@ -1 +1 @@
bool value of Tensor with more than one value is ambiguous
Boolean value of Tensor with more than one value is ambiguous

View File

@ -306,7 +306,7 @@ class TestList(JitTestCase):
test_invalid_list_equality,
(),
RuntimeError,
"bool value of Tensor")
"Boolean value of Tensor")
def test_list_sort(self):
template = dedent('''
@ -338,7 +338,7 @@ class TestList(JitTestCase):
return x
self.checkScriptRaisesRegex(test_fail, (([torch.zeros([2]), torch.zeros([2])],)), Exception,
"bool value of Tensor with more than one value")
"Boolean value of Tensor with more than one value")
@torch.jit.script
def test_mutation():

View File

@ -6889,7 +6889,7 @@ a")
self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
"bool value of Tensor with more than one value")
"Boolean value of Tensor with more than one value")
def test_not_cast(x):
if not x:

View File

@ -2214,7 +2214,7 @@ class TestSparse(TestCase):
self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (0., 0.), (1,)).is_nonzero())
self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (-1., 1.), (1,)).is_nonzero())
self.assertTrue(torch.sparse_coo_tensor(torch.zeros(0, 1), 12.3, []).is_nonzero()) # scalar sparse tensor
with self.assertRaisesRegex(RuntimeError, "bool value of Tensor with no values is ambiguous"):
with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
torch.sparse_coo_tensor(([0, 1],), self.value_empty(2, 0), (4, 0)).is_nonzero()
def test_allow_tensor_metadata_change(self):

View File

@ -410,7 +410,7 @@ class _TestTorchMixin(object):
x = torch.rand(16, device=devices[1])
y = torch.rand(16, device=devices[2])
with self.assertRaisesRegex(RuntimeError,
"expected condition, x and y to be on the same device"):
"Expected condition, x and y to be on the same device"):
torch.where(condition, x, y)
def test_where_bool_tensor(self):
@ -4721,12 +4721,12 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertExpectedRaisesInline(
RuntimeError,
lambda: torch.tensor([]).is_nonzero(),
"bool value of Tensor with no values is ambiguous",
"Boolean value of Tensor with no values is ambiguous",
)
self.assertExpectedRaisesInline(
RuntimeError,
lambda: torch.tensor([0, 0]).is_nonzero(),
"bool value of Tensor with more than one value is ambiguous",
"Boolean value of Tensor with more than one value is ambiguous",
)
self.assertFalse(torch.tensor(0).is_nonzero())
self.assertTrue(torch.tensor(1).is_nonzero())