diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 196444509c5..4019cf2ff9b 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -19,15 +19,27 @@ void _assert_match(const O& original, const C& compared, const std::string& name if (!equal) { std::stringstream msg; msg << "Tensor " << name << " mismatch!"; - AT_ASSERT(equal, msg.str()); + if (!equal) { + throw std::runtime_error(msg.str()); + } } } } -void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional dtype) { +void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { + _assert_match(tensor.sym_sizes(), sizes, "sizes"); + _assert_match(tensor.sym_strides(), strides, "strides"); + _assert_match(tensor.dtype(), dtype, "dtype"); + _assert_match(tensor.device(), device, "device"); + _assert_match(tensor.layout(), layout, "layout"); +} + +void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { _assert_match(tensor.sizes(), sizes, "sizes"); _assert_match(tensor.strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); + _assert_match(tensor.device(), device, "device"); + _assert_match(tensor.layout(), layout, "layout"); } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index eba6b72004b..646b6aa497f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -187,7 +187,10 @@ dispatch: CPU: _functional_assert_async_msg_cpu -- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () +- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () + dispatch: + CompositeExplicitAutograd: _assert_tensor_metadata + Meta: _assert_tensor_metadata_meta_symint - func: _print(str s) -> () dispatch: diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 946e3104f56..85066e34112 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -30,6 +30,7 @@ aten::_amp_update_scale.out aten::_amp_update_scale_ aten::_assert_async aten::_assert_async.msg +aten::_assert_tensor_metadata aten::_batch_norm_no_update.out aten::_batch_norm_with_update.out aten::_cdist_backward diff --git a/test/test_comparison_utils.py b/test/test_comparison_utils.py index 6c5c65d1a0c..a4ebd806035 100644 --- a/test/test_comparison_utils.py +++ b/test/test_comparison_utils.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # Owner(s): ["module: internals"] +import unittest + import torch from torch.testing._internal.common_utils import run_tests, TestCase @@ -32,6 +34,19 @@ class TestComparisonUtils(TestCase): with self.assertRaises(RuntimeError): torch._assert_tensor_metadata(t, [3], [1], torch.float) + @unittest.skipIf(not torch.cuda.is_available(), "Requires cuda") + def test_assert_device(self): + t = torch.tensor([0.5], device="cpu") + + with self.assertRaises(RuntimeError): + torch._assert_tensor_metadata(t, device="cuda") + + def test_assert_layout(self): + t = torch.tensor([0.5]) + + with self.assertRaises(RuntimeError): + torch._assert_tensor_metadata(t, layout=torch.sparse_coo) + if __name__ == "__main__": run_tests() diff --git a/torch/fx/node.py b/torch/fx/node.py index 469b6340384..50e96760eac 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -80,6 +80,7 @@ _side_effectful_functions: Set[Callable] = { torch._assert_async, _ops.aten._assert_async.msg, _ops.aten._assert_scalar.default, + _ops.aten._assert_tensor_metadata.default, _ops.aten.sym_constrain_range.default, _ops.aten.sym_constrain_range_for_size.default, _ops.profiler._record_function_enter, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 1ae4599407c..b73bd444736 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -53,6 +53,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ "_assert_async", # no return "_assert_async.msg", # no return + "_assert_tensor_metadata", # no return "_cslt_sparse_mm_search", # returns an int "_assert_scalar", # no return "_dimI", # returns an int