mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124978 Approved by: https://github.com/jansel
714 lines
24 KiB
Python
714 lines
24 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import copy
|
|
import dataclasses
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from re import escape
|
|
from typing import Any, List
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from functorch.experimental.control_flow import cond, map
|
|
from torch import Tensor
|
|
from torch._export.utils import (
|
|
get_buffer,
|
|
get_param,
|
|
is_buffer,
|
|
is_param,
|
|
register_dataclass_as_pytree_node,
|
|
)
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch.export import (
|
|
Constraint,
|
|
Dim,
|
|
dynamic_dim,
|
|
export,
|
|
FlatArgsAdapter,
|
|
unflatten,
|
|
)
|
|
from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import (
|
|
find_library_location,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_SANDCASTLE,
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
|
|
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
|
from torch.utils._pytree import (
|
|
LeafSpec,
|
|
tree_flatten,
|
|
tree_unflatten,
|
|
TreeSpec,
|
|
treespec_dumps,
|
|
treespec_loads,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestUnflatten(TestCase):
|
|
def compare_outputs(self, eager, unflattened, args):
|
|
orig_output = eager(*args)
|
|
unflattened_output = unflattened(*args)
|
|
self.assertTrue(torch.allclose(orig_output, unflattened_output))
|
|
|
|
def test_unflatten_nested(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x / x
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("child2buffer", torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
export_module = export(orig_eager, (torch.rand(2, 3),), {})
|
|
unflattened = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
|
|
# Compare the root modules and all submodules
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
|
|
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
|
|
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
|
|
|
|
# Check state dicts are equal
|
|
orig_state_dict = orig_eager.state_dict()
|
|
exported_state_dict = unflattened.state_dict()
|
|
for name, value in orig_state_dict.items():
|
|
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
|
|
|
|
def test_unflatten_buffer_mutation(self):
|
|
class Child(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("child2buffer", torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
self.child2buffer.add_(x)
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = Child()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
return x * self.rootparam
|
|
|
|
eager_module = MyModule()
|
|
export_module = export(eager_module, (torch.rand(2, 3),), {})
|
|
unflattened_module = unflatten(export_module)
|
|
|
|
# Buffer should look the same before and after one run
|
|
eager_buffer = eager_module.foo.child2buffer
|
|
unflattened_buffer = unflattened_module.foo.child2buffer
|
|
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
eager_module(*inputs)
|
|
unflattened_module(*inputs)
|
|
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
|
|
|
|
def test_unflatten_nested_access(self):
|
|
class Child(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("child2buffer", torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = Child()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x + self.foo.child2buffer
|
|
x = self.foo(x)
|
|
return x
|
|
|
|
eager_module = MyModule()
|
|
export_module = export(eager_module, (torch.rand(2, 3),), {})
|
|
unflattened_module = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(eager_module, unflattened_module, inputs)
|
|
|
|
def test_unflatten_shared_submodule(self):
|
|
class Shared(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
layernorm = torch.nn.LayerNorm(10)
|
|
self.sub_net = torch.nn.Sequential(
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.sub_net(x)
|
|
|
|
eager_module = Shared()
|
|
inps = (torch.rand(10),)
|
|
export_module = export(eager_module, inps, {})
|
|
unflattened_module = unflatten(export_module)
|
|
self.compare_outputs(eager_module, unflattened_module, inps)
|
|
self.assertTrue(hasattr(unflattened_module, "sub_net"))
|
|
for i in range(len(eager_module.sub_net)):
|
|
self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
|
|
self.assertEqual(
|
|
id(getattr(unflattened_module.sub_net, "0")),
|
|
id(getattr(unflattened_module.sub_net, "2")),
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
|
|
def test_unflatten_preserve_signature(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, zx, y):
|
|
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
|
|
def forward(self, x, y):
|
|
z = torch.ones_like(x)
|
|
xw = self.nested((z, x), y={"key": y})
|
|
return xw["w"] + z - xw["x"]
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x - 1
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
|
|
def forward(self, x, y):
|
|
x = self.foo(x, y)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inps = torch.rand(2, 3), torch.rand(2, 3)
|
|
for strict in [True, False]:
|
|
export_module = export(
|
|
orig_eager,
|
|
inps,
|
|
{},
|
|
preserve_module_call_signature=("foo.nested",),
|
|
strict=strict,
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
self.compare_outputs(export_module.module(), unflattened, inps)
|
|
unflattened.foo.nested = NestedChild()
|
|
self.compare_outputs(export_module.module(), unflattened, inps)
|
|
|
|
# Test tree spec mismatched input
|
|
orig_outs = export_module.module()(*inps)
|
|
new_inps = *inps, torch.rand(2, 3)
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
|
|
):
|
|
unflattened(new_inps)
|
|
|
|
# With flat args adapter
|
|
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
|
|
def adapt(
|
|
self,
|
|
target_spec: TreeSpec,
|
|
input_spec: TreeSpec,
|
|
input_args: List[Any],
|
|
) -> List[Any]:
|
|
while len(input_args) > 2:
|
|
input_args.pop(-1)
|
|
return input_args
|
|
|
|
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
|
|
new_outs = unflattened(*new_inps)
|
|
self.assertTrue(torch.allclose(orig_outs, new_outs))
|
|
|
|
def test_unflatten_param_list_dict(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param_list = torch.nn.ParameterList()
|
|
self.param_dict = torch.nn.ParameterDict()
|
|
for i in range(2):
|
|
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
|
|
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
|
|
torch.randn((2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
for i in range(2):
|
|
x = x + self.param_list[i]
|
|
x = x + self.param_dict[f"key_{i}"]
|
|
return x
|
|
|
|
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.compare_outputs(
|
|
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
|
)
|
|
|
|
def test_unflatten_wrong_input(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param_list = torch.nn.ParameterList()
|
|
self.param_dict = torch.nn.ParameterDict()
|
|
for i in range(2):
|
|
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
|
|
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
|
|
torch.randn((2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
a = x.sum()
|
|
for i in range(2):
|
|
a = a + self.param_list[i].sum()
|
|
a = a + self.param_dict[f"key_{i}"].sum()
|
|
return a
|
|
|
|
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
|
|
):
|
|
export_module.module()(torch.randn(6, 6))
|
|
|
|
unflattened = unflatten(export_module)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
|
|
):
|
|
unflattened(torch.randn(6, 6))
|
|
|
|
def test_unflatten_with_inplace_compile(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x / x
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("child2buffer", torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
|
unflattened = unflatten(export_module)
|
|
|
|
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
|
|
unflattened.foo.compile(fullgraph=True)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def test_fx_trace(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x = x[0] + x[1]
|
|
x = x + y["foo"]
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
|
|
export_module = export(orig_eager, inputs, {})
|
|
|
|
unflattened = unflatten(export_module)
|
|
torch.fx.symbolic_trace(
|
|
unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
|
|
)
|
|
|
|
def test_double_nested_submodule(self):
|
|
class SubSubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x * x
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.subsubmod = SubSubMod()
|
|
|
|
def forward(self, x):
|
|
return x - x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + self.submod.subsubmod(x)
|
|
|
|
orig_eager = MyModule()
|
|
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
|
unflattened = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def test_unflatten_container_type(self):
|
|
class Leaf(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.leaf = Leaf()
|
|
self.register_buffer("buffer", torch.randn(4, 4))
|
|
|
|
def forward(self, x, z):
|
|
return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.bar = Bar()
|
|
|
|
def forward(self, x, z):
|
|
y = self.bar.buffer + x + z[0] + z[1]
|
|
return self.bar(x, z) + y.sum()
|
|
|
|
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
|
|
mod = Foo()
|
|
ep_strict = torch.export.export(mod, inp)
|
|
ep_non_strict = torch.export.export(mod, inp, strict=False)
|
|
|
|
gm_unflat_non_strict = unflatten(ep_non_strict)
|
|
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))
|
|
|
|
def test_unflattened_module_nodes_has_meta_val(self):
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x + x, x * x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + sum(self.submod(x))
|
|
|
|
orig_eager = MyModule()
|
|
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
|
unflattened = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def check_meta(gm):
|
|
for n in gm.graph.nodes:
|
|
if n.op == "output":
|
|
continue
|
|
self.assertTrue(n.meta.get("val") is not None)
|
|
|
|
for m in unflattened.modules():
|
|
check_meta(m)
|
|
|
|
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
|
|
class TransposeModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.transpose(0, 1)
|
|
|
|
x = torch.randn(32, 3, 64, 64)
|
|
exported_program = export(TransposeModule(), args=(x,))
|
|
unflattened_module = unflatten(exported_program)
|
|
|
|
# Check the inputs of the created call_module node are in order
|
|
call_module_input_order = []
|
|
for node in unflattened_module.graph.nodes:
|
|
if node.op == "call_module":
|
|
transpose_module = unflattened_module.get_submodule(node.target)
|
|
for sub_node in transpose_module.graph.nodes:
|
|
if sub_node.op == "placeholder" or sub_node.op == "get_attr":
|
|
call_module_input_order.append(sub_node.op)
|
|
self.assertEqual(
|
|
call_module_input_order, ["placeholder", "get_attr", "get_attr"]
|
|
)
|
|
|
|
def test_unflatten_constant_tensor(self):
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.initializer = 0.1
|
|
|
|
def forward(self, x):
|
|
return x + torch.tensor(self.initializer)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + self.submod(x)
|
|
|
|
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.compare_outputs(
|
|
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
|
)
|
|
|
|
@skipIfTorchDynamo("custom objects not supported in dynamo yet")
|
|
def test_unflatten_constant_obj(self):
|
|
init_torchbind_implementations()
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x: int, y: int):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def __obj_unflatten__(cls, flat_ctx):
|
|
return cls(**dict(flat_ctx))
|
|
|
|
def add_tensor(self, z):
|
|
return (self.x + self.y) * z
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + self.submod(x)
|
|
|
|
with enable_torchbind_tracing():
|
|
export_module = torch.export.export(
|
|
Mod(), (torch.randn((2, 3)),), strict=False
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.compare_outputs(
|
|
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
|
)
|
|
|
|
def test_nested_leaf_non_strict(self):
|
|
class Leaf(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class Nested(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.leaf = Leaf()
|
|
|
|
def forward(self, x):
|
|
return self.leaf(x) + 2
|
|
|
|
class TopLevel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.nested = Nested()
|
|
|
|
def forward(self, x):
|
|
return self.nested(x) + 3
|
|
|
|
ep = torch.export.export(
|
|
TopLevel(),
|
|
(torch.randn(3),),
|
|
strict=False,
|
|
preserve_module_call_signature=("nested",),
|
|
)
|
|
|
|
torch.export.unflatten(ep)
|
|
|
|
def test_unflatten_submodule_ordering(self):
|
|
class Module2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("buffer", torch.rand(3, 4))
|
|
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
|
|
|
|
def forward(self, x):
|
|
return x + self.buffer + self.param
|
|
|
|
class Module1(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("buffer", torch.rand(3, 4))
|
|
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
|
|
|
|
def forward(self, x):
|
|
return x + self.buffer + self.param
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod2 = Module2()
|
|
self.mod3 = self.mod2
|
|
self.mod1 = Module1()
|
|
|
|
def forward(self, x):
|
|
return self.mod3(self.mod2(self.mod1(x)))
|
|
|
|
mod = Module()
|
|
|
|
ep = torch.export.export(mod, (torch.randn(3, 4),))
|
|
|
|
unflattened = torch.export.unflatten(ep)
|
|
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
|
|
self.assertEqual(len(fqn_list), 4)
|
|
self.assertEqual(
|
|
[x for x, _ in mod.named_modules(remove_duplicate=False)],
|
|
fqn_list,
|
|
)
|
|
|
|
def test_duplicate_placeholder(self):
|
|
N, C, H, W = 1, 2, 2, 3
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
layer = torch.nn.LayerNorm([C, H, W])
|
|
self.norms = torch.nn.ModuleList(
|
|
[
|
|
layer, # reuse layer norm
|
|
layer,
|
|
layer,
|
|
]
|
|
)
|
|
|
|
def forward(self, input_):
|
|
for i in range(len(self.norms)):
|
|
output = self.norms[i](input_)
|
|
input_ = output
|
|
return output
|
|
|
|
mod = MyModule()
|
|
input_ = torch.randn(N, C, H, W)
|
|
|
|
ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
|
|
umod = unflatten(ep_strict)
|
|
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
|
|
|
|
ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
|
|
umod = unflatten(ep_non_strict)
|
|
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|