mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Lets explore firs a couple of problem related to replacements and runtime assertions. #### example problem 1 if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert not being inserted during insert_deferred_runtime_asserts. The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen, but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc. #### example problem 2 Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure. internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/ ## The solution : Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item(). The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion. This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions. I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue. Other approaches that did not work : #### ban replacements on unbacked. 1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements but problem 2 becomes more problematic. 2. Problem two, it affects the quality of reasoning ! in a bad way. #### Apply specialization on runtime assertions before codegen . 1. Can fix some issues, but may lead also to runtime assertions becoming NOPs. 2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661 Approved by: https://github.com/jansel
1011 lines
35 KiB
Python
1011 lines
35 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import copy
|
|
import unittest
|
|
from re import escape
|
|
from typing import Any, List, Optional
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch.export import export, FlatArgsAdapter, unflatten
|
|
from torch.export.unflatten import _disable_interpreter
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
|
from torch.utils._pytree import TreeSpec
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestUnflatten(TestCase):
|
|
def compare_outputs(self, eager, unflattened, args):
|
|
orig_output = eager(*args)
|
|
unflattened_output = unflattened(*args)
|
|
self.assertTrue(torch.allclose(orig_output, unflattened_output))
|
|
|
|
def test_unflatten_nested(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x / x
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
|
|
unflattened = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
|
|
# Compare the root modules and all submodules
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
|
|
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
|
|
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
|
|
|
|
# Check state dicts are equal
|
|
orig_state_dict = orig_eager.state_dict()
|
|
exported_state_dict = unflattened.state_dict()
|
|
for name, value in orig_state_dict.items():
|
|
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
|
|
|
|
def test_unflatten_buffer_mutation(self):
|
|
class Child(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
self.child2buffer.add_(x)
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
return x * self.rootparam
|
|
|
|
eager_module = MyModule()
|
|
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
|
|
unflattened_module = unflatten(export_module)
|
|
|
|
# Buffer should look the same before and after one run
|
|
eager_buffer = eager_module.foo.child2buffer
|
|
unflattened_buffer = unflattened_module.foo.child2buffer
|
|
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
eager_module(*inputs)
|
|
unflattened_module(*inputs)
|
|
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
|
|
|
|
def test_unflatten_nested_access(self):
|
|
class Child(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x + self.foo.child2buffer
|
|
x = self.foo(x)
|
|
return x
|
|
|
|
eager_module = MyModule()
|
|
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
|
|
unflattened_module = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(eager_module, unflattened_module, inputs)
|
|
|
|
def test_unflatten_shared_submodule(self):
|
|
class Shared(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
layernorm = torch.nn.LayerNorm(10)
|
|
self.sub_net = torch.nn.Sequential(
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.sub_net(x)
|
|
|
|
eager_module = Shared()
|
|
inps = (torch.rand(10),)
|
|
export_module = export(eager_module, inps, {}, strict=True)
|
|
unflattened_module = unflatten(export_module)
|
|
self.compare_outputs(eager_module, unflattened_module, inps)
|
|
self.assertTrue(hasattr(unflattened_module, "sub_net"))
|
|
for i in range(len(eager_module.sub_net)):
|
|
self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
|
|
self.assertEqual(
|
|
id(getattr(unflattened_module.sub_net, "0")),
|
|
id(getattr(unflattened_module.sub_net, "2")),
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
|
|
def test_unflatten_preserve_signature(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, zx, y):
|
|
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
|
|
def forward(self, x, y):
|
|
z = torch.ones_like(x)
|
|
xw = self.nested((z, x), y={"key": y})
|
|
return xw["w"] + z - xw["x"]
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x - 1
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
|
|
def forward(self, x, y):
|
|
x = self.foo(x, y)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inps = torch.rand(2, 3), torch.rand(2, 3)
|
|
for strict in [True, False]:
|
|
export_module = export(
|
|
orig_eager,
|
|
inps,
|
|
{},
|
|
preserve_module_call_signature=("foo.nested",),
|
|
strict=strict,
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
self.compare_outputs(export_module.module(), unflattened, inps)
|
|
unflattened.foo.nested = NestedChild()
|
|
self.compare_outputs(export_module.module(), unflattened, inps)
|
|
|
|
# Test tree spec mismatched input
|
|
orig_outs = export_module.module()(*inps)
|
|
new_inps = *inps, torch.rand(2, 3)
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
|
|
):
|
|
unflattened(new_inps)
|
|
|
|
# With flat args adapter
|
|
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
|
|
def adapt(
|
|
self,
|
|
target_spec: TreeSpec,
|
|
input_spec: TreeSpec,
|
|
input_args: List[Any],
|
|
metadata: dict[str, Any],
|
|
obj: Optional[Any] = None,
|
|
) -> List[Any]:
|
|
while len(input_args) > 2:
|
|
input_args.pop(-1)
|
|
return input_args
|
|
|
|
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
|
|
new_outs = unflattened(*new_inps)
|
|
self.assertTrue(torch.allclose(orig_outs, new_outs))
|
|
|
|
def test_unflatten_param_list_dict(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param_list = torch.nn.ParameterList()
|
|
self.param_dict = torch.nn.ParameterDict()
|
|
for i in range(2):
|
|
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
|
|
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
|
|
torch.randn((2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
for i in range(2):
|
|
x = x + self.param_list[i]
|
|
x = x + self.param_dict[f"key_{i}"]
|
|
return x
|
|
|
|
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.compare_outputs(
|
|
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
def test_unflatten_preserve_with_unused_input(self):
|
|
class M1(torch.nn.Module):
|
|
def forward(self, x, a, b):
|
|
return x + a, b
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = M1()
|
|
|
|
def forward(self, x, y):
|
|
a, b = torch.topk(y, 2)
|
|
return self.m1(x, a, b)[0]
|
|
|
|
ep = torch.export.export(
|
|
M(),
|
|
(torch.randn(2), torch.randn(5)),
|
|
preserve_module_call_signature=("m1",),
|
|
strict=False,
|
|
)
|
|
ep.graph.eliminate_dead_code()
|
|
unflattened = unflatten(ep)
|
|
self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))
|
|
|
|
def test_unflatten_wrong_input(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param_list = torch.nn.ParameterList()
|
|
self.param_dict = torch.nn.ParameterDict()
|
|
for i in range(2):
|
|
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
|
|
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
|
|
torch.randn((2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
a = x.sum()
|
|
for i in range(2):
|
|
a = a + self.param_list[i].sum()
|
|
a = a + self.param_dict[f"key_{i}"].sum()
|
|
return a
|
|
|
|
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
|
|
):
|
|
export_module.module()(torch.randn(6, 6))
|
|
|
|
unflattened = unflatten(export_module)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
|
|
):
|
|
unflattened(torch.randn(6, 6))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
def test_unflatten_with_inplace_compile(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x / x
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
export_module = torch.export.export(
|
|
orig_eager, (torch.rand(2, 3),), {}, strict=True
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
|
|
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
|
|
from torch._dynamo.backends.debugging import ExplainWithBackend
|
|
|
|
eb = ExplainWithBackend("inductor")
|
|
unflattened.foo.compile(backend=eb, fullgraph=True)
|
|
inputs = (torch.randn(2, 3),)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
self.assertEqual(len(eb.graphs), 1)
|
|
|
|
unflattened.compile()
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def test_fx_trace(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x = x[0] + x[1]
|
|
x = x + y["foo"]
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
|
|
export_module = export(orig_eager, inputs, {}, strict=True)
|
|
|
|
unflattened = unflatten(export_module)
|
|
torch.fx.symbolic_trace(
|
|
unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
|
|
)
|
|
|
|
def test_double_nested_submodule(self):
|
|
class SubSubMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x * x
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.subsubmod = SubSubMod()
|
|
|
|
def forward(self, x):
|
|
return x - x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + self.submod.subsubmod(x)
|
|
|
|
orig_eager = MyModule()
|
|
export_module = torch.export.export(
|
|
orig_eager, (torch.rand(2, 3),), {}, strict=True
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def test_unflatten_container_type(self):
|
|
class Leaf(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.leaf = Leaf()
|
|
self.buffer = torch.nn.Buffer(torch.randn(4, 4))
|
|
|
|
def forward(self, x, z):
|
|
return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bar = Bar()
|
|
|
|
def forward(self, x, z):
|
|
y = self.bar.buffer + x + z[0] + z[1]
|
|
return self.bar(x, z) + y.sum()
|
|
|
|
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
|
|
mod = Foo()
|
|
|
|
ep_strict = torch.export.export(mod, inp, strict=True) # noqa: F841
|
|
ep_non_strict = torch.export.export(mod, inp, strict=False)
|
|
|
|
gm_unflat_non_strict = unflatten(ep_non_strict)
|
|
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))
|
|
|
|
def test_unflattened_module_nodes_has_meta_val(self):
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x + x, x * x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + sum(self.submod(x))
|
|
|
|
orig_eager = MyModule()
|
|
export_module = torch.export.export(
|
|
orig_eager, (torch.rand(2, 3),), {}, strict=True
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def check_meta(gm):
|
|
for n in gm.graph.nodes:
|
|
if n.op == "output":
|
|
continue
|
|
self.assertTrue(n.meta.get("val") is not None)
|
|
|
|
for m in unflattened.modules():
|
|
check_meta(m)
|
|
|
|
def test_unflatten_requires_grad_param(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
return self.p + x
|
|
|
|
with torch.device("meta"):
|
|
mod = M()
|
|
|
|
inputs = (torch.randn(3, 3, device="meta"),)
|
|
ep = export(mod, inputs, strict=True)
|
|
unflattened = unflatten(ep)
|
|
self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
|
|
self.assertTrue(unflattened.p.requires_grad is False)
|
|
|
|
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
|
|
class TransposeModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.transpose(0, 1)
|
|
|
|
x = torch.randn(32, 3, 64, 64)
|
|
exported_program = export(TransposeModule(), args=(x,), strict=True)
|
|
unflattened_module = unflatten(exported_program)
|
|
|
|
# Check the inputs of the created call_module node are in order
|
|
call_module_input_order = []
|
|
for node in unflattened_module.graph.nodes:
|
|
if node.op == "call_module":
|
|
transpose_module = unflattened_module.get_submodule(node.target)
|
|
for sub_node in transpose_module.graph.nodes:
|
|
if sub_node.op == "placeholder" or sub_node.op == "get_attr":
|
|
call_module_input_order.append(sub_node.op)
|
|
self.assertEqual(
|
|
call_module_input_order, ["placeholder", "get_attr", "get_attr"]
|
|
)
|
|
|
|
def test_unflatten_constant_tensor(self):
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.initializer = 0.1
|
|
|
|
def forward(self, x):
|
|
return x + torch.tensor(self.initializer)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + self.submod(x)
|
|
|
|
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.compare_outputs(
|
|
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
|
)
|
|
|
|
@skipIfTorchDynamo("custom objects not supported in dynamo yet")
|
|
def test_unflatten_constant_obj(self):
|
|
init_torchbind_implementations()
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo: # noqa: F841
|
|
def __init__(self, x: int, y: int):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def __obj_unflatten__(cls, flat_ctx):
|
|
return cls(**dict(flat_ctx))
|
|
|
|
def add_tensor(self, z):
|
|
return (self.x + self.y) * z
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return x + self.submod(x)
|
|
|
|
with enable_torchbind_tracing():
|
|
export_module = torch.export.export(
|
|
Mod(), (torch.randn((2, 3)),), strict=False
|
|
)
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.compare_outputs(
|
|
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
|
)
|
|
|
|
# skip connection is not supported yet
|
|
@unittest.expectedFailure
|
|
def test_unflatten_skipped_call_module(self):
|
|
class C(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return a.d(x.cos())
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.c = C()
|
|
|
|
def forward(self, x):
|
|
return self.c(x) + x
|
|
|
|
class D(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x.sin()
|
|
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.b = B()
|
|
self.d = D()
|
|
|
|
def forward(self, x):
|
|
return self.b(x)
|
|
|
|
a = A()
|
|
|
|
# The call chain looks like this:
|
|
# A -> B -> C -> A.d
|
|
ep = torch.export.export(a, (torch.randn(3),), strict=False)
|
|
unflatten(ep)
|
|
|
|
def test_nested_leaf_non_strict(self):
|
|
class Leaf(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class Nested(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.leaf = Leaf()
|
|
|
|
def forward(self, x):
|
|
return self.leaf(x) + 2
|
|
|
|
class TopLevel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = Nested()
|
|
|
|
def forward(self, x):
|
|
return self.nested(x) + 3
|
|
|
|
ep = torch.export.export(
|
|
TopLevel(),
|
|
(torch.randn(3),),
|
|
strict=False,
|
|
preserve_module_call_signature=("nested",),
|
|
)
|
|
|
|
torch.export.unflatten(ep)
|
|
|
|
def test_unflatten_submodule_ordering(self):
|
|
class Module2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
|
|
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
|
|
|
|
def forward(self, x):
|
|
return x + self.buffer + self.param
|
|
|
|
class Module1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
|
|
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
|
|
|
|
def forward(self, x):
|
|
return x + self.buffer + self.param
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mod2 = Module2()
|
|
self.mod3 = self.mod2
|
|
self.mod1 = Module1()
|
|
|
|
def forward(self, x):
|
|
return self.mod3(self.mod2(self.mod1(x)))
|
|
|
|
mod = Module()
|
|
|
|
ep = torch.export.export(mod, (torch.randn(3, 4),), strict=True)
|
|
|
|
unflattened = torch.export.unflatten(ep)
|
|
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
|
|
self.assertEqual(len(fqn_list), 4)
|
|
self.assertEqual(
|
|
[x for x, _ in mod.named_modules(remove_duplicate=False)],
|
|
fqn_list,
|
|
)
|
|
|
|
def test_duplicate_placeholder(self):
|
|
N, C, H, W = 1, 2, 2, 3
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
layer = torch.nn.LayerNorm([C, H, W])
|
|
self.norms = torch.nn.ModuleList(
|
|
[
|
|
layer, # reuse layer norm
|
|
layer,
|
|
layer,
|
|
]
|
|
)
|
|
|
|
def forward(self, input_):
|
|
for i in range(len(self.norms)):
|
|
output = self.norms[i](input_)
|
|
input_ = output
|
|
return output
|
|
|
|
mod = MyModule()
|
|
input_ = torch.randn(N, C, H, W)
|
|
|
|
ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
|
|
umod = unflatten(ep_strict)
|
|
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
|
|
|
|
ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
|
|
umod = unflatten(ep_non_strict)
|
|
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
|
|
|
|
def test_simple_alias(self):
|
|
# handle weight sharing, check tensor ids after unflattening
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
# alias param
|
|
self.bias = torch.nn.Parameter(torch.randn(4))
|
|
self.m = torch.nn.Linear(4, 4)
|
|
self.m.bias = self.bias
|
|
|
|
def forward(self, x):
|
|
return self.m(x) + self.bias
|
|
|
|
m = Foo()
|
|
inps = (torch.randn(4, 4),)
|
|
ep = export(m, inps, strict=True)
|
|
unep = unflatten(ep)
|
|
self.assertTrue(id(unep.m.bias) == id(unep.bias))
|
|
|
|
# handle aliasing where one alias is unused
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bias = torch.nn.Parameter(torch.randn(4))
|
|
self.m = torch.nn.Linear(4, 4)
|
|
self.m.bias = (
|
|
self.bias
|
|
) # self.bias is unused, aliasing should be handled
|
|
|
|
def forward(self, x):
|
|
return self.m(x)
|
|
|
|
m = Foo()
|
|
inps = (torch.randn(4, 4),)
|
|
ep = export(m, inps, strict=True)
|
|
unep = unflatten(ep)
|
|
self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
|
|
|
|
def test_attr_as_submod_input(self):
|
|
class layer(torch.nn.Module):
|
|
def forward(self, x, const) -> torch.Tensor:
|
|
return x + const
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.const = torch.nn.Buffer(torch.ones(4, 8))
|
|
self.layers = torch.nn.ModuleList([layer() for _ in range(2)])
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
for layer in self.layers:
|
|
x = layer(x, self.const)
|
|
return x
|
|
|
|
mod = M()
|
|
x = torch.randn(4, 8)
|
|
ep = export(mod, (x,), strict=True)
|
|
unflattened = unflatten(ep)
|
|
torch.testing.assert_close(unflattened(x), mod(x))
|
|
|
|
def test_dedup_sym_size(self):
|
|
# Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
|
|
# but only one copy of sym_size is created in the initial export graph.
|
|
# For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
|
|
# but for m2 floordiv should be passed in as a placeholder.
|
|
# Test that this is preserved, and the unflattened module runs correctly.
|
|
class M1(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
d = x.size(0) // 2
|
|
return y[:d]
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
d = x.size(0) // 2
|
|
return y[:d]
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = M1()
|
|
self.m2 = M2()
|
|
|
|
def forward(self, x, y):
|
|
d = x.size(0) // 2
|
|
m1_res = self.m1(x, y)
|
|
m2_res = self.m2(x, y)
|
|
return y[d:] + m1_res + m2_res
|
|
|
|
inputs = (torch.ones(10), torch.ones(10))
|
|
d_ = torch.export.Dim("foo", max=2048)
|
|
d = 2 * d_
|
|
ep = torch.export.export(
|
|
M(),
|
|
inputs,
|
|
dynamic_shapes=((d,), (d,)),
|
|
strict=False,
|
|
preserve_module_call_signature=("m1",),
|
|
)
|
|
unflat = unflatten(ep)
|
|
unflat(*inputs)
|
|
|
|
fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
|
|
torch.ops.aten.sym_size.int
|
|
)
|
|
self.assertEqual(fn_count_sym_size(unflat.graph), 3)
|
|
self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
|
|
self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)
|
|
|
|
def test_unflatten_eager(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x / x
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
|
|
with _disable_interpreter():
|
|
unflattened = unflatten(export_module)
|
|
|
|
self.assertEqual(unflattened._run_with_interpreter, False)
|
|
self.assertEqual(unflattened.foo._run_with_interpreter, False)
|
|
|
|
inputs = (torch.rand(2, 3),)
|
|
|
|
# Compare the root modules and all submodules
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
|
|
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
|
|
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
|
|
|
|
# Check state dicts are equal
|
|
orig_state_dict = orig_eager.state_dict()
|
|
exported_state_dict = unflattened.state_dict()
|
|
for name, value in orig_state_dict.items():
|
|
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
|
|
|
|
# Check composability with symbolic trace, as torchrec ddp uses symbolic
|
|
# tracer
|
|
symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs)
|
|
self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs)))
|
|
|
|
# torch.compile submodule
|
|
unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
|
|
self.compare_outputs(orig_eager, unflattened, inputs)
|
|
|
|
def test_unflatten_none(self):
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + x, None
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m2 = M2()
|
|
|
|
def forward(self, x, y):
|
|
x = x + x
|
|
return self.m2(x, y)
|
|
|
|
ep = export(
|
|
M(), (torch.rand(2, 3), None), preserve_module_call_signature=("m2",)
|
|
)
|
|
unflattened = unflatten(ep)
|
|
inp = (torch.randn(2, 3), None)
|
|
self.assertTrue(torch.allclose(M()(*inp)[0], unflattened(*inp)[0]))
|
|
|
|
def test_unflatten_empty_branch(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x is None:
|
|
return torch.ones(3), torch.ones(3)
|
|
else:
|
|
return x + x, x * x
|
|
|
|
class M1(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.m = M()
|
|
|
|
def forward(self, x, y):
|
|
a, b = self.m(x)
|
|
c, d = self.m(y)
|
|
return a + b + c + d
|
|
|
|
ep = torch.export.export(M1(), (torch.randn(3), None))
|
|
unf = torch.export.unflatten(ep)
|
|
inp = (torch.randn(3), None)
|
|
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))
|
|
|
|
ep = torch.export.export(
|
|
M1(), (torch.randn(3), None), preserve_module_call_signature="m"
|
|
)
|
|
unf = torch.export.unflatten(ep)
|
|
inp = (torch.randn(3), None)
|
|
self.assertTrue(torch.allclose(unf(*inp), M1()(*inp)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|