# Owner(s): ["oncall: export"] import copy import tempfile import unittest import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.export import Dim, export from torch.export._draft_export import draft_export, FailureType from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase from torch.testing._internal.torchbind_impls import ( _empty_tensor_queue, init_torchbind_implementations, ) from torch.utils._pytree import tree_leaves class TestDraftExport(TestCase): def setUp(self): super().setUp() init_torchbind_implementations() self.torch_bind_ops = [ torch.ops._TorchScriptTesting.queue_pop, torch.ops._TorchScriptTesting.queue_push, torch.ops._TorchScriptTesting.queue_size, ] def tearDown(self): return def test_missing_meta_kernel_custom_op_basic(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo2", mutates_args={}) def foo2_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b class M(torch.nn.Module): def forward(self, a, b): res = torch.ops.mylib.foo2(a, b) return res inp = (torch.ones(3, 3), torch.ones(3, 3)) ep = draft_export(M(), inp) report = ep._report self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.MISSING_FAKE_KERNEL ) inp = (torch.randn(3, 3), torch.randn(3, 3)) self.assertEqual(ep.module()(*inp), M()(*inp)) with torch._library.fake_profile.register_fake_profile(report.op_profiles): ep.run_decompositions() def test_missing_meta_kernel_impl(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", "(Tensor a, Tensor b) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl("mylib::foo", "cpu", lib=lib) def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b class M(torch.nn.Module): def forward(self, a, b): res = torch.ops.mylib.foo(a, b) res = torch.ops.mylib.foo(res, b) return res inp = (torch.ones(3, 3), torch.ones(3, 3)) ep = draft_export(M(), inp) report = ep._report self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.MISSING_FAKE_KERNEL ) inp = (torch.randn(3, 3), torch.randn(3, 3)) self.assertEqual(ep.module()(*inp), M()(*inp)) self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo.default"]), 1) print(report.op_profiles) with torch._library.fake_profile.register_fake_profile(report.op_profiles): ep = ep.run_decompositions() self.assertEqual(ep.module()(*inp), M()(*inp)) def test_missing_meta_kernel_custom_op_multiple_profiles(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo3", mutates_args={}) def foo3_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b class M(torch.nn.Module): def forward(self, a, b, c, d): res1 = torch.ops.mylib.foo3(a, b) res2 = torch.ops.mylib.foo3(c, d) return res1, res2 inp = ( torch.ones(3, 4), torch.ones(3, 4), torch.ones(2, 3, 4), torch.ones(2, 3, 4), ) ep = draft_export(M(), inp) report = ep._report self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.MISSING_FAKE_KERNEL ) self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo3.default"]), 2) with torch._library.fake_profile.register_fake_profile(report.op_profiles): ep.run_decompositions() def test_missing_meta_kernel_custom_op_update_profile(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo8", mutates_args={}) def foo8_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b class M(torch.nn.Module): def forward(self, a, b): res = torch.ops.mylib.foo8(a, b) return res inp = ( torch.ones(3, 4), torch.ones(3, 4), ) ep = draft_export(M(), inp) report = ep._report self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1) new_inp = ( torch.ones(2, 3, 4), torch.ones(2, 3, 4), ) with torch._library.fake_profile.register_fake_profile(report.op_profiles): with FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()): torch.ops.mylib.foo8(*inp) with self.assertRaisesRegex( RuntimeError, "no profiles match the given inputs" ): torch.ops.mylib.foo8(*new_inp) ep = draft_export(M(), new_inp) report = ep._report self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1) with torch._library.fake_profile.register_fake_profile( report.op_profiles ), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()): torch.ops.mylib.foo8(*new_inp) # Existing registration has been updated to match the new # profile traced with draft-export with self.assertRaisesRegex( RuntimeError, "no profiles match the given inputs" ): torch.ops.mylib.foo8(*inp) @unittest.skipIf(not torch.cuda.is_available(), "Requires cuda") def test_missing_meta_kernel_guard(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo4", mutates_args={}) def foo4_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b class M(torch.nn.Module): def forward(self, a, b): res1 = torch.ops.mylib.foo4(a, b) return res1 inp = ( torch.ones(3, 4), torch.ones(3, 4), ) ep = draft_export( M(), inp, dynamic_shapes={ "a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, "b": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, }, ) inp = (torch.randn(2, 3), torch.randn(2, 3)) self.assertEqual(ep.module()(*inp), M()(*inp)) m = ep.module() with self.assertRaisesRegex(RuntimeError, "Tensor device mismatch!"): bad_device_inps = ( torch.randn(2, 3, device=torch.device("cuda")), torch.randn(2, 3, device=torch.device("cuda")), ) m(*bad_device_inps) with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"): bad_dtype_inps = ( torch.randn(2, 3, dtype=torch.float16), torch.randn(2, 3, dtype=torch.float16), ) m(*bad_dtype_inps) def test_fake_infer_dense_in_memory_check(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo5", mutates_args={}) def foo5_impl(a: torch.Tensor) -> torch.Tensor: return a * 2 @torch.library.custom_op("mylib::foo6", mutates_args={}) def foo6_impl(a: torch.Tensor) -> torch.Tensor: return (a * 2)[:, :-1, :-1] # not dense in memory @torch.library.custom_op("mylib::foo7", mutates_args={}) def foo7_impl(a: torch.Tensor) -> torch.Tensor: return (a * 2)[:, 1:-1, :] # non-zero storage offset class Foo(torch.nn.Module): def forward(self, x, opt): if opt == 0: return torch.ops.mylib.foo5(x) elif opt == 1: return torch.ops.mylib.foo6(x) else: return torch.ops.mylib.foo7(x) draft_export(Foo(), (torch.randn(80, 4, 4), 0)) draft_export(Foo(), (torch.randn(80, 1, 4), 0)) draft_export(Foo(), (torch.randn(1, 4, 1, 1, 4, 1, 4), 0)) with self.assertRaisesRegex( RuntimeError, "a return was not dense in memory", ): draft_export(Foo(), (torch.randn(4, 6, 8), 1)) with self.assertRaisesRegex( RuntimeError, "a return has a non-zero storage offset", ): draft_export(Foo(), (torch.randn(4, 6, 8), 2)) def test_data_dependent_failure(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo1", "(Tensor a, Tensor b) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl("mylib::foo1", "cpu", lib=lib) def foo_impl(a, b): return a + b class M(torch.nn.Module): def forward(self, a, b, c): res = torch.ops.mylib.foo1(a, b) c_item = c.item() return res[:c_item] inp = (torch.ones(3, 3), torch.ones(3, 3), torch.tensor(3)) ep = draft_export(M(), inp) report = ep._report self.assertTrue(len(report.failures) > 0) self.assertEqual( report.failures[0].failure_type, FailureType.MISSING_FAKE_KERNEL ) self.assertEqual( report.failures[1].failure_type, FailureType.DATA_DEPENDENT_ERROR ) inp = (torch.randn(3, 3), torch.randn(3, 3), torch.tensor(2)) self.assertEqual(ep.module()(*inp), M()(*inp)) def test_unbacked_div_mod_replacement(self): class M(torch.nn.Module): def forward(self, x): x = torch.zeros(x.item()) x = x.unsqueeze(0).repeat(10, 2) return x.view(-1, 2, 2345) ep = draft_export(M(), (torch.tensor([938]),)) report = ep._report self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR ) self.assertEqual(report.failures[0].data["expr"], "Eq(2*u1, 10)") def test_dedup_data_dependent_failure(self): class M(torch.nn.Module): def forward(self, x, y, z): res = 0 for v in [x, y]: b = v.item() if b > 10: res += v * b else: res += v + b return z * res inp = (torch.tensor(5), torch.tensor(3), torch.tensor(2)) ep = draft_export(M(), inp) report = ep._report self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR ) inp = (torch.tensor(4), torch.tensor(2), torch.tensor(6)) self.assertEqual(ep.module()(*inp), M()(*inp)) # the fake tensors on node.meta["val"] should have real_tensor gm = ep.module() tensors = [ node.meta.get("val").real_tensor for node in gm.graph.nodes if node.op == "placeholder" ] self.assertTrue(all(isinstance(t, torch.Tensor) for t in tensors)) def test_complex_data_dependent_expr(self): class M(torch.nn.Module): def forward(self, x, y): a = x.item() a = -a a = a // 3 a = a + 5 z = torch.cat([y, y]) return z[:a] ep = draft_export( M(), (torch.tensor(6), torch.randn(5)), dynamic_shapes={"x": None, "y": {0: Dim.DYNAMIC}}, ) report = ep._report self.assertTrue(len(report.failures) > 0) self.assertEqual( report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR ) for _ep in [ep, ep.run_decompositions()]: # check data-dependent asserts assert_scalar_nodes = [ node for node in _ep.graph.nodes if node.target == torch.ops.aten._assert_scalar.default ] self.assertEqual(len(assert_scalar_nodes), 5) # unbacked bindings unbacked_binding_symbols = set() for node in _ep.graph.nodes: if bindings := node.meta.get("unbacked_bindings"): unbacked_binding_symbols.update(bindings.keys()) self.assertEqual(len(unbacked_binding_symbols), 1) def test_offsets(self): class M(torch.nn.Module): def forward(self, x): a = x.item() if a == 0: raise RuntimeError("bad") return x * a inp = (torch.tensor(3),) draft_export(M(), inp) def test_shape_failure(self): class M(torch.nn.Module): def forward(self, a): assert a.shape[0] == 3 return a * a inp = (torch.ones(3, 3),) ep = draft_export(M(), inp, dynamic_shapes={"a": {0: Dim("a0")}}) report = ep._report self.assertEqual(len(report.failures), 1) self.assertEqual(report.failures[0].failure_type, FailureType.GUARD_ADDED) inp = (torch.randn(3, 3),) self.assertEqual(ep.module()(*inp), M()(*inp)) inp = (torch.randn(4, 3),) with self.assertRaises(RuntimeError): ep.module()(*inp) def test_side_effect1(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("a", torch.tensor(2)) def forward(self, b): a_item = self.a.item() if a_item == 2: res = a_item * b else: res = (a_item + 1) * b self.a.add_(1) a_item = self.a.item() if a_item == 3: res = a_item * res else: res = (a_item + 1) * res return res inp = (torch.ones(3, 3),) mod = M() ep = draft_export(mod, inp) self.assertEqual(mod.a, torch.tensor(2)) FileCheck().check_count("torch.ops.aten.add.default", 0, exactly=True).run( ep.graph_module.code ) def test_side_effect_inps(self): class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): x.sin_() return x inp = (torch.ones(3, 3),) ep = draft_export(M(), inp) report = ep._report self.assertTrue(report.successful()) self.assertEqual(inp[0], torch.ones(3, 3)) def test_masked_linear(self): class M(torch.nn.Module): def forward(self, x, mask, weight, bias): masked = x[mask != 0, :, :] return torch.nn.functional.linear(masked, weight, bias) x = torch.zeros(10) inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25)) draft_ep = draft_export(M(), inp) ep = export(M(), inp) self.assertEqual(draft_ep.module()(*inp), ep.module()(*inp)) x[2] += 1 x[3] += 1 self.assertEqual(draft_ep.module()(*inp), ep.module()(*inp)) def test_torchbind(self): class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, tq, x): x_cos = tq.pop() + tq.float_size() + self.linear(x) if tq.is_empty(): x_sin = self.linear(tq.pop()) - tq.size() + x else: x_sin = tq.pop() + tq.size() + x return x_sin, x_cos, tq mod = Model() tq = _empty_tensor_queue() tq2 = copy.deepcopy(tq) a = torch.randn(2, 2) b = torch.randn(2, 2) tq.push(a) tq.push(b) tq3 = copy.deepcopy(tq) inp = (tq, torch.randn(2, 2)) ep = draft_export(mod, inp) report = ep._report self.assertTrue(report.successful()) self.assertEqual(tq2.size(), 0) self.assertEqual(tq3.size(), 2) self.assertEqual(tq.size(), 2) def test_override_size_and_dtype_mismatched_fake_kernels(self): with torch.library._scoped_library("mylib", "FRAGMENT"): class M(torch.nn.Module): def forward(self, a): return torch.ops.mylib.foo9(a) @torch.library.custom_op("mylib::foo9", mutates_args={}) def foo(a: torch.Tensor) -> list[torch.Tensor]: x = a * 2 y = a.repeat(2, 2) z = a.to(torch.bfloat16) return [x, y, z] @torch.library.register_fake("mylib::foo9") def foo_fake_impl(a): x = torch.empty_like(a) # good y = torch.empty_like(a) # size mismatch z = torch.empty_like(a) # dtype mismatch return [x, y, z] mod = M() inputs = (torch.randn(3, 3),) with self.assertRaises(RuntimeError): with torch._functorch.config.patch( fake_tensor_propagate_real_tensors=True ): export(mod, inputs, strict=True) ep = draft_export(mod, inputs) report = ep._report for ep_out, eager_out in zip(ep.module()(*inputs), mod(*inputs)): self.assertTrue(torch.allclose(ep_out, eager_out)) self.assertEqual(ep_out.dtype, eager_out.dtype) self.assertEqual(len(report.failures), 2) self.assertEqual( report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL ) self.assertEqual( report.failures[1].failure_type, FailureType.MISMATCHED_FAKE_KERNEL ) self.assertEqual( sorted([f.data["reason"] for f in report.failures]), [ "Dtypes torch.bfloat16 and torch.float32 are not equal!", "mismatch between fake value 3 and real value 6 ", ], ) with torch._library.fake_profile.register_fake_profile(report.op_profiles): ep.run_decompositions() def test_override_incorrectly_aliasing_kernel(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo10", mutates_args={}) def foo(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return a * 2, a + 2 @torch.library.register_fake("mylib::foo10") def foo_fake_impl(a): return a, torch.empty_like(a) # incorrectly aliasing class M(torch.nn.Module): def forward(self, a): return torch.ops.mylib.foo10(a) mod = M() inputs = (torch.randn(3, 3),) with self.assertRaisesRegex( RuntimeError, "Real tensor propagation found an aliasing mismatch", ): with torch._functorch.config.patch( fake_tensor_propagate_real_tensors=True ): export(mod, inputs, strict=True) ep = draft_export(mod, inputs) report = ep._report for ep_out, eager_out in zip( tree_leaves(ep.module()(*inputs)), tree_leaves(mod(*inputs)) ): self.assertTrue(torch.allclose(ep_out, eager_out)) self.assertEqual(ep_out.dtype, eager_out.dtype) self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL ) self.assertTrue( "Mismatched aliasing spec between fake kernel and real kernel" in report.failures[0].data["reason"] ) def test_override_mismatched_fake_kernel_with_unbacked_symbols(self): with torch.library._scoped_library("mylib", "FRAGMENT"): @torch.library.custom_op("mylib::foo11", mutates_args={}) def foo11(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a[b.item()].to(torch.bfloat16) @torch.library.register_fake("mylib::foo11") def foo_fake_impl(a, b): ctx = torch.library.get_ctx() u = ctx.new_dynamic_size() return torch.empty(u, a.shape[1], dtype=a.dtype) class M(torch.nn.Module): def forward(self, a, b): return torch.ops.mylib.foo11(a, b) mod = M() inputs = (torch.randn(100, 4), torch.tensor(10)) ep = draft_export(mod, inputs) report = ep._report for ep_out, eager_out in zip(ep.module()(*inputs), mod(*inputs)): self.assertTrue(torch.allclose(ep_out, eager_out)) self.assertEqual(ep_out.dtype, eager_out.dtype) self.assertEqual(len(report.failures), 1) self.assertEqual( report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL ) self.assertEqual( report.failures[0].data["reason"], "Dtypes torch.bfloat16 and torch.float32 are not equal!", ) with torch._library.fake_profile.register_fake_profile(report.op_profiles): ep.run_decompositions() # https://github.com/pytorch/pytorch/issues/140625 @unittest.skipIf(IS_WINDOWS, "aoti_compile_and_package not supported on Windows") def test_constantify_unbacked_symbol(self): class M(torch.nn.Module): def forward(self, x, y): xt = torch.tensor(x.shape) u0 = xt[0].item() return y * torch.arange(u0) mod = M() example_inputs = (torch.randn(3, 5), torch.randn(3)) draft_ep = draft_export(mod, example_inputs) with tempfile.NamedTemporaryFile(suffix=".pt2") as f: torch._inductor.aoti_compile_and_package( draft_ep, package_path=f.name, ) if __name__ == "__main__": run_tests()