pytorch/test/export/test_experimental.py
Tugsbayasgalan Manlaibaatar 81f98f1082 Experimental non-strict mode (#114658)
This is proof-of-concept implementation of how people can use a marker `mark_strict` to enable torchdynamo while exporting under non-strict mode. The main idea is that `mark_strict` will turn into an HOO which then utilizes dynamo to do correctness analysis in the same way how torch.cond works today. There are some notable limitations:
1. This API is not meant for public use yet
2. Strict region can't work with arbitrary container inputs
3. We don't preserve `nn_module_stack` and other node metadata for the strict region.
4. strict_mode HOO will show up in the final graph. This is undesirable in the long term, but for short term experiments, it should be good enough. Will fix this in the follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114658
Approved by: https://github.com/ydwu4
2024-01-04 12:24:58 +00:00

112 lines
3.7 KiB
Python

# Owner(s): ["module: dynamo"]
# flake8: noqa
import unittest
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._export.wrappers import _mark_strict_DO_NOT_USE
from torch._functorch.aot_autograd import aot_export_module
from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_with_buffer_as_submodule(self):
@_mark_strict_DO_NOT_USE
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(3))
def forward(self, x):
y = x + 2
y.add_(4)
# this doesnt' work today with HOO
# self.buffer1.add_(6)
buffer_updated = self.buffer1 + 6
return x.sum() + y.sum() + buffer_updated.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
x_v2 = x.sin()
return (self.submodule(x_v2), x + 3)
inp = torch.randn(3)
ep = torch.export.export(M(), (inp,), strict=False)
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, arg0_1, arg1_1):
sin = torch.ops.aten.sin.default(arg1_1)
strict_graph_1 = self.strict_graph_1
strict_mode_1 = torch.ops.higher_order.strict_mode(strict_graph_1, (sin, arg0_1)); strict_graph_1 = sin = arg0_1 = None
getitem_1 = strict_mode_1[0]; strict_mode_1 = None
add = torch.ops.aten.add.Tensor(arg1_1, 3); arg1_1 = None
return (getitem_1, add)""",
)
self.assertExpectedInline(
str(ep.graph_module.strict_graph_1.code.strip()),
"""\
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 2)
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
return (add_4,)""",
)
eager_mod = M()
ep = torch.export.export(eager_mod, (inp,), strict=True)
graph_res_1, graph_res_2 = ep(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
graph_res_1, graph_res_2 = ep(inp)
eager_res_1, eager_res_2 = eager_mod(inp)
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
def test_mark_strict_with_container_type(self):
@_mark_strict_DO_NOT_USE
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x0 = x[0][0]
return x0.sum()
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = B()
def forward(self, x):
return self.submodule(x)
inp = ((torch.randn(3),),)
with self.assertRaisesRegex(
RuntimeError, "strict_mode HOO doesn't work unless"
):
ep = torch.export.export(M(), inp, strict=False)
if __name__ == "__main__":
run_tests()