mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Handle the case where there's default arguments on function signature. Test Plan: pytest test/export/test_experimental.py -k test_dynamo_graph_capture_default_args Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166654 Approved by: https://github.com/tugsbayasgalan
601 lines
23 KiB
Python
601 lines
23 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import copy
|
|
import types
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
from torch.export import export
|
|
from torch.export.experimental import _export_forward_backward, _sticky_export
|
|
from torch.export.graph_signature import OutputKind
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import TEST_CUDA
|
|
|
|
|
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
|
class TestExperiment(TestCase):
|
|
def test_joint_basic(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x):
|
|
return self.loss(
|
|
self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
|
|
)
|
|
|
|
m = Module()
|
|
example_inputs = (torch.randn(3),)
|
|
m(*example_inputs)
|
|
with torch._export.config.patch(use_new_tracer_experimental=True):
|
|
ep = torch.export.export(m, example_inputs, strict=True)
|
|
joint_ep = _export_forward_backward(ep)
|
|
self.assertExpectedInline(
|
|
str(joint_ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|
view = torch.ops.aten.view.default(x, [1, 3]); x = None
|
|
permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None
|
|
addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None
|
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
|
alias = torch.ops.aten.alias.default(_softmax)
|
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
|
alias_1 = torch.ops.aten.alias.default(_log_softmax)
|
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
|
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
|
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
|
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
|
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
|
alias_3 = torch.ops.aten.alias.default(alias); alias = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
|
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
|
mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None
|
|
permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None
|
|
sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
|
view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None
|
|
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
|
|
return (div, permute_3, view_3)""",
|
|
)
|
|
ep = joint_ep.run_decompositions()
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|
view = torch.ops.aten.view.default(x, [1, 3]); x = None
|
|
permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None
|
|
addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None
|
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
|
alias = torch.ops.aten.alias.default(_softmax)
|
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
|
alias_1 = torch.ops.aten.alias.default(_log_softmax)
|
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
|
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
|
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
|
alias_2 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
|
exp = torch.ops.aten.exp.default(alias_2); alias_2 = None
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
|
alias_3 = torch.ops.aten.alias.default(alias); alias = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_3); sub = None
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
|
mul_4 = torch.ops.aten.mul.Tensor(alias_3, sum_3); alias_3 = sum_3 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
|
mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None
|
|
permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None
|
|
sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
|
view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None
|
|
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
|
|
return (div, permute_3, view_3)""",
|
|
)
|
|
|
|
def test_joint_dynamic(self) -> None:
|
|
from torch.export import Dim
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.y = torch.nn.Parameter(torch.randn(3))
|
|
|
|
def forward(self, x):
|
|
x = torch.ones(x.shape[0], 3)
|
|
return (self.y + x).sum()
|
|
|
|
m = Module()
|
|
example_inputs = (torch.randn(3),)
|
|
m(*example_inputs)
|
|
ep = torch.export.export(
|
|
m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}}, strict=True
|
|
)
|
|
_export_forward_backward(ep)
|
|
|
|
def test_joint_cifar10_backwards(self) -> None:
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# From Pytorch's CIFAR10 example:
|
|
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
self.fc2 = nn.Linear(120, 84)
|
|
self.fc3 = nn.Linear(84, 10)
|
|
self.loss = nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x, labels):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = self.pool(F.relu(self.conv2(x)))
|
|
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return self.loss(x, labels)
|
|
|
|
net = Net()
|
|
x = torch.randn(4, 3, 32, 32)
|
|
labels = torch.ones(4, dtype=torch.int64)
|
|
inputs = (x, labels)
|
|
|
|
ep = export(net, inputs, strict=True)
|
|
ep = _export_forward_backward(ep)
|
|
|
|
def test_joint_loss_index(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, index):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(4, 4)
|
|
self.index = index
|
|
|
|
def forward(self, x):
|
|
x = self.l(x)
|
|
x = x.sum()
|
|
if self.index == 0:
|
|
return x, -x.detach()
|
|
else:
|
|
return x.detach(), x
|
|
|
|
inputs = (torch.randn(4, 4),)
|
|
for i in [0, 1]:
|
|
ep = export(Foo(i), inputs, strict=True)
|
|
ep_joint = _export_forward_backward(ep, joint_loss_index=i)
|
|
for j, spec in enumerate(ep_joint.graph_signature.output_specs):
|
|
if i == j:
|
|
self.assertTrue(spec.kind == OutputKind.LOSS_OUTPUT)
|
|
else:
|
|
self.assertTrue(spec.kind != OutputKind.LOSS_OUTPUT)
|
|
|
|
def test_joint_buffer_input_mutations(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(4, 4)
|
|
self.register_buffer("buf", torch.randn(4))
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x, label):
|
|
x.add_(self.buf)
|
|
x = self.l(x)
|
|
self.buf.add_(2.0)
|
|
return self.loss(x, label)
|
|
|
|
inputs = (
|
|
torch.randn(4, 4),
|
|
torch.randint(0, 4, (4,)),
|
|
)
|
|
ep = export(Foo(), inputs)
|
|
ep_joint = _export_forward_backward(ep)
|
|
self.assertEqual(len(ep_joint.graph_signature.output_specs), 5)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[0].kind,
|
|
OutputKind.BUFFER_MUTATION,
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[0].target,
|
|
"buf",
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[1].kind,
|
|
OutputKind.USER_INPUT_MUTATION,
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[1].target,
|
|
"x",
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[2].kind,
|
|
OutputKind.LOSS_OUTPUT,
|
|
)
|
|
|
|
def test_sticky_export(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Pipeline:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def generate(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
inp = torch.randn(4, 4)
|
|
|
|
p = Pipeline(Model())
|
|
orig_forward = p.model.forward
|
|
p.model.forward = _sticky_export(p.model.forward)
|
|
res = p.generate(inp)
|
|
|
|
p.model.forward = orig_forward
|
|
res2 = p.generate(inp)
|
|
self.assertTrue(torch.allclose(res, res2))
|
|
|
|
def test_sticky_export_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
if x.shape[0] < 5:
|
|
return self.linear(x)
|
|
return x.sin()
|
|
|
|
class Pipeline:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def generate(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
inp = torch.randn(4, 4)
|
|
|
|
def callback(*args, **kwargs):
|
|
# I think it is bit weird to use the forward arg name here, so
|
|
# lets just use ShapeCollections
|
|
|
|
flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs))
|
|
collections = torch.export.ShapesCollection()
|
|
for arg in flat_args:
|
|
if isinstance(arg, torch.Tensor):
|
|
collections[arg] = {
|
|
i: torch.export.Dim.AUTO for i in range(len(arg.shape))
|
|
}
|
|
return collections
|
|
|
|
p = Pipeline(Model())
|
|
p.model.forward = _sticky_export(
|
|
p.model.forward, dynamic_shapes_callback=callback
|
|
)
|
|
_ = p.generate(inp)
|
|
self.assertExpectedInline(
|
|
str(p.model.forward._exported_artifact.code).strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
linear_weight = self.linear.weight
|
|
linear_bias = self.linear.bias
|
|
_guards_fn = self._guards_fn(x); _guards_fn = None
|
|
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
|
|
return pytree.tree_unflatten((linear,), self._out_spec)""",
|
|
)
|
|
|
|
def test_sticky_export_nested_inp(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, *, inputs):
|
|
return self.linear(inputs[0]) + self.linear(inputs[1])
|
|
|
|
class Pipeline:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def generate(self, *, input_tensor, input_tensor2):
|
|
inputs = [input_tensor, input_tensor2]
|
|
return self.model(inputs=inputs)
|
|
|
|
inp = torch.randn(4, 4)
|
|
inp2 = torch.randn(4, 4)
|
|
|
|
p = Pipeline(Model())
|
|
orig_forward = p.model.forward
|
|
p.model.forward = _sticky_export(p.model.forward)
|
|
res = p.generate(input_tensor=inp, input_tensor2=inp2)
|
|
|
|
p.model.forward = orig_forward
|
|
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
|
|
self.assertTrue(torch.allclose(res, res2))
|
|
|
|
def test_export_add_in_out_info(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, dct, lst, bleh):
|
|
x = dct["a"] * lst[1][0]
|
|
y = dct["b"] * lst[0]
|
|
out_dict = {}
|
|
# Mutate and get a new entry in there
|
|
lst_copy = lst.copy()
|
|
lst_copy.append(lst[0])
|
|
out_dict["a"] = x
|
|
out_dict["b"] = y
|
|
return (
|
|
dct["a"],
|
|
out_dict["b"],
|
|
bleh,
|
|
lst_copy[-1],
|
|
out_dict["a"],
|
|
[5, 6],
|
|
)
|
|
|
|
dct = {"a": torch.randn(2, 3), "b": torch.randn(2, 3)}
|
|
lst = [torch.randn(2, 3), [torch.randn(2, 3), torch.randn(2, 3)]]
|
|
|
|
export_inputs = ((dct, lst, 56), {})
|
|
eager_inputs = copy.deepcopy(export_inputs)
|
|
|
|
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
|
|
|
graph_module = _dynamo_graph_capture_for_export(Foo())(
|
|
*export_inputs[0], **export_inputs[1]
|
|
)
|
|
|
|
res_export = graph_module(*export_inputs[0], **export_inputs[1])
|
|
res_eager = Foo()(*eager_inputs[0], **eager_inputs[1])
|
|
|
|
self.assertEqual(res_export, res_eager)
|
|
|
|
def test_export_leaf(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sin()
|
|
|
|
export_inputs = ((torch.randn(4, 4),), {})
|
|
eager_inputs = copy.deepcopy(export_inputs)
|
|
|
|
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
|
|
|
graph_module = _dynamo_graph_capture_for_export(Foo())(
|
|
*export_inputs[0], **export_inputs[1]
|
|
)
|
|
|
|
res_export = graph_module(*export_inputs[0], **export_inputs[1])
|
|
res_eager = Foo()(*eager_inputs[0], **eager_inputs[1])
|
|
|
|
self.assertEqual(res_export, res_eager)
|
|
|
|
def test_dynamo_graph_capture(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, dct, lst, bleh):
|
|
x = dct["a"] * lst[1][0]
|
|
y = dct["b"] * lst[0]
|
|
out_dict = {}
|
|
|
|
# Mutate and get a new entry in there
|
|
lst_copy = lst.copy()
|
|
lst_copy.append(lst[0])
|
|
out_dict["a"] = x
|
|
out_dict["b"] = y
|
|
return (
|
|
dct["a"],
|
|
out_dict["b"],
|
|
bleh,
|
|
lst_copy[-1],
|
|
out_dict["a"],
|
|
[5, 6],
|
|
)
|
|
|
|
foo = Foo()
|
|
|
|
def make_inputs():
|
|
return (
|
|
{"a": torch.randn(2, 3), "b": torch.randn(2, 3)},
|
|
[torch.randn(2, 3), (torch.randn(2, 3),)],
|
|
torch.randn(2, 3),
|
|
)
|
|
|
|
trace_inputs = make_inputs()
|
|
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
|
test_inputs = make_inputs()
|
|
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
|
|
|
def test_dynamo_graph_capture_custom_pytree_type(self):
|
|
import torch.utils._pytree as pytree
|
|
|
|
@dataclass
|
|
class Bar:
|
|
x: torch.Tensor
|
|
y: torch.Tensor
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, bar: Bar):
|
|
return bar.x + bar.y
|
|
|
|
foo = Foo()
|
|
|
|
def make_inputs():
|
|
return (Bar(torch.randn(2, 3), torch.randn(2, 3)),)
|
|
|
|
pytree.register_dataclass(Bar)
|
|
try:
|
|
trace_inputs = make_inputs()
|
|
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
|
test_inputs = make_inputs()
|
|
self.assertExpectedInline(
|
|
gm._in_shuffle_graph.code.strip("\r\n "),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
return (arg1_1, arg2_1)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\r\n "),
|
|
"""\
|
|
def forward(self, args_0):
|
|
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0,))
|
|
L_bar_x , L_bar_y , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
|
|
l_bar_x = L_bar_x
|
|
l_bar_y = L_bar_y
|
|
add = l_bar_x + l_bar_y; l_bar_x = l_bar_y = None
|
|
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, add), self._out_spec)""",
|
|
)
|
|
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
|
finally:
|
|
pytree._deregister_pytree_node(Bar)
|
|
|
|
def test_dynamo_graph_capture_closure(self):
|
|
from torch.export import Dim
|
|
|
|
N = 3
|
|
outer = torch.randn(10, 32)
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = x + outer
|
|
y = z[:-1, :] # [s0 - 1, 32]
|
|
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
|
|
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
|
|
return reshaped
|
|
|
|
inps = (torch.randn(10, 32),)
|
|
ep = dynamo_graph_capture_for_export(MyModel())(*inps)
|
|
self.assertExpectedInline(
|
|
ep._in_shuffle_graph.code.strip("\r\n "),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
_tensor_constant0 = self._tensor_constant0
|
|
return (arg1_1, _tensor_constant0)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.code.strip("\r\n "),
|
|
"""\
|
|
def forward(self, args_0):
|
|
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
|
|
L_x_ , L_outer_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
|
|
l_x_ = L_x_
|
|
l_outer_ = L_outer_
|
|
z = l_x_ + l_outer_; l_x_ = l_outer_ = None
|
|
y = z[(slice(None, -1, None), slice(None, None, None))]; z = None
|
|
stacked = torch.stack([y, y, y], dim = 0); y = None
|
|
reshaped = stacked.reshape(-1, 3, 32); stacked = None
|
|
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, reshaped), self._out_spec)""",
|
|
)
|
|
self.assertEqual(ep(*inps), MyModel()(*inps))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
|
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
|
|
class DummyOp(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, scalar):
|
|
ctx.save_for_backward(x)
|
|
return x + scalar
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out, None
|
|
|
|
def mock_fw_compute(x):
|
|
with fx_traceback.annotate({"compute": 0}):
|
|
return DummyOp.apply(x, 10)
|
|
|
|
def mock_bw_comm(x):
|
|
with fx_traceback.annotate({"comm": 0}):
|
|
return DummyOp.apply(x, 20)
|
|
|
|
def mock_bw_compute(x):
|
|
return DummyOp.apply(x, 30)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, fw_in, bw_in):
|
|
fw_out = mock_fw_compute(fw_in)
|
|
# bw_in blocks bw_out
|
|
bw_in = mock_bw_comm(bw_in)
|
|
bw_out = mock_bw_compute(bw_in)
|
|
return fw_out, bw_out
|
|
|
|
def input_fn():
|
|
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
|
|
grad_ins = (torch.rand(2, 128, device="cuda"),)
|
|
return (
|
|
*inputs,
|
|
*grad_ins,
|
|
)
|
|
|
|
with torch.device("meta"):
|
|
model = Model()
|
|
|
|
import torch.fx.traceback as fx_traceback
|
|
|
|
with fx_traceback.preserve_node_meta():
|
|
gm = dynamo_graph_capture_for_export(model)(*input_fn())
|
|
|
|
"""
|
|
def forward(self, args_0, args_1):
|
|
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0, args_1,))
|
|
L_fw_in_ , L_bw_in_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
|
|
l_fw_in_ = L_fw_in_
|
|
l_bw_in_ = L_bw_in_
|
|
fwd_body_0 = self.fwd_body_0
|
|
bwd_body_0 = self.bwd_body_0
|
|
fw_out = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_fw_in_, args_tensor_mask = [True, False], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_fw_in_ = None
|
|
bw_in = l_bw_in_ + 20; l_bw_in_ = None
|
|
bw_out = bw_in + 30; bw_in = None
|
|
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, fw_out, bw_out), self._out_spec)
|
|
"""
|
|
test_inputs = input_fn()
|
|
self.assertEqual(gm(*test_inputs), model(*test_inputs))
|
|
|
|
def test_dynamo_graph_capture_default_args(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x, y=1):
|
|
return x + y
|
|
|
|
m = Module()
|
|
ep = dynamo_graph_capture_for_export(m)(torch.randn(2, 3))
|
|
test_inputs = (torch.randn(2, 3),)
|
|
self.assertEqual(ep(*test_inputs), m(*test_inputs))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|