mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117 Approved by: https://github.com/yanboliang, https://github.com/williamwen42 ghstack dependencies: #137114, #137115, #137116
368 lines
16 KiB
Python
368 lines
16 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import unittest
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._export.wrappers import _mark_strict_experimental
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
from torch.export import export
|
|
from torch.export._trace import _convert_ts_to_export_experimental
|
|
from torch.export.experimental import _export_forward_backward
|
|
from torch.testing import FileCheck
|
|
|
|
|
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
|
class TestExperiment(TestCase):
|
|
def test_with_buffer_as_submodule(self):
|
|
@_mark_strict_experimental
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.ones(3))
|
|
|
|
def forward(self, x):
|
|
y = x + 2
|
|
y.add_(4)
|
|
# this doesnt' work today with HOO
|
|
# self.buffer1.add_(6)
|
|
buffer_updated = self.buffer1 + 6
|
|
return x.sum() + y.sum() + buffer_updated.sum()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submodule = B()
|
|
|
|
def forward(self, x):
|
|
x_v2 = x.sin()
|
|
return (self.submodule(x_v2), x + 3)
|
|
|
|
inp = torch.randn(3)
|
|
ep = torch.export.export(M(), (inp,), strict=False)
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code.strip()),
|
|
"""\
|
|
def forward(self, b_submodule_buffer1, x):
|
|
sin = torch.ops.aten.sin.default(x)
|
|
strict_graph_0 = self.strict_graph_0
|
|
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
|
|
getitem = strict_mode[0]; strict_mode = None
|
|
add = torch.ops.aten.add.Tensor(x, 3); x = None
|
|
return (getitem, add)""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.strict_graph_0.code.strip()),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, 2)
|
|
add_1 = torch.ops.aten.add.Tensor(add, 4); add = None
|
|
add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None
|
|
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
|
sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
|
sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None
|
|
add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None
|
|
return (add_4,)""",
|
|
)
|
|
|
|
eager_mod = M()
|
|
ep = torch.export.export(eager_mod, (inp,), strict=True)
|
|
|
|
graph_res_1, graph_res_2 = ep.module()(inp)
|
|
eager_res_1, eager_res_2 = eager_mod(inp)
|
|
|
|
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
|
|
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
|
|
|
|
graph_res_1, graph_res_2 = ep.module()(inp)
|
|
eager_res_1, eager_res_2 = eager_mod(inp)
|
|
|
|
self.assertTrue(torch.allclose(graph_res_2, eager_res_2))
|
|
self.assertTrue(torch.allclose(graph_res_1, eager_res_1))
|
|
|
|
def test_mark_strict_with_container_type(self):
|
|
@_mark_strict_experimental
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x0 = x[0][0]
|
|
return x0.sum()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submodule = B()
|
|
|
|
def forward(self, x):
|
|
return self.submodule(x)
|
|
|
|
inp = ((torch.randn(3),),)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "strict_mode HOO doesn't work unless"
|
|
):
|
|
ep = torch.export.export(M(), inp, strict=False)
|
|
|
|
def test_torchscript_module_export(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.cos() + x.sin()
|
|
|
|
model_to_trace = M()
|
|
inps = (torch.randn(4, 4),)
|
|
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
|
|
|
|
exported_module = _convert_ts_to_export_experimental(
|
|
traced_module_by_torchscript, inps
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps)))
|
|
|
|
def test_torchscript_module_export_single_input(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.cos() + x.sin()
|
|
|
|
model_to_trace = M()
|
|
inps = torch.randn(4, 4)
|
|
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
|
|
|
|
exported_module = _convert_ts_to_export_experimental(
|
|
traced_module_by_torchscript, inps
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps)))
|
|
|
|
def test_torchscript_module_export_various_inputs_with_annotated_input_names(self):
|
|
def _check_equality_and_annotations(m_func, inps):
|
|
# Original module.
|
|
model_to_trace = m_func()
|
|
|
|
# ExportedProgram from TorchScript module.
|
|
traced_module_by_torchscript = torch.jit.trace(
|
|
m_func(), example_inputs=inps
|
|
)
|
|
exported_module = _convert_ts_to_export_experimental(
|
|
traced_module_by_torchscript, inps
|
|
)
|
|
|
|
# ExportedProgram from original module.
|
|
original_exported_module = torch.export.export(m_func(), inps)
|
|
|
|
# Check whether input annotations are the same as tracing the original module.
|
|
orig_ph_name_list = [
|
|
n.name
|
|
for n in original_exported_module.graph.nodes
|
|
if n.op == "placeholder"
|
|
]
|
|
ph_name_list = [
|
|
n.name for n in exported_module.graph.nodes if n.op == "placeholder"
|
|
]
|
|
self.assertEqual(orig_ph_name_list, ph_name_list)
|
|
|
|
# Check results equality.
|
|
self.assertTrue(
|
|
torch.allclose(exported_module(*inps), model_to_trace(*inps))
|
|
)
|
|
|
|
# Tuple
|
|
class MTuple(torch.nn.Module):
|
|
def forward(self, x: Tuple[torch.Tensor]):
|
|
return x[0] + x[1]
|
|
|
|
_check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),))
|
|
|
|
# List
|
|
class MList(torch.nn.Module):
|
|
def forward(self, x: List[torch.Tensor]):
|
|
return x[0] + x[1]
|
|
|
|
_check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],))
|
|
|
|
# Dict
|
|
class MDict(torch.nn.Module):
|
|
def forward(self, x: Dict[str, torch.Tensor]):
|
|
return x["0"] + x["1"]
|
|
|
|
_check_equality_and_annotations(
|
|
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
|
|
)
|
|
|
|
def test_joint_basic(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x):
|
|
return self.loss(
|
|
self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
|
|
)
|
|
|
|
m = Module()
|
|
example_inputs = (torch.randn(3),)
|
|
m(*example_inputs)
|
|
ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True)
|
|
joint_ep = _export_forward_backward(ep)
|
|
self.assertExpectedInline(
|
|
str(joint_ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|
view = torch.ops.aten.view.default(x, [1, 3]); x = None
|
|
permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None
|
|
addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None
|
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
|
alias = torch.ops.aten.alias.default(_softmax)
|
|
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
|
alias_2 = torch.ops.aten.alias.default(clone); clone = None
|
|
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
|
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
|
alias_5 = torch.ops.aten.alias.default(_log_softmax)
|
|
alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None
|
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
|
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
|
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None
|
|
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
|
alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None
|
|
exp = torch.ops.aten.exp.default(alias_8); alias_8 = None
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
|
alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
|
alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
|
mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
|
mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None
|
|
permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None
|
|
sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
|
view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None
|
|
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
|
|
return (div, permute_3, view_3)""",
|
|
)
|
|
ep = joint_ep.run_decompositions()
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|
view = torch.ops.aten.view.default(x, [1, 3]); x = None
|
|
permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None
|
|
addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None
|
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
|
alias = torch.ops.aten.alias.default(_softmax)
|
|
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
|
alias_2 = torch.ops.aten.alias.default(clone); clone = None
|
|
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
|
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
|
alias_5 = torch.ops.aten.alias.default(_log_softmax)
|
|
alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None
|
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
|
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
|
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None
|
|
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
|
alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None
|
|
exp = torch.ops.aten.exp.default(alias_8); alias_8 = None
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
|
alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
|
alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
|
mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
|
mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None
|
|
permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None
|
|
sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
|
view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None
|
|
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
|
|
return (div, permute_3, view_3)""",
|
|
)
|
|
|
|
def test_joint_dynamic(self) -> None:
|
|
from torch.export import Dim
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.y = torch.nn.Parameter(torch.randn(3))
|
|
|
|
def forward(self, x):
|
|
x = torch.ones(x.shape[0], 3)
|
|
return (self.y + x).sum()
|
|
|
|
m = Module()
|
|
example_inputs = (torch.randn(3),)
|
|
m(*example_inputs)
|
|
ep = torch.export._trace._export(
|
|
m, example_inputs, pre_dispatch=True, dynamic_shapes={"x": {0: Dim("x0")}}
|
|
)
|
|
joint_ep = _export_forward_backward(ep)
|
|
|
|
def test_joint_cifar10_backwards(self) -> None:
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# From Pytorch's CIFAR10 example:
|
|
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
self.fc2 = nn.Linear(120, 84)
|
|
self.fc3 = nn.Linear(84, 10)
|
|
self.loss = nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x, labels):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = self.pool(F.relu(self.conv2(x)))
|
|
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return self.loss(x, labels)
|
|
|
|
net = Net()
|
|
x = torch.randn(4, 3, 32, 32)
|
|
labels = torch.ones(4, dtype=torch.int64)
|
|
inputs = (x, labels)
|
|
|
|
ep = export(net, inputs)
|
|
ep = _export_forward_backward(ep)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|