mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[RFC]: Create aten native op for constrain_range (#103346)
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```
The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).
The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.
**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```
Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
df814484f4
commit
b27c3558a4
15
aten/src/ATen/native/Constraints.cpp
Normal file
15
aten/src/ATen/native/Constraints.cpp
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void sym_constrain_range_cpu(
|
||||
const Scalar& size,
|
||||
c10::optional<int64_t> min = c10::nullopt,
|
||||
c10::optional<int64_t> max = c10::nullopt) {}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
15
aten/src/ATen/native/cuda/Constraints.cu
Normal file
15
aten/src/ATen/native/cuda/Constraints.cu
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void sym_constrain_range_cuda(
|
||||
const Scalar& size,
|
||||
c10::optional<int64_t> min = c10::nullopt,
|
||||
c10::optional<int64_t> max = c10::nullopt) {}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
@ -177,6 +177,11 @@
|
|||
|
||||
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
|
||||
|
||||
- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> ()
|
||||
dispatch:
|
||||
CPU: sym_constrain_range_cpu
|
||||
CUDA: sym_constrain_range_cuda
|
||||
|
||||
- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
|
||||
variants: method
|
||||
|
||||
|
|
|
|||
|
|
@ -1244,6 +1244,7 @@ aten_native_source_non_codegen_list = [
|
|||
"aten/src/ATen/native/ChanelShuffle.cpp",
|
||||
"aten/src/ATen/native/Col2Im.cpp",
|
||||
"aten/src/ATen/native/PadNd.cpp",
|
||||
"aten/src/ATen/native/Constraints.cpp",
|
||||
"aten/src/ATen/native/Convolution.cpp",
|
||||
"aten/src/ATen/native/ConvolutionMM2d.cpp",
|
||||
"aten/src/ATen/native/ConvolutionMM3d.cpp",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch._export import dynamic_dim
|
|||
from torch._export.constraints import constrain_as_size, constrain_as_value
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
|
||||
|
|
@ -2408,6 +2409,44 @@ def forward(self, x):
|
|||
buffer = io.BytesIO()
|
||||
torch.save(gm, buffer)
|
||||
|
||||
def test_export_with_inline_constraints(self):
|
||||
def f(x):
|
||||
a = x.item()
|
||||
constrain_as_size(a, 4, 7)
|
||||
return torch.empty((a, 4))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]"
|
||||
) as cm:
|
||||
torch._export.export(f, (torch.tensor([20]),))
|
||||
|
||||
ep = torch._export.export(f, (torch.tensor([5]),))
|
||||
self.assertEqual(ep(torch.tensor([6])).shape, (6, 4))
|
||||
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]",
|
||||
) as cm:
|
||||
ep(torch.tensor([30]))
|
||||
|
||||
def test_export_with_inline_constraints_complex(self):
|
||||
def f(x):
|
||||
a = x.item()
|
||||
constrain_as_size(a, 4, 7)
|
||||
empty = torch.empty((a, 4))
|
||||
|
||||
return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)
|
||||
|
||||
ep = torch._export.export(f, (torch.tensor([6]),))
|
||||
self.assertEqual(ep(torch.tensor([5])).shape, (10, 5))
|
||||
FileCheck().check_count(
|
||||
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
|
||||
).run(ep.graph_module.code)
|
||||
|
||||
def test_export_dynamic_dim_not_1(self):
|
||||
x = torch.randn([1, 1, 1])
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ class TestExperimentalExport(TestCase):
|
|||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestDynamismExpression(TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_export_inline_constraints(self):
|
||||
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -268,9 +268,9 @@ class TestPasses(TestCase):
|
|||
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
|
||||
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)
|
||||
|
||||
# 2 constraints for b
|
||||
self.assertEqual(num_assert, 2)
|
||||
self.assertEqual(num_scalar_tensor, 2)
|
||||
# TODO: De-duplicate assertions for same symbol.
|
||||
self.assertEqual(num_assert, 4)
|
||||
self.assertEqual(num_scalar_tensor, 4)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."):
|
||||
ep(torch.tensor([1, 1, 0, 0, 0]))
|
||||
|
|
|
|||
|
|
@ -1,9 +1,19 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Callable, Union
|
||||
|
||||
import torch
|
||||
from torch import SymInt, SymFloat
|
||||
from torch._dynamo import allow_in_graph
|
||||
from torch.fx.experimental.symbolic_shapes import constrain_range
|
||||
from torch.fx.experimental.symbolic_shapes import constrain_range_int
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]`
|
||||
# could cause type error during since `SymInt` or `SymFloat` will be used.
|
||||
# Here manually specify the type explicitly.
|
||||
sym_constrain_range: Callable[
|
||||
[Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]],
|
||||
None,
|
||||
] = torch.sym_constrain_range # type: ignore[assignment]
|
||||
|
||||
|
||||
# TODO: we want to hide this min/max stuff under some abstraction similar to
|
||||
# DynamicDim
|
||||
|
|
@ -13,7 +23,11 @@ def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = N
|
|||
Add min/max constraint on the intermediate symbol at tracing time
|
||||
"""
|
||||
|
||||
constrain_range(symbol, min=min, max=max)
|
||||
if not isinstance(symbol, SymInt):
|
||||
constrain_range_int(symbol, min=min, max=max)
|
||||
else:
|
||||
sym_constrain_range(symbol, min, max)
|
||||
|
||||
return symbol
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from torch._prims_common.wrappers import (
|
|||
out_wrapper,
|
||||
)
|
||||
from torch._refs import _broadcast_shapes
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import constrain_range
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
|
|
@ -311,6 +311,11 @@ def assert_async_meta(val, assert_msg):
|
|||
return
|
||||
|
||||
|
||||
@register_meta(aten.sym_constrain_range.default)
|
||||
def sym_constrain_range(size, min, max):
|
||||
constrain_range(size, min=min, max=max)
|
||||
|
||||
|
||||
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
||||
def squareCheckInputs(self: Tensor, f_name: str):
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -1488,7 +1488,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
nonlocal common_device
|
||||
nonlocal has_scalar_only_inputs
|
||||
|
||||
if common_device is None:
|
||||
if isinstance(e, torch.Tensor) and common_device is None:
|
||||
(
|
||||
common_device,
|
||||
has_scalar_only_inputs,
|
||||
|
|
|
|||
|
|
@ -348,20 +348,7 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
if max is None:
|
||||
max = sympy.oo
|
||||
if not isinstance(a, SymInt):
|
||||
if not (min <= a <= max):
|
||||
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")
|
||||
|
||||
if (
|
||||
(fake_mode := detect_fake_mode()) is not None and
|
||||
getattr(fake_mode, "shape_env", None) is not None
|
||||
):
|
||||
# If we are tracing with a fake mode then add this integer to the
|
||||
# shape_env's var_to_range
|
||||
sym_integer = sympy.Integer(a)
|
||||
shape_env = fake_mode.shape_env
|
||||
_constrain_symbol_range(shape_env, sym_integer, min, max)
|
||||
shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()
|
||||
|
||||
constrain_range_int(a, min=min, max=max)
|
||||
return
|
||||
|
||||
if isinstance(a.node.expr, sympy.Integer):
|
||||
|
|
@ -376,6 +363,30 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
# SymInt).
|
||||
_constrain_symbol_range(a.node.shape_env, a.node.expr, min, max)
|
||||
|
||||
def constrain_range_int(a, *, min, max):
|
||||
"""
|
||||
Constrain range on concrete int value.
|
||||
This can happens for the following scenarios:
|
||||
- Eager mode execution and real int value is provided.
|
||||
- During tracing the traced symbol is resolved as a static integer (see
|
||||
PR #101655 for more details).
|
||||
"""
|
||||
|
||||
assert not isinstance(a, SymInt)
|
||||
if not (min <= a <= max):
|
||||
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")
|
||||
|
||||
if (
|
||||
(fake_mode := detect_fake_mode()) is not None and
|
||||
getattr(fake_mode, "shape_env", None) is not None
|
||||
):
|
||||
# If we are tracing with a fake mode then add this integer to the
|
||||
# shape_env's var_to_range
|
||||
sym_integer = sympy.Integer(a)
|
||||
shape_env = fake_mode.shape_env
|
||||
_constrain_symbol_range(shape_env, sym_integer, min, max)
|
||||
shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()
|
||||
|
||||
def constrain_unify(a, b):
|
||||
"""
|
||||
Given two SymInts, constrain them so that they must be equal. NB:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ _side_effectful_functions: Set[Callable] = {
|
|||
torch._assert_async,
|
||||
_ops.aten._assert_async.msg,
|
||||
_ops.aten.copy_.default,
|
||||
_ops.aten.sym_constrain_range.default,
|
||||
_ops.profiler._record_function_enter,
|
||||
_ops.profiler._record_function_enter_new,
|
||||
_ops.profiler._record_function_exit}
|
||||
|
|
|
|||
|
|
@ -185,6 +185,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.sym_max,
|
||||
torch.sym_min,
|
||||
torch.sym_not,
|
||||
torch.sym_constrain_range,
|
||||
torch.tril_indices,
|
||||
torch.triu_indices,
|
||||
torch.vander,
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
|||
"qscheme", # returns a QScheme
|
||||
"record_stream", # no return
|
||||
"sparse_dim", # returns an int
|
||||
"sym_constrain_range", # no return
|
||||
"_nested_tensor_storage_offsets", # returns a vector of ints
|
||||
"_chunk_grad_outputs_efficient_attention", # returns a bool
|
||||
"_fused_sdp_choice", # returns an int
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user