pytorch/test/export/test_unflatten.py
Laith Sakka 853958f82c Fix: Replacements can cause runtime assertions to disappear and can cause invalid inductor code. (#153661)
Lets explore firs a couple of problem related to replacements and runtime assertions.

#### example problem 1
if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert  not being inserted during insert_deferred_runtime_asserts.
The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen,  but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc.

#### example problem 2
Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure.
internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/

## The solution :
Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item().

The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion.

This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions.
I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue.

Other approaches that did not work :
#### ban replacements on unbacked.
1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements
but problem 2 becomes more problematic.
2. Problem two, it affects the quality of reasoning ! in a bad way.

#### Apply specialization on runtime assertions before codegen .
1. Can fix some issues, but may lead also to runtime assertions becoming NOPs.
2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661
Approved by: https://github.com/jansel
2025-05-28 09:08:05 +00:00

1011 lines
35 KiB
Python

# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import unittest
from re import escape
from typing import Any, List, Optional
import torch
import torch._dynamo as torchdynamo
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.export import export, FlatArgsAdapter, unflatten
from torch.export.unflatten import _disable_interpreter
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils._pytree import TreeSpec
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestUnflatten(TestCase):
def compare_outputs(self, eager, unflattened, args):
orig_output = eager(*args)
unflattened_output = unflattened(*args)
self.assertTrue(torch.allclose(orig_output, unflattened_output))
def test_unflatten_nested(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
# Compare the root modules and all submodules
self.compare_outputs(orig_eager, unflattened, inputs)
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
# Check state dicts are equal
orig_state_dict = orig_eager.state_dict()
exported_state_dict = unflattened.state_dict()
for name, value in orig_state_dict.items():
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
def test_unflatten_buffer_mutation(self):
class Child(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
self.child2buffer.add_(x)
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.foo(x)
return x * self.rootparam
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
unflattened_module = unflatten(export_module)
# Buffer should look the same before and after one run
eager_buffer = eager_module.foo.child2buffer
unflattened_buffer = unflattened_module.foo.child2buffer
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
inputs = (torch.rand(2, 3),)
eager_module(*inputs)
unflattened_module(*inputs)
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
def test_unflatten_nested_access(self):
class Child(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x + self.foo.child2buffer
x = self.foo(x)
return x
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
unflattened_module = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(eager_module, unflattened_module, inputs)
def test_unflatten_shared_submodule(self):
class Shared(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
layernorm = torch.nn.LayerNorm(10)
self.sub_net = torch.nn.Sequential(
layernorm,
torch.nn.ReLU(),
layernorm,
torch.nn.ReLU(),
)
def forward(self, x):
return self.sub_net(x)
eager_module = Shared()
inps = (torch.rand(10),)
export_module = export(eager_module, inps, {}, strict=True)
unflattened_module = unflatten(export_module)
self.compare_outputs(eager_module, unflattened_module, inps)
self.assertTrue(hasattr(unflattened_module, "sub_net"))
for i in range(len(eager_module.sub_net)):
self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
self.assertEqual(
id(getattr(unflattened_module.sub_net, "0")),
id(getattr(unflattened_module.sub_net, "2")),
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
def test_unflatten_preserve_signature(self):
class NestedChild(torch.nn.Module):
def forward(self, zx, y):
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
def forward(self, x, y):
z = torch.ones_like(x)
xw = self.nested((z, x), y={"key": y})
return xw["w"] + z - xw["x"]
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x - 1
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x = self.foo(x, y)
x = self.bar(x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
for strict in [True, False]:
export_module = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested",),
strict=strict,
)
unflattened = unflatten(export_module)
self.compare_outputs(export_module.module(), unflattened, inps)
unflattened.foo.nested = NestedChild()
self.compare_outputs(export_module.module(), unflattened, inps)
# Test tree spec mismatched input
orig_outs = export_module.module()(*inps)
new_inps = *inps, torch.rand(2, 3)
with self.assertRaisesRegex(
TypeError,
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
):
unflattened(new_inps)
# With flat args adapter
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
def adapt(
self,
target_spec: TreeSpec,
input_spec: TreeSpec,
input_args: List[Any],
metadata: dict[str, Any],
obj: Optional[Any] = None,
) -> List[Any]:
while len(input_args) > 2:
input_args.pop(-1)
return input_args
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
new_outs = unflattened(*new_inps)
self.assertTrue(torch.allclose(orig_outs, new_outs))
def test_unflatten_param_list_dict(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param_list = torch.nn.ParameterList()
self.param_dict = torch.nn.ParameterDict()
for i in range(2):
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
torch.randn((2, 3))
)
def forward(self, x):
for i in range(2):
x = x + self.param_list[i]
x = x + self.param_dict[f"key_{i}"]
return x
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_unflatten_preserve_with_unused_input(self):
class M1(torch.nn.Module):
def forward(self, x, a, b):
return x + a, b
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M1()
def forward(self, x, y):
a, b = torch.topk(y, 2)
return self.m1(x, a, b)[0]
ep = torch.export.export(
M(),
(torch.randn(2), torch.randn(5)),
preserve_module_call_signature=("m1",),
strict=False,
)
ep.graph.eliminate_dead_code()
unflattened = unflatten(ep)
self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))
def test_unflatten_wrong_input(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param_list = torch.nn.ParameterList()
self.param_dict = torch.nn.ParameterDict()
for i in range(2):
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
torch.randn((2, 3))
)
def forward(self, x):
a = x.sum()
for i in range(2):
a = a + self.param_list[i].sum()
a = a + self.param_dict[f"key_{i}"].sum()
return a
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
):
export_module.module()(torch.randn(6, 6))
unflattened = unflatten(export_module)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
):
unflattened(torch.randn(6, 6))
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_unflatten_with_inplace_compile(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = torch.export.export(
orig_eager, (torch.rand(2, 3),), {}, strict=True
)
unflattened = unflatten(export_module)
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
from torch._dynamo.backends.debugging import ExplainWithBackend
eb = ExplainWithBackend("inductor")
unflattened.foo.compile(backend=eb, fullgraph=True)
inputs = (torch.randn(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
self.assertEqual(len(eb.graphs), 1)
unflattened.compile()
self.compare_outputs(orig_eager, unflattened, inputs)
def test_fx_trace(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
x = x[0] + x[1]
x = x + y["foo"]
return x
orig_eager = MyModule()
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
export_module = export(orig_eager, inputs, {}, strict=True)
unflattened = unflatten(export_module)
torch.fx.symbolic_trace(
unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
)
def test_double_nested_submodule(self):
class SubSubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x * x
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.subsubmod = SubSubMod()
def forward(self, x):
return x - x
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod.subsubmod(x)
orig_eager = MyModule()
export_module = torch.export.export(
orig_eager, (torch.rand(2, 3),), {}, strict=True
)
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def test_unflatten_container_type(self):
class Leaf(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear(x)
class Bar(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = Leaf()
self.buffer = torch.nn.Buffer(torch.randn(4, 4))
def forward(self, x, z):
return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bar = Bar()
def forward(self, x, z):
y = self.bar.buffer + x + z[0] + z[1]
return self.bar(x, z) + y.sum()
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
mod = Foo()
ep_strict = torch.export.export(mod, inp, strict=True) # noqa: F841
ep_non_strict = torch.export.export(mod, inp, strict=False)
gm_unflat_non_strict = unflatten(ep_non_strict)
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))
def test_unflattened_module_nodes_has_meta_val(self):
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + x, x * x
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + sum(self.submod(x))
orig_eager = MyModule()
export_module = torch.export.export(
orig_eager, (torch.rand(2, 3),), {}, strict=True
)
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def check_meta(gm):
for n in gm.graph.nodes:
if n.op == "output":
continue
self.assertTrue(n.meta.get("val") is not None)
for m in unflattened.modules():
check_meta(m)
def test_unflatten_requires_grad_param(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False)
def forward(self, x):
return self.p + x
with torch.device("meta"):
mod = M()
inputs = (torch.randn(3, 3, device="meta"),)
ep = export(mod, inputs, strict=True)
unflattened = unflatten(ep)
self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
self.assertTrue(unflattened.p.requires_grad is False)
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
class TransposeModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
def forward(self, x):
x = self.conv(x)
return x.transpose(0, 1)
x = torch.randn(32, 3, 64, 64)
exported_program = export(TransposeModule(), args=(x,), strict=True)
unflattened_module = unflatten(exported_program)
# Check the inputs of the created call_module node are in order
call_module_input_order = []
for node in unflattened_module.graph.nodes:
if node.op == "call_module":
transpose_module = unflattened_module.get_submodule(node.target)
for sub_node in transpose_module.graph.nodes:
if sub_node.op == "placeholder" or sub_node.op == "get_attr":
call_module_input_order.append(sub_node.op)
self.assertEqual(
call_module_input_order, ["placeholder", "get_attr", "get_attr"]
)
def test_unflatten_constant_tensor(self):
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.initializer = 0.1
def forward(self, x):
return x + torch.tensor(self.initializer)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod(x)
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
@skipIfTorchDynamo("custom objects not supported in dynamo yet")
def test_unflatten_constant_obj(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo: # noqa: F841
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def __obj_unflatten__(cls, flat_ctx):
return cls(**dict(flat_ctx))
def add_tensor(self, z):
return (self.x + self.y) * z
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + self.attr.add_tensor(x)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod(x)
with enable_torchbind_tracing():
export_module = torch.export.export(
Mod(), (torch.randn((2, 3)),), strict=False
)
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
# skip connection is not supported yet
@unittest.expectedFailure
def test_unflatten_skipped_call_module(self):
class C(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return a.d(x.cos())
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = C()
def forward(self, x):
return self.c(x) + x
class D(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sin()
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = B()
self.d = D()
def forward(self, x):
return self.b(x)
a = A()
# The call chain looks like this:
# A -> B -> C -> A.d
ep = torch.export.export(a, (torch.randn(3),), strict=False)
unflatten(ep)
def test_nested_leaf_non_strict(self):
class Leaf(torch.nn.Module):
def forward(self, x):
return x + 1
class Nested(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = Leaf()
def forward(self, x):
return self.leaf(x) + 2
class TopLevel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = Nested()
def forward(self, x):
return self.nested(x) + 3
ep = torch.export.export(
TopLevel(),
(torch.randn(3),),
strict=False,
preserve_module_call_signature=("nested",),
)
torch.export.unflatten(ep)
def test_unflatten_submodule_ordering(self):
class Module2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
def forward(self, x):
return x + self.buffer + self.param
class Module1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
def forward(self, x):
return x + self.buffer + self.param
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod2 = Module2()
self.mod3 = self.mod2
self.mod1 = Module1()
def forward(self, x):
return self.mod3(self.mod2(self.mod1(x)))
mod = Module()
ep = torch.export.export(mod, (torch.randn(3, 4),), strict=True)
unflattened = torch.export.unflatten(ep)
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
self.assertEqual(len(fqn_list), 4)
self.assertEqual(
[x for x, _ in mod.named_modules(remove_duplicate=False)],
fqn_list,
)
def test_duplicate_placeholder(self):
N, C, H, W = 1, 2, 2, 3
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
layer = torch.nn.LayerNorm([C, H, W])
self.norms = torch.nn.ModuleList(
[
layer, # reuse layer norm
layer,
layer,
]
)
def forward(self, input_):
for i in range(len(self.norms)):
output = self.norms[i](input_)
input_ = output
return output
mod = MyModule()
input_ = torch.randn(N, C, H, W)
ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
umod = unflatten(ep_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
umod = unflatten(ep_non_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
def test_simple_alias(self):
# handle weight sharing, check tensor ids after unflattening
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# alias param
self.bias = torch.nn.Parameter(torch.randn(4))
self.m = torch.nn.Linear(4, 4)
self.m.bias = self.bias
def forward(self, x):
return self.m(x) + self.bias
m = Foo()
inps = (torch.randn(4, 4),)
ep = export(m, inps, strict=True)
unep = unflatten(ep)
self.assertTrue(id(unep.m.bias) == id(unep.bias))
# handle aliasing where one alias is unused
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bias = torch.nn.Parameter(torch.randn(4))
self.m = torch.nn.Linear(4, 4)
self.m.bias = (
self.bias
) # self.bias is unused, aliasing should be handled
def forward(self, x):
return self.m(x)
m = Foo()
inps = (torch.randn(4, 4),)
ep = export(m, inps, strict=True)
unep = unflatten(ep)
self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
def test_attr_as_submod_input(self):
class layer(torch.nn.Module):
def forward(self, x, const) -> torch.Tensor:
return x + const
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.const = torch.nn.Buffer(torch.ones(4, 8))
self.layers = torch.nn.ModuleList([layer() for _ in range(2)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x, self.const)
return x
mod = M()
x = torch.randn(4, 8)
ep = export(mod, (x,), strict=True)
unflattened = unflatten(ep)
torch.testing.assert_close(unflattened(x), mod(x))
def test_dedup_sym_size(self):
# Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
# but only one copy of sym_size is created in the initial export graph.
# For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
# but for m2 floordiv should be passed in as a placeholder.
# Test that this is preserved, and the unflattened module runs correctly.
class M1(torch.nn.Module):
def forward(self, x, y):
d = x.size(0) // 2
return y[:d]
class M2(torch.nn.Module):
def forward(self, x, y):
d = x.size(0) // 2
return y[:d]
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M1()
self.m2 = M2()
def forward(self, x, y):
d = x.size(0) // 2
m1_res = self.m1(x, y)
m2_res = self.m2(x, y)
return y[d:] + m1_res + m2_res
inputs = (torch.ones(10), torch.ones(10))
d_ = torch.export.Dim("foo", max=2048)
d = 2 * d_
ep = torch.export.export(
M(),
inputs,
dynamic_shapes=((d,), (d,)),
strict=False,
preserve_module_call_signature=("m1",),
)
unflat = unflatten(ep)
unflat(*inputs)
fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
torch.ops.aten.sym_size.int
)
self.assertEqual(fn_count_sym_size(unflat.graph), 3)
self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)
def test_unflatten_eager(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
with _disable_interpreter():
unflattened = unflatten(export_module)
self.assertEqual(unflattened._run_with_interpreter, False)
self.assertEqual(unflattened.foo._run_with_interpreter, False)
inputs = (torch.rand(2, 3),)
# Compare the root modules and all submodules
self.compare_outputs(orig_eager, unflattened, inputs)
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
# Check state dicts are equal
orig_state_dict = orig_eager.state_dict()
exported_state_dict = unflattened.state_dict()
for name, value in orig_state_dict.items():
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
# Check composability with symbolic trace, as torchrec ddp uses symbolic
# tracer
symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs)
self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs)))
# torch.compile submodule
unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
self.compare_outputs(orig_eager, unflattened, inputs)
def test_unflatten_none(self):
class M2(torch.nn.Module):
def forward(self, x, y):
return x + x, None
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m2 = M2()
def forward(self, x, y):
x = x + x
return self.m2(x, y)
ep = export(
M(), (torch.rand(2, 3), None), preserve_module_call_signature=("m2",)
)
unflattened = unflatten(ep)
inp = (torch.randn(2, 3), None)
self.assertTrue(torch.allclose(M()(*inp)[0], unflattened(*inp)[0]))
def test_unflatten_empty_branch(self):
class M(torch.nn.Module):
def forward(self, x):
if x is None:
return torch.ones(3), torch.ones(3)
else:
return x + x, x * x
class M1(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = M()
def forward(self, x, y):
a, b = self.m(x)
c, d = self.m(y)
return a + b + c + d
ep = torch.export.export(M1(), (torch.randn(3), None))
unf = torch.export.unflatten(ep)
inp = (torch.randn(3), None)
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))
ep = torch.export.export(
M1(), (torch.randn(3), None), preserve_module_call_signature="m"
)
unf = torch.export.unflatten(ep)
inp = (torch.randn(3), None)
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))
if __name__ == "__main__":
run_tests()