mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
as title Differential Revision: [D52924188](https://our.internmc.facebook.com/intern/diff/D52924188/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117886 Approved by: https://github.com/ydwu4
112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._export.wrappers import _mark_strict_experimental
|
|
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
|
|
from torch.testing import FileCheck
|
|
|
|
|
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
|
class TestExperiment(TestCase):
|
|
def test_with_buffer_as_submodule(self):
|
|
@_mark_strict_experimental
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("buffer1", torch.ones(3))
|
|
|
|
def forward(self, x):
|
|
y = x + 2
|
|
y.add_(4)
|
|
# this doesnt' work today with HOO
|
|
# self.buffer1.add_(6)
|
|
buffer_updated = self.buffer1 + 6
|
|
return x.sum() + y.sum() + buffer_updated.sum()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submodule = B()
|
|
|
|
def forward(self, x):
|
|
x_v2 = x.sin()
|
|
return (self.submodule(x_v2), x + 3)
|
|
|
|
inp = torch.randn(3)
|
|
ep = torch.export.export(M(), (inp,), strict=False)
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code.strip()),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
sin = torch.ops.aten.sin.default(arg1_1)
|
|
strict_graph_1 = self.strict_graph_1
|
|
strict_mode_1 = torch.ops.higher_order.strict_mode(strict_graph_1, (sin, arg0_1)); strict_graph_1 = sin = arg0_1 = None
|
|
getitem_1 = strict_mode_1[0]; strict_mode_1 = None
|
|
add = torch.ops.aten.add.Tensor(arg1_1, 3); arg1_1 = None
|
|
return (getitem_1, add)""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.strict_graph_1.code.strip()),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, 2)
|
|
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
|
|
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
|
|
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
|
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
|
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
|
|
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
|
|
return (add_4,)""",
|
|
)
|
|
|
|
eager_mod = M()
|
|
ep = torch.export.export(eager_mod, (inp,), strict=True)
|
|
|
|
graph_res_1, graph_res_2 = ep(inp)
|
|
eager_res_1, eager_res_2 = eager_mod(inp)
|
|
|
|
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
|
|
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
|
|
|
|
graph_res_1, graph_res_2 = ep(inp)
|
|
eager_res_1, eager_res_2 = eager_mod(inp)
|
|
|
|
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
|
|
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
|
|
|
|
def test_mark_strict_with_container_type(self):
|
|
@_mark_strict_experimental
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x0 = x[0][0]
|
|
return x0.sum()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submodule = B()
|
|
|
|
def forward(self, x):
|
|
return self.submodule(x)
|
|
|
|
inp = ((torch.randn(3),),)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "strict_mode HOO doesn't work unless"
|
|
):
|
|
ep = torch.export.export(M(), inp, strict=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|