pytorch/test/export/test_experimental.py
Tugsbayasgalan Manlaibaatar d7fe3c4123 [RELAND] Switch default behavoir of export IR to be predispatch (#125860)
This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:

You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
You can write sound graph transformations more easily as the IR is functional.
Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.
If you want to get the core aten IR out of torch.export, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D57172986](https://our.internmc.facebook.com/intern/diff/D57172986)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125860
Approved by: https://github.com/zhxchen17
2024-05-10 17:36:53 +00:00

112 lines
3.7 KiB
Python

# Owner(s): ["oncall: export"]
# flake8: noqa
import unittest
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._export.wrappers import _mark_strict_experimental
from torch._functorch.aot_autograd import aot_export_module
from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_with_buffer_as_submodule(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(3))
def forward(self, x):
y = x + 2
y.add_(4)
# this doesnt' work today with HOO
# self.buffer1.add_(6)
buffer_updated = self.buffer1 + 6
return x.sum() + y.sum() + buffer_updated.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
x_v2 = x.sin()
return (self.submodule(x_v2), x + 3)
inp = torch.randn(3)
ep = torch.export.export(M(), (inp,), strict=False)
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem_2 = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem_2, add)""",
)
self.assertExpectedInline(
str(ep.graph_module.strict_graph_0.code.strip()),
"""\
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 2)
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
return (add_4,)""",
)
eager_mod = M()
ep = torch.export.export(eager_mod, (inp,), strict=True)
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
graph_res_1, graph_res_2 = ep.module()(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
def test_mark_strict_with_container_type(self):
@_mark_strict_experimental
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x0 = x[0][0]
return x0.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
return self.submodule(x)
inp = ((torch.randn(3),),)
with self.assertRaisesRegex(
RuntimeError, "strict_mode HOO doesn't work unless"
):
ep = torch.export.export(M(), inp, strict=False)
if __name__ == "__main__":
run_tests()