diff --git a/aten/src/ATen/native/Constraints.cpp b/aten/src/ATen/native/Constraints.cpp new file mode 100644 index 00000000000..a015ca19c9a --- /dev/null +++ b/aten/src/ATen/native/Constraints.cpp @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include + +namespace at { +namespace native { + +void sym_constrain_range_cpu( + const Scalar& size, + c10::optional min = c10::nullopt, + c10::optional max = c10::nullopt) {} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/Constraints.cu b/aten/src/ATen/native/cuda/Constraints.cu new file mode 100644 index 00000000000..2f4bd4b67f4 --- /dev/null +++ b/aten/src/ATen/native/cuda/Constraints.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include + +namespace at { +namespace native { + +void sym_constrain_range_cuda( + const Scalar& size, + c10::optional min = c10::nullopt, + c10::optional max = c10::nullopt) {} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d23abffea99..5f151ae2f26 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -177,6 +177,11 @@ - func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () +- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> () + dispatch: + CPU: sym_constrain_range_cpu + CUDA: sym_constrain_range_cuda + - func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) variants: method diff --git a/build_variables.bzl b/build_variables.bzl index dc6cdd31fd3..0d7571db9a1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1244,6 +1244,7 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/ChanelShuffle.cpp", "aten/src/ATen/native/Col2Im.cpp", "aten/src/ATen/native/PadNd.cpp", + "aten/src/ATen/native/Constraints.cpp", "aten/src/ATen/native/Convolution.cpp", "aten/src/ATen/native/ConvolutionMM2d.cpp", "aten/src/ATen/native/ConvolutionMM3d.cpp", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 00eb62b2bb5..2b43a2790dd 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -22,6 +22,7 @@ from torch._export import dynamic_dim from torch._export.constraints import constrain_as_size, constrain_as_value from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ConstraintViolationError +from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfRocm @@ -2408,6 +2409,44 @@ def forward(self, x): buffer = io.BytesIO() torch.save(gm, buffer) + def test_export_with_inline_constraints(self): + def f(x): + a = x.item() + constrain_as_size(a, 4, 7) + return torch.empty((a, 4)) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]" + ) as cm: + torch._export.export(f, (torch.tensor([20]),)) + + ep = torch._export.export(f, (torch.tensor([5]),)) + self.assertEqual(ep(torch.tensor([6])).shape, (6, 4)) + + FileCheck().check_count( + "torch.ops.aten.sym_constrain_range.default", 1, exactly=True + ).run(ep.graph_module.code) + + with self.assertRaisesRegex( + RuntimeError, + r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]", + ) as cm: + ep(torch.tensor([30])) + + def test_export_with_inline_constraints_complex(self): + def f(x): + a = x.item() + constrain_as_size(a, 4, 7) + empty = torch.empty((a, 4)) + + return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0) + + ep = torch._export.export(f, (torch.tensor([6]),)) + self.assertEqual(ep(torch.tensor([5])).shape, (10, 5)) + FileCheck().check_count( + "torch.ops.aten.sym_constrain_range.default", 1, exactly=True + ).run(ep.graph_module.code) + def test_export_dynamic_dim_not_1(self): x = torch.randn([1, 1, 1]) diff --git a/test/export/test_export.py b/test/export/test_export.py index b1fcd3e75dd..13a34388c2a 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -68,7 +68,6 @@ class TestExperimentalExport(TestCase): @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestDynamismExpression(TestCase): - @unittest.expectedFailure def test_export_inline_constraints(self): def f(x): diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 77559333c97..73a7636a6b6 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -268,9 +268,9 @@ class TestPasses(TestCase): num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - # 2 constraints for b - self.assertEqual(num_assert, 2) - self.assertEqual(num_scalar_tensor, 2) + # TODO: De-duplicate assertions for same symbol. + self.assertEqual(num_assert, 4) + self.assertEqual(num_scalar_tensor, 4) with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."): ep(torch.tensor([1, 1, 0, 0, 0])) diff --git a/torch/_export/constraints.py b/torch/_export/constraints.py index ddfca1813f4..6f8d382cccc 100644 --- a/torch/_export/constraints.py +++ b/torch/_export/constraints.py @@ -1,9 +1,19 @@ -from typing import Optional +from typing import Optional, Callable, Union +import torch +from torch import SymInt, SymFloat from torch._dynamo import allow_in_graph -from torch.fx.experimental.symbolic_shapes import constrain_range +from torch.fx.experimental.symbolic_shapes import constrain_range_int from torch.utils._sympy.value_ranges import ValueRangeError +# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]` +# could cause type error during since `SymInt` or `SymFloat` will be used. +# Here manually specify the type explicitly. +sym_constrain_range: Callable[ + [Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]], + None, +] = torch.sym_constrain_range # type: ignore[assignment] + # TODO: we want to hide this min/max stuff under some abstraction similar to # DynamicDim @@ -13,7 +23,11 @@ def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = N Add min/max constraint on the intermediate symbol at tracing time """ - constrain_range(symbol, min=min, max=max) + if not isinstance(symbol, SymInt): + constrain_range_int(symbol, min=min, max=max) + else: + sym_constrain_range(symbol, min, max) + return symbol diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 98b5fb36a40..e3c4f1b3a2c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -30,7 +30,7 @@ from torch._prims_common.wrappers import ( out_wrapper, ) from torch._refs import _broadcast_shapes - +from torch.fx.experimental.symbolic_shapes import constrain_range from torch.utils._pytree import tree_map @@ -311,6 +311,11 @@ def assert_async_meta(val, assert_msg): return +@register_meta(aten.sym_constrain_range.default) +def sym_constrain_range(size, min, max): + constrain_range(size, min=min, max=max) + + # From aten/src/ATen/native/LinearAlgebraUtils.h def squareCheckInputs(self: Tensor, f_name: str): assert ( diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 3b7dd842378..3ba091a86f7 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1488,7 +1488,7 @@ class FakeTensorMode(TorchDispatchMode): nonlocal common_device nonlocal has_scalar_only_inputs - if common_device is None: + if isinstance(e, torch.Tensor) and common_device is None: ( common_device, has_scalar_only_inputs, diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e79a4f4d1d9..b7a1c581f31 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -348,20 +348,7 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): if max is None: max = sympy.oo if not isinstance(a, SymInt): - if not (min <= a <= max): - raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]") - - if ( - (fake_mode := detect_fake_mode()) is not None and - getattr(fake_mode, "shape_env", None) is not None - ): - # If we are tracing with a fake mode then add this integer to the - # shape_env's var_to_range - sym_integer = sympy.Integer(a) - shape_env = fake_mode.shape_env - _constrain_symbol_range(shape_env, sym_integer, min, max) - shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack() - + constrain_range_int(a, min=min, max=max) return if isinstance(a.node.expr, sympy.Integer): @@ -376,6 +363,30 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): # SymInt). _constrain_symbol_range(a.node.shape_env, a.node.expr, min, max) +def constrain_range_int(a, *, min, max): + """ + Constrain range on concrete int value. + This can happens for the following scenarios: + - Eager mode execution and real int value is provided. + - During tracing the traced symbol is resolved as a static integer (see + PR #101655 for more details). + """ + + assert not isinstance(a, SymInt) + if not (min <= a <= max): + raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]") + + if ( + (fake_mode := detect_fake_mode()) is not None and + getattr(fake_mode, "shape_env", None) is not None + ): + # If we are tracing with a fake mode then add this integer to the + # shape_env's var_to_range + sym_integer = sympy.Integer(a) + shape_env = fake_mode.shape_env + _constrain_symbol_range(shape_env, sym_integer, min, max) + shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack() + def constrain_unify(a, b): """ Given two SymInts, constrain them so that they must be equal. NB: diff --git a/torch/fx/node.py b/torch/fx/node.py index afb7d2917e8..199670db154 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -35,6 +35,7 @@ _side_effectful_functions: Set[Callable] = { torch._assert_async, _ops.aten._assert_async.msg, _ops.aten.copy_.default, + _ops.aten.sym_constrain_range.default, _ops.profiler._record_function_enter, _ops.profiler._record_function_enter_new, _ops.profiler._record_function_exit} diff --git a/torch/overrides.py b/torch/overrides.py index c51b35e2082..2b3cb0baa0c 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -185,6 +185,7 @@ def get_ignored_functions() -> Set[Callable]: torch.sym_max, torch.sym_min, torch.sym_not, + torch.sym_constrain_range, torch.tril_indices, torch.triu_indices, torch.vander, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 8b7429e9826..653d7b29562 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -73,6 +73,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ "qscheme", # returns a QScheme "record_stream", # no return "sparse_dim", # returns an int + "sym_constrain_range", # no return "_nested_tensor_storage_offsets", # returns a vector of ints "_chunk_grad_outputs_efficient_attention", # returns a bool "_fused_sdp_choice", # returns an int