mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add TORCH_CHECK_TENSOR_ALL (#89097)
`TORCH_CHECK_TENSOR_ALL(cond, ...)` is a wrapper around `TORCH_CHECK` which allows the condition argument to be a tensor, batched or unbatched. `cond` can be a boolean tensor of any size. If any element is False, or if `cond.numel() == 0`, then `TORCH_CHECK_TENSOR_ALL` raises an error Part of #72948 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89097 Approved by: https://github.com/zou3519
This commit is contained in:
parent
25e530083e
commit
647b8f8e3e
|
|
@ -10,6 +10,9 @@
|
|||
|
||||
// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
|
||||
|
||||
#define TORCH_CHECK_TENSOR_ALL(cond, ...) \
|
||||
TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
|
||||
|
||||
namespace at {
|
||||
|
||||
// The following are utility functions for checking that arguments
|
||||
|
|
|
|||
|
|
@ -252,6 +252,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
|||
OP_DECOMPOSE(swapdims);
|
||||
OP_DECOMPOSE(take_along_dim);
|
||||
OP_DECOMPOSE(tensordot);
|
||||
OP_DECOMPOSE(_test_check_tensor);
|
||||
OP_DECOMPOSE(tile);
|
||||
OP_DECOMPOSE2(trapezoid, x);
|
||||
OP_DECOMPOSE2(trapezoid, dx);
|
||||
|
|
|
|||
|
|
@ -20,6 +20,11 @@ Tensor sum_decomp(
|
|||
return at::sum(self, range(0, self.dim()), false, dtype);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, optional<int64_t>> _is_all_true_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim) {
|
||||
return std::make_tuple(at::_is_all_true(self), nullopt);
|
||||
}
|
||||
|
||||
Tensor mean_decomp(
|
||||
const Tensor& self, optional<ScalarType> dtype) {
|
||||
return at::mean(self, range(0, self.dim()), false, dtype);
|
||||
|
|
@ -502,5 +507,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||
VMAP_SUPPORT(aminmax, aminmax_batching_rule);
|
||||
VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
|
||||
VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
|
||||
VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
|
||||
}
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -2035,6 +2035,10 @@ Tensor all(const Tensor& self, Dimname dim, bool keepdim) {
|
|||
Tensor& all_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
|
||||
reportNYIDimnameOverload("all");
|
||||
}
|
||||
Tensor _is_all_true(const Tensor& self) {
|
||||
TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
|
||||
return self.all();
|
||||
}
|
||||
Tensor logcumsumexp(const Tensor& self, Dimname dim) {
|
||||
return at::logcumsumexp(self, dimname_to_position(self, dim));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -106,6 +106,11 @@ Tensor _test_autograd_multiple_dispatch_view(const Tensor &self) {
|
|||
return self.view(-1);
|
||||
}
|
||||
|
||||
Tensor _test_check_tensor(const Tensor& self) {
|
||||
TORCH_CHECK_TENSOR_ALL(self, "Test message for TORCH_CHECK_TENSOR_ALL");
|
||||
return self.clone();
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
||||
namespace functionalization {
|
||||
|
|
|
|||
|
|
@ -617,6 +617,15 @@
|
|||
- func: affine_grid_generator_backward(Tensor grad, int[] size, bool align_corners) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: _is_all_true(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _is_all_true
|
||||
|
||||
# Note: this function is only for testing.
|
||||
- func: _test_check_tensor(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
structured_delegate: all.out
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import functools
|
|||
import itertools
|
||||
import warnings
|
||||
import unittest
|
||||
import random
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_cuda import with_tf32_off
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
|
|
@ -4902,6 +4903,83 @@ class TestTransformFailure(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
|
||||
transform(input)
|
||||
|
||||
class TestVmapDeviceType(Namespace.TestVmapBase):
|
||||
def _vmap_test(self, *args, **kwargs):
|
||||
return _vmap_test(self, *args, **kwargs)
|
||||
|
||||
def test__is_all_true(self, device):
|
||||
def test():
|
||||
def f(x, *, expected_result):
|
||||
result = torch.ops.aten._is_all_true(x)
|
||||
self.assertFalse(torch._C._functorch.is_batchedtensor(result))
|
||||
self.assertEqual(result.shape, torch.Size([]))
|
||||
self.assertEqual(result.item(), expected_result)
|
||||
return result
|
||||
|
||||
x = torch.rand(10, device=device)
|
||||
vmap(f)(x >= 0, expected_result=True)
|
||||
vmap(f)(x < 0, expected_result=False)
|
||||
|
||||
x[random.choice(range(10))] *= -1
|
||||
vmap(f)(x >= 0, expected_result=False)
|
||||
vmap(f)(x < 0, expected_result=False)
|
||||
|
||||
x = -torch.rand(10, device=device)
|
||||
vmap(f)(x > 0, expected_result=False)
|
||||
vmap(f)(x <= 0, expected_result=True)
|
||||
|
||||
check_vmap_fallback(self, test, torch._is_all_true)
|
||||
|
||||
def test_check_tensor(self, device):
|
||||
def test():
|
||||
test_sizes = [
|
||||
(1,),
|
||||
(10,),
|
||||
(1, 1),
|
||||
(1, 10),
|
||||
(10, 1),
|
||||
(10, 10),
|
||||
(1, 1, 1),
|
||||
(10, 1, 1),
|
||||
(1, 10, 1),
|
||||
(10, 10, 10),
|
||||
]
|
||||
|
||||
def check_gte_0(t):
|
||||
return torch._test_check_tensor(t >= 0)
|
||||
|
||||
error_message = "Test message for TORCH_CHECK_TENSOR_ALL"
|
||||
|
||||
for size in test_sizes:
|
||||
t_all_gte_0 = torch.rand(size, device=device)
|
||||
t_all_lt_0 = t_all_gte_0 - 1
|
||||
|
||||
vmap(check_gte_0)(t_all_gte_0)
|
||||
|
||||
if len(size) >= 2:
|
||||
vmap(vmap(check_gte_0))(t_all_gte_0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_message):
|
||||
vmap(check_gte_0)(t_all_lt_0)
|
||||
|
||||
if len(size) >= 2:
|
||||
with self.assertRaisesRegex(RuntimeError, error_message):
|
||||
vmap(vmap(check_gte_0))(t_all_lt_0)
|
||||
|
||||
if t_all_gte_0.numel() > 1:
|
||||
t_all_gte_0_but_one = t_all_gte_0.clone()
|
||||
idx = (random.choice(range(dim_size)) for dim_size in size)
|
||||
t_all_gte_0_but_one[(..., *idx)] = -1
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_message):
|
||||
vmap(check_gte_0)(t_all_gte_0_but_one)
|
||||
|
||||
if len(size) >= 2:
|
||||
with self.assertRaisesRegex(RuntimeError, error_message):
|
||||
vmap(vmap(check_gte_0))(t_all_gte_0_but_one)
|
||||
|
||||
check_vmap_fallback(self, test, torch._test_check_tensor)
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
|
||||
|
||||
|
|
@ -4912,6 +4990,7 @@ instantiate_device_type_tests(
|
|||
)
|
||||
instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for)
|
||||
instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for)
|
||||
instantiate_device_type_tests(TestVmapDeviceType, globals(), only_for=only_for)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -740,6 +740,40 @@ class TestTorchDeviceType(TestCase):
|
|||
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
|
||||
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)
|
||||
|
||||
# Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
|
||||
def test_check_tensor(self, device):
|
||||
test_sizes = [
|
||||
(),
|
||||
(1,),
|
||||
(10,),
|
||||
(1, 1),
|
||||
(1, 10),
|
||||
(10, 1),
|
||||
(10, 10),
|
||||
(1, 1, 1),
|
||||
(10, 1, 1),
|
||||
(1, 10, 1),
|
||||
(10, 10, 10),
|
||||
]
|
||||
for size in test_sizes:
|
||||
t_all_true = torch.ones(size, dtype=torch.bool, device=device)
|
||||
t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
|
||||
|
||||
# Should not raise error
|
||||
torch._test_check_tensor(t_all_true)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
|
||||
torch._test_check_tensor(t_all_false)
|
||||
|
||||
if t_all_true.numel() > 1:
|
||||
t_all_true_but_one = t_all_true.clone()
|
||||
# Choose a random element to set to false
|
||||
idx = (random.choice(range(dim_size)) for dim_size in size)
|
||||
t_all_true_but_one[(..., *idx)] = False
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"):
|
||||
torch._test_check_tensor(t_all_true_but_one)
|
||||
|
||||
# Uses mismatched arange out size to trigger a warning
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering")
|
||||
|
|
|
|||
|
|
@ -278,6 +278,9 @@
|
|||
- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
- name: _is_all_true(Tensor self) -> Tensor
|
||||
self: non_differentiable
|
||||
|
||||
- name: all(Tensor self) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
|
|
|
|||
|
|
@ -292,6 +292,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
Tensor._conj_physical,
|
||||
Tensor._neg_view,
|
||||
Tensor._is_zerotensor,
|
||||
Tensor._is_all_true,
|
||||
Tensor._addmm_activation,
|
||||
Tensor.to_padded_tensor,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user