pytorch/test/export/test_export.py
xuanqi b27c3558a4 [RFC]: Create aten native op for constrain_range (#103346)
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
    a = x.item()
    constrain_as_size(a, 4, 7)
    return torch.empty((a, 4))

inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```

The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).

The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.

**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```

Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
2023-06-16 14:55:40 +00:00

263 lines
9.0 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo as torchdynamo
from torch._export import export, dynamic_dim
from torch._export.trace import do_not_use_experimental_export
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExperimentalExport(TestCase):
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
def test_export_simple_model_with_attr(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.float_val = float_val
def forward(self, x):
y = x + self.float_val
return y.cos()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
def test_export_simple_model(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.float_val = float_val
def forward(self, x):
return x.cos()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
def test_export_simple_model_buffer_mutation(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 1))
def forward(self, x):
self.buffer1.add_(2)
return x.cos() + self.buffer1.sin()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
mutated_buffer, output = exported_program.fw_module(*inp)
# TODO (tmanlaibaatar) enable this once we figure out
# how to do buffer mutation
# self.assertEqual(mutated_buffer.sum().item(), 30)
self.assertEqual(output, mod(*inp))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase):
def test_export_inline_constraints(self):
def f(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
ref = f(*inp)
gm = export(f, inp)
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
gm = make_fx(f, tracing_mode="symbolic")(*inp)
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
def test_export_constraints_error(self):
def invalid_size(x):
b = x.item()
constrain_as_size(b, min=0, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
export(invalid_size, inp)
def invalid_input_conflict_with_inline_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([6]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
export(invalid_input_conflict_with_inline_constraints, inp)
def invalid_input_conflict_with_input_constraints(x):
return x + 1
inp = torch.zeros([3])
inp_constraints = [
dynamic_dim(inp, 0) > 5,
]
with self.assertRaisesRegex(torchdynamo.exc.UserError, "not in range"):
export(
invalid_input_conflict_with_input_constraints,
(inp,),
constraints=inp_constraints,
)
def conflicting_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=3)
constrain_as_size(b, min=4, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"):
export(conflicting_constraints, inp)
def test_export_assume_static_by_default(self):
def branch_on_shape(x: torch.Tensor):
if x.shape[0] == 4:
return x + 1
else:
return x
inp = (torch.rand(4, 5),)
# Being able to export means shape is preserved as static
export(branch_on_shape, inp)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExport(TestCase):
def test_basic(self):
def f(x, y):
return x[0] + y
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
exported_program = export(f, inp)
self.assertTrue(torch.allclose(exported_program(*inp), f(*inp)))
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
def fn_ddo(x):
y = x.nonzero()
z = y.shape[0]
if z > 2:
return x.cos()
else:
return x.sin()
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"trying to get a value out of symbolic int"
):
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)
def test_if_functional(self):
def foo(x):
z = x + 4
z.add_(4)
y = z.view(x.shape)
return x.cos() + y.cos()
gm = export(foo, (torch.tensor([2, 3, 5]),), constraints=None)
view_count = 0
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add_.Tensor:
# No more inplace mutation
self.assertNotEqual(
node.target,
torch.ops.aten.add_.Tensor,
"There shouldn't be any inplace mutation node in the graph."
)
if node.op == "call_function" and node.target == torch.ops.aten.view.default:
view_count += 1
# There should be nonzero view nodes in the graph
self.assertTrue(view_count > 0)
def test_export_mod_constraints(self):
class BasicDynamiShapeModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0] - 1, -1)
m = BasicDynamiShapeModel()
a = torch.randn(3, 4)
constraints = [3 <= dynamic_dim(a, 0), dynamic_dim(a, 1)]
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify"
".*\n.*\\[0\\], which was marked dynamic, must be specialized to 3"
".*\n.*\\[1\\], which was marked dynamic, must be specialized to 4"
),
):
torch._export.export(m, (a,), constraints=constraints)
em = torch._export.export(m, (a,))
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)
def test_export_constrain_static(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
c = y.dim()
constrain_as_value(c, min=1, max=3)
z = y[0:c]
return torch.empty((b, y.shape[0])), z
x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = (x, y)
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
with self.assertRaisesRegex(
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
"on a value which we evaluated to have a static value of 3. "
):
export(f, example_inputs, constraints)
def test_not_correct_dim(self):
def f(x):
return x.cos()
def g(x):
return x + 4
inp_for_f = torch.tensor(5)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Cannot mark 0-dimension tensors to be dynamic"):
constraints = [dynamic_dim(inp_for_f, 0)]
inp_for_f_mul_dim = torch.ones(5, 5)
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"Expected the dimension passed to dynamic_dim to be in the range \\[0:1\\]"
):
constraints = [dynamic_dim(inp_for_f_mul_dim, 2)]
inp_for_g = 4
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Expected tensor as input to dynamic_dim"):
constraints = [dynamic_dim(inp_for_g, 0)]
if __name__ == '__main__':
run_tests()