Add TORCH_CHECK_TENSOR_ALL (#89097)

`TORCH_CHECK_TENSOR_ALL(cond, ...)` is a wrapper around `TORCH_CHECK` which allows the condition argument to be a tensor, batched or unbatched. `cond` can be a boolean tensor of any size. If any element is False, or if `cond.numel() == 0`, then `TORCH_CHECK_TENSOR_ALL` raises an error

Part of #72948
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89097
Approved by: https://github.com/zou3519
This commit is contained in:
Kurt Mohler 2023-01-19 21:04:09 +00:00 committed by PyTorch MergeBot
parent 25e530083e
commit 647b8f8e3e
10 changed files with 145 additions and 0 deletions

View File

@ -10,6 +10,9 @@
// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
#define TORCH_CHECK_TENSOR_ALL(cond, ...) \
TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
namespace at {
// The following are utility functions for checking that arguments

View File

@ -252,6 +252,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(swapdims);
OP_DECOMPOSE(take_along_dim);
OP_DECOMPOSE(tensordot);
OP_DECOMPOSE(_test_check_tensor);
OP_DECOMPOSE(tile);
OP_DECOMPOSE2(trapezoid, x);
OP_DECOMPOSE2(trapezoid, dx);

View File

@ -20,6 +20,11 @@ Tensor sum_decomp(
return at::sum(self, range(0, self.dim()), false, dtype);
}
std::tuple<Tensor, optional<int64_t>> _is_all_true_batch_rule(
const Tensor& self, optional<int64_t> self_bdim) {
return std::make_tuple(at::_is_all_true(self), nullopt);
}
Tensor mean_decomp(
const Tensor& self, optional<ScalarType> dtype) {
return at::mean(self, range(0, self.dim()), false, dtype);
@ -502,5 +507,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(aminmax, aminmax_batching_rule);
VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
}
}}

View File

@ -2035,6 +2035,10 @@ Tensor all(const Tensor& self, Dimname dim, bool keepdim) {
Tensor& all_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
reportNYIDimnameOverload("all");
}
Tensor _is_all_true(const Tensor& self) {
TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
return self.all();
}
Tensor logcumsumexp(const Tensor& self, Dimname dim) {
return at::logcumsumexp(self, dimname_to_position(self, dim));
}

View File

@ -106,6 +106,11 @@ Tensor _test_autograd_multiple_dispatch_view(const Tensor &self) {
return self.view(-1);
}
Tensor _test_check_tensor(const Tensor& self) {
TORCH_CHECK_TENSOR_ALL(self, "Test message for TORCH_CHECK_TENSOR_ALL");
return self.clone();
}
} // namespace native
namespace functionalization {

View File

@ -617,6 +617,15 @@
- func: affine_grid_generator_backward(Tensor grad, int[] size, bool align_corners) -> Tensor
variants: function
- func: _is_all_true(Tensor self) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: _is_all_true
# Note: this function is only for testing.
- func: _test_check_tensor(Tensor self) -> Tensor
variants: function
- func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.out

View File

@ -16,6 +16,7 @@ import functools
import itertools
import warnings
import unittest
import random
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
@ -4902,6 +4903,83 @@ class TestTransformFailure(TestCase):
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
transform(input)
class TestVmapDeviceType(Namespace.TestVmapBase):
def _vmap_test(self, *args, **kwargs):
return _vmap_test(self, *args, **kwargs)
def test__is_all_true(self, device):
def test():
def f(x, *, expected_result):
result = torch.ops.aten._is_all_true(x)
self.assertFalse(torch._C._functorch.is_batchedtensor(result))
self.assertEqual(result.shape, torch.Size([]))
self.assertEqual(result.item(), expected_result)
return result
x = torch.rand(10, device=device)
vmap(f)(x >= 0, expected_result=True)
vmap(f)(x < 0, expected_result=False)
x[random.choice(range(10))] *= -1
vmap(f)(x >= 0, expected_result=False)
vmap(f)(x < 0, expected_result=False)
x = -torch.rand(10, device=device)
vmap(f)(x > 0, expected_result=False)
vmap(f)(x <= 0, expected_result=True)
check_vmap_fallback(self, test, torch._is_all_true)
def test_check_tensor(self, device):
def test():
test_sizes = [
(1,),
(10,),
(1, 1),
(1, 10),
(10, 1),
(10, 10),
(1, 1, 1),
(10, 1, 1),
(1, 10, 1),
(10, 10, 10),
]
def check_gte_0(t):
return torch._test_check_tensor(t >= 0)
error_message = "Test message for TORCH_CHECK_TENSOR_ALL"
for size in test_sizes:
t_all_gte_0 = torch.rand(size, device=device)
t_all_lt_0 = t_all_gte_0 - 1
vmap(check_gte_0)(t_all_gte_0)
if len(size) >= 2:
vmap(vmap(check_gte_0))(t_all_gte_0)
with self.assertRaisesRegex(RuntimeError, error_message):
vmap(check_gte_0)(t_all_lt_0)
if len(size) >= 2:
with self.assertRaisesRegex(RuntimeError, error_message):
vmap(vmap(check_gte_0))(t_all_lt_0)
if t_all_gte_0.numel() > 1:
t_all_gte_0_but_one = t_all_gte_0.clone()
idx = (random.choice(range(dim_size)) for dim_size in size)
t_all_gte_0_but_one[(..., *idx)] = -1
with self.assertRaisesRegex(RuntimeError, error_message):
vmap(check_gte_0)(t_all_gte_0_but_one)
if len(size) >= 2:
with self.assertRaisesRegex(RuntimeError, error_message):
vmap(vmap(check_gte_0))(t_all_gte_0_but_one)
check_vmap_fallback(self, test, torch._test_check_tensor)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
@ -4912,6 +4990,7 @@ instantiate_device_type_tests(
)
instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for)
instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for)
instantiate_device_type_tests(TestVmapDeviceType, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()

View File

@ -740,6 +740,40 @@ class TestTorchDeviceType(TestCase):
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)
# Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
def test_check_tensor(self, device):
test_sizes = [
(),
(1,),
(10,),
(1, 1),
(1, 10),
(10, 1),
(10, 10),
(1, 1, 1),
(10, 1, 1),
(1, 10, 1),
(10, 10, 10),
]
for size in test_sizes:
t_all_true = torch.ones(size, dtype=torch.bool, device=device)
t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
# Should not raise error
torch._test_check_tensor(t_all_true)
with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
torch._test_check_tensor(t_all_false)
if t_all_true.numel() > 1:
t_all_true_but_one = t_all_true.clone()
# Choose a random element to set to false
idx = (random.choice(range(dim_size)) for dim_size in size)
t_all_true_but_one[(..., *idx)] = False
with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
torch._test_check_tensor(t_all_true_but_one)
# Uses mismatched arange out size to trigger a warning
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering")

View File

@ -278,6 +278,9 @@
- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
output_differentiability: [False]
- name: _is_all_true(Tensor self) -> Tensor
self: non_differentiable
- name: all(Tensor self) -> Tensor
output_differentiability: [False]

View File

@ -292,6 +292,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._conj_physical,
Tensor._neg_view,
Tensor._is_zerotensor,
Tensor._is_all_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
}