pytorch/test/export/test_experimental.py
Tugsbayasgalan Manlaibaatar c933af2709 Switch to predispatch (#123573)
This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:
- You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
- You can write sound graph transformations more easily as the IR is functional.
- Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.

If you want to get the core aten IR out of `torch.export`, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D56273267](https://our.internmc.facebook.com/intern/diff/D56273267)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123573
Approved by: https://github.com/gmagogsfm
2024-04-24 00:51:09 +00:00

114 lines
3.8 KiB
Python

# Owner(s): ["oncall: export"]
# flake8: noqa
import unittest
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._export.wrappers import _mark_strict_experimental
from torch._functorch.aot_autograd import aot_export_module
from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
# TODO AssertionError: Unknown tensor output kind: getitem_2
@unittest.expectedFailure
def test_with_buffer_as_submodule(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(3))
def forward(self, x):
y = x + 2
y.add_(4)
# this doesnt' work today with HOO
# self.buffer1.add_(6)
buffer_updated = self.buffer1 + 6
return x.sum() + y.sum() + buffer_updated.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
x_v2 = x.sin()
return (self.submodule(x_v2), x + 3)
inp = torch.randn(3)
ep = torch.export.export(M(), (inp,), strict=False)
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem, add)""",
)
self.assertExpectedInline(
str(ep.graph_module.strict_graph_0.code.strip()),
"""\
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 2)
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
return (add_4,)""",
)
eager_mod = M()
ep = torch.export.export(eager_mod, (inp,), strict=True)
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
def test_mark_strict_with_container_type(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x0 = x[0][0]
return x0.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
return self.submodule(x)
inp = ((torch.randn(3),),)
with self.assertRaisesRegex(
RuntimeError, "strict_mode HOO doesn't work unless"
):
ep = torch.export.export(M(), inp, strict=False)
if __name__ == "__main__":
run_tests()