pytorch/test/export/test_export.py
Avik Chaudhuri 59ee6cd864 fix soundness bug with unsupported constraints (#102897)
We do not raise constraint violations for complex binary conditions, such as conditions involving `%`. Moreover, while these constraints are discovered by our solver, the solver does not inject new constraint violations. This can result in cases where export passes, appropriate assertions are not added, and we get runtime crashes.

Now, when the solver discovers constraints that are too complex, we force-specialize the involved dimensions and raise a constraint violation when such dimensions are marked dynamic. This forces the user to remove the dynamic marking, and causes the appropriate specialization assertions to be added.

Differential Revision: [D46415786](https://our.internmc.facebook.com/intern/diff/D46415786/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102897
Approved by: https://github.com/tugsbayasgalan
2023-06-10 01:59:55 +00:00

264 lines
9.1 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo as torchdynamo
from torch._export import export, dynamic_dim
from torch._export.trace import do_not_use_experimental_export
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExperimentalExport(TestCase):
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
def test_export_simple_model_with_attr(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.float_val = float_val
def forward(self, x):
y = x + self.float_val
return y.cos()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
def test_export_simple_model(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.float_val = float_val
def forward(self, x):
return x.cos()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
def test_export_simple_model_buffer_mutation(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 1))
def forward(self, x):
self.buffer1.add_(2)
return x.cos() + self.buffer1.sin()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
mutated_buffer, output = exported_program.fw_module(*inp)
# TODO (tmanlaibaatar) enable this once we figure out
# how to do buffer mutation
# self.assertEqual(mutated_buffer.sum().item(), 30)
self.assertEqual(output, mod(*inp))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase):
@unittest.expectedFailure
def test_export_inline_constraints(self):
def f(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
ref = f(*inp)
gm = export(f, inp)
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
gm = make_fx(f, tracing_mode="symbolic")(*inp)
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
def test_export_constraints_error(self):
def invalid_size(x):
b = x.item()
constrain_as_size(b, min=0, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
export(invalid_size, inp)
def invalid_input_conflict_with_inline_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([6]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
export(invalid_input_conflict_with_inline_constraints, inp)
def invalid_input_conflict_with_input_constraints(x):
return x + 1
inp = torch.zeros([3])
inp_constraints = [
dynamic_dim(inp, 0) > 5,
]
with self.assertRaisesRegex(torchdynamo.exc.UserError, "not in range"):
export(
invalid_input_conflict_with_input_constraints,
(inp,),
constraints=inp_constraints,
)
def conflicting_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=3)
constrain_as_size(b, min=4, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"):
export(conflicting_constraints, inp)
def test_export_assume_static_by_default(self):
def branch_on_shape(x: torch.Tensor):
if x.shape[0] == 4:
return x + 1
else:
return x
inp = (torch.rand(4, 5),)
# Being able to export means shape is preserved as static
export(branch_on_shape, inp)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExport(TestCase):
def test_basic(self):
def f(x, y):
return x[0] + y
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
exported_program = export(f, inp)
self.assertTrue(torch.allclose(exported_program(*inp), f(*inp)))
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
def fn_ddo(x):
y = x.nonzero()
z = y.shape[0]
if z > 2:
return x.cos()
else:
return x.sin()
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"trying to get a value out of symbolic int"
):
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)
def test_if_functional(self):
def foo(x):
z = x + 4
z.add_(4)
y = z.view(x.shape)
return x.cos() + y.cos()
gm = export(foo, (torch.tensor([2, 3, 5]),), constraints=None)
view_count = 0
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add_.Tensor:
# No more inplace mutation
self.assertNotEqual(
node.target,
torch.ops.aten.add_.Tensor,
"There shouldn't be any inplace mutation node in the graph."
)
if node.op == "call_function" and node.target == torch.ops.aten.view.default:
view_count += 1
# There should be nonzero view nodes in the graph
self.assertTrue(view_count > 0)
def test_export_mod_constraints(self):
class BasicDynamiShapeModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0] - 1, -1)
m = BasicDynamiShapeModel()
a = torch.randn(3, 4)
constraints = [3 <= dynamic_dim(a, 0), dynamic_dim(a, 1)]
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify"
".*\n.*\\[0\\], which was marked dynamic, must be specialized to 3"
".*\n.*\\[1\\], which was marked dynamic, must be specialized to 4"
),
):
torch._export.export(m, (a,), constraints=constraints)
em = torch._export.export(m, (a,))
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)
def test_export_constrain_static(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
c = y.dim()
constrain_as_value(c, min=1, max=3)
z = y[0:c]
return torch.empty((b, y.shape[0])), z
x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = (x, y)
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
with self.assertRaisesRegex(
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
"on a value which we evaluated to have a static value of 3. "
):
export(f, example_inputs, constraints)
def test_not_correct_dim(self):
def f(x):
return x.cos()
def g(x):
return x + 4
inp_for_f = torch.tensor(5)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Cannot mark 0-dimension tensors to be dynamic"):
constraints = [dynamic_dim(inp_for_f, 0)]
inp_for_f_mul_dim = torch.ones(5, 5)
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"Expected the dimension passed to dynamic_dim to be in the range \\[0:1\\]"
):
constraints = [dynamic_dim(inp_for_f_mul_dim, 2)]
inp_for_g = 4
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Expected tensor as input to dynamic_dim"):
constraints = [dynamic_dim(inp_for_g, 0)]
if __name__ == '__main__':
run_tests()