From b27c3558a43a28d9185e1e953cf90128cb13e870 Mon Sep 17 00:00:00 2001 From: xuanqi Date: Thu, 15 Jun 2023 16:27:30 -0700 Subject: [PATCH] [RFC]: Create aten native op for constrain_range (#103346) At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets: ``` def f(x): a = x.item() constrain_as_size(a, 4, 7) return torch.empty((a, 4)) inp = torch.tensor([5]) ep = torch._export.export(f, (inp,)) ``` The reason is because current constrain logic is: 1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op). 2) Utilize side effect to add range constraints for traced symbol's shape env ([code](https://github.com/pytorch/pytorch/blob/9591e52880fb09cb6753d97606e6efb956075f5b/torch/fx/experimental/symbolic_shapes.py#L370-L372)). 3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](https://github.com/pytorch/pytorch/blob/9591e52880fb09cb6753d97606e6efb956075f5b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py#L98-L100) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round. 4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue. 5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive). The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export. **NOTE:** [Logic](https://github.com/pytorch/pytorch/blob/2d745b95d723641e575027bd4e2fff612f61cc8f/torch/fx/experimental/symbolic_shapes.py#L350-L365C15) within [`constrain_range`](https://github.com/pytorch/pytorch/blob/2d745b95d723641e575027bd4e2fff612f61cc8f/torch/fx/experimental/symbolic_shapes.py#LL313C74-L313C74) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided: * If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op. * So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example: ``` ... inp = torch.tensor([10]) ep = torch._export.export(f, (inp,)) # immediately raise error ``` Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204) Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346 Approved by: https://github.com/tugsbayasgalan --- aten/src/ATen/native/Constraints.cpp | 15 +++++++++ aten/src/ATen/native/cuda/Constraints.cu | 15 +++++++++ aten/src/ATen/native/native_functions.yaml | 5 +++ build_variables.bzl | 1 + test/dynamo/test_export.py | 39 ++++++++++++++++++++++ test/export/test_export.py | 1 - test/export/test_passes.py | 6 ++-- torch/_export/constraints.py | 20 +++++++++-- torch/_meta_registrations.py | 7 +++- torch/_subclasses/fake_tensor.py | 2 +- torch/fx/experimental/symbolic_shapes.py | 39 ++++++++++++++-------- torch/fx/node.py | 1 + torch/overrides.py | 1 + torchgen/native_function_generation.py | 1 + 14 files changed, 130 insertions(+), 23 deletions(-) create mode 100644 aten/src/ATen/native/Constraints.cpp create mode 100644 aten/src/ATen/native/cuda/Constraints.cu diff --git a/aten/src/ATen/native/Constraints.cpp b/aten/src/ATen/native/Constraints.cpp new file mode 100644 index 00000000000..a015ca19c9a --- /dev/null +++ b/aten/src/ATen/native/Constraints.cpp @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include + +namespace at { +namespace native { + +void sym_constrain_range_cpu( + const Scalar& size, + c10::optional min = c10::nullopt, + c10::optional max = c10::nullopt) {} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/Constraints.cu b/aten/src/ATen/native/cuda/Constraints.cu new file mode 100644 index 00000000000..2f4bd4b67f4 --- /dev/null +++ b/aten/src/ATen/native/cuda/Constraints.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include + +namespace at { +namespace native { + +void sym_constrain_range_cuda( + const Scalar& size, + c10::optional min = c10::nullopt, + c10::optional max = c10::nullopt) {} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d23abffea99..5f151ae2f26 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -177,6 +177,11 @@ - func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () +- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> () + dispatch: + CPU: sym_constrain_range_cpu + CUDA: sym_constrain_range_cuda + - func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) variants: method diff --git a/build_variables.bzl b/build_variables.bzl index dc6cdd31fd3..0d7571db9a1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1244,6 +1244,7 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/ChanelShuffle.cpp", "aten/src/ATen/native/Col2Im.cpp", "aten/src/ATen/native/PadNd.cpp", + "aten/src/ATen/native/Constraints.cpp", "aten/src/ATen/native/Convolution.cpp", "aten/src/ATen/native/ConvolutionMM2d.cpp", "aten/src/ATen/native/ConvolutionMM3d.cpp", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 00eb62b2bb5..2b43a2790dd 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -22,6 +22,7 @@ from torch._export import dynamic_dim from torch._export.constraints import constrain_as_size, constrain_as_value from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ConstraintViolationError +from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfRocm @@ -2408,6 +2409,44 @@ def forward(self, x): buffer = io.BytesIO() torch.save(gm, buffer) + def test_export_with_inline_constraints(self): + def f(x): + a = x.item() + constrain_as_size(a, 4, 7) + return torch.empty((a, 4)) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]" + ) as cm: + torch._export.export(f, (torch.tensor([20]),)) + + ep = torch._export.export(f, (torch.tensor([5]),)) + self.assertEqual(ep(torch.tensor([6])).shape, (6, 4)) + + FileCheck().check_count( + "torch.ops.aten.sym_constrain_range.default", 1, exactly=True + ).run(ep.graph_module.code) + + with self.assertRaisesRegex( + RuntimeError, + r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]", + ) as cm: + ep(torch.tensor([30])) + + def test_export_with_inline_constraints_complex(self): + def f(x): + a = x.item() + constrain_as_size(a, 4, 7) + empty = torch.empty((a, 4)) + + return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0) + + ep = torch._export.export(f, (torch.tensor([6]),)) + self.assertEqual(ep(torch.tensor([5])).shape, (10, 5)) + FileCheck().check_count( + "torch.ops.aten.sym_constrain_range.default", 1, exactly=True + ).run(ep.graph_module.code) + def test_export_dynamic_dim_not_1(self): x = torch.randn([1, 1, 1]) diff --git a/test/export/test_export.py b/test/export/test_export.py index b1fcd3e75dd..13a34388c2a 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -68,7 +68,6 @@ class TestExperimentalExport(TestCase): @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestDynamismExpression(TestCase): - @unittest.expectedFailure def test_export_inline_constraints(self): def f(x): diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 77559333c97..73a7636a6b6 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -268,9 +268,9 @@ class TestPasses(TestCase): num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg) num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default) - # 2 constraints for b - self.assertEqual(num_assert, 2) - self.assertEqual(num_scalar_tensor, 2) + # TODO: De-duplicate assertions for same symbol. + self.assertEqual(num_assert, 4) + self.assertEqual(num_scalar_tensor, 4) with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."): ep(torch.tensor([1, 1, 0, 0, 0])) diff --git a/torch/_export/constraints.py b/torch/_export/constraints.py index ddfca1813f4..6f8d382cccc 100644 --- a/torch/_export/constraints.py +++ b/torch/_export/constraints.py @@ -1,9 +1,19 @@ -from typing import Optional +from typing import Optional, Callable, Union +import torch +from torch import SymInt, SymFloat from torch._dynamo import allow_in_graph -from torch.fx.experimental.symbolic_shapes import constrain_range +from torch.fx.experimental.symbolic_shapes import constrain_range_int from torch.utils._sympy.value_ranges import ValueRangeError +# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]` +# could cause type error during since `SymInt` or `SymFloat` will be used. +# Here manually specify the type explicitly. +sym_constrain_range: Callable[ + [Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]], + None, +] = torch.sym_constrain_range # type: ignore[assignment] + # TODO: we want to hide this min/max stuff under some abstraction similar to # DynamicDim @@ -13,7 +23,11 @@ def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = N Add min/max constraint on the intermediate symbol at tracing time """ - constrain_range(symbol, min=min, max=max) + if not isinstance(symbol, SymInt): + constrain_range_int(symbol, min=min, max=max) + else: + sym_constrain_range(symbol, min, max) + return symbol diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 98b5fb36a40..e3c4f1b3a2c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -30,7 +30,7 @@ from torch._prims_common.wrappers import ( out_wrapper, ) from torch._refs import _broadcast_shapes - +from torch.fx.experimental.symbolic_shapes import constrain_range from torch.utils._pytree import tree_map @@ -311,6 +311,11 @@ def assert_async_meta(val, assert_msg): return +@register_meta(aten.sym_constrain_range.default) +def sym_constrain_range(size, min, max): + constrain_range(size, min=min, max=max) + + # From aten/src/ATen/native/LinearAlgebraUtils.h def squareCheckInputs(self: Tensor, f_name: str): assert ( diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 3b7dd842378..3ba091a86f7 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1488,7 +1488,7 @@ class FakeTensorMode(TorchDispatchMode): nonlocal common_device nonlocal has_scalar_only_inputs - if common_device is None: + if isinstance(e, torch.Tensor) and common_device is None: ( common_device, has_scalar_only_inputs, diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e79a4f4d1d9..b7a1c581f31 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -348,20 +348,7 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): if max is None: max = sympy.oo if not isinstance(a, SymInt): - if not (min <= a <= max): - raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]") - - if ( - (fake_mode := detect_fake_mode()) is not None and - getattr(fake_mode, "shape_env", None) is not None - ): - # If we are tracing with a fake mode then add this integer to the - # shape_env's var_to_range - sym_integer = sympy.Integer(a) - shape_env = fake_mode.shape_env - _constrain_symbol_range(shape_env, sym_integer, min, max) - shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack() - + constrain_range_int(a, min=min, max=max) return if isinstance(a.node.expr, sympy.Integer): @@ -376,6 +363,30 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): # SymInt). _constrain_symbol_range(a.node.shape_env, a.node.expr, min, max) +def constrain_range_int(a, *, min, max): + """ + Constrain range on concrete int value. + This can happens for the following scenarios: + - Eager mode execution and real int value is provided. + - During tracing the traced symbol is resolved as a static integer (see + PR #101655 for more details). + """ + + assert not isinstance(a, SymInt) + if not (min <= a <= max): + raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]") + + if ( + (fake_mode := detect_fake_mode()) is not None and + getattr(fake_mode, "shape_env", None) is not None + ): + # If we are tracing with a fake mode then add this integer to the + # shape_env's var_to_range + sym_integer = sympy.Integer(a) + shape_env = fake_mode.shape_env + _constrain_symbol_range(shape_env, sym_integer, min, max) + shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack() + def constrain_unify(a, b): """ Given two SymInts, constrain them so that they must be equal. NB: diff --git a/torch/fx/node.py b/torch/fx/node.py index afb7d2917e8..199670db154 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -35,6 +35,7 @@ _side_effectful_functions: Set[Callable] = { torch._assert_async, _ops.aten._assert_async.msg, _ops.aten.copy_.default, + _ops.aten.sym_constrain_range.default, _ops.profiler._record_function_enter, _ops.profiler._record_function_enter_new, _ops.profiler._record_function_exit} diff --git a/torch/overrides.py b/torch/overrides.py index c51b35e2082..2b3cb0baa0c 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -185,6 +185,7 @@ def get_ignored_functions() -> Set[Callable]: torch.sym_max, torch.sym_min, torch.sym_not, + torch.sym_constrain_range, torch.tril_indices, torch.triu_indices, torch.vander, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 8b7429e9826..653d7b29562 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -73,6 +73,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ "qscheme", # returns a QScheme "record_stream", # no return "sparse_dim", # returns an int + "sym_constrain_range", # no return "_nested_tensor_storage_offsets", # returns a vector of ints "_chunk_grad_outputs_efficient_attention", # returns a bool "_fused_sdp_choice", # returns an int