mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Lets explore firs a couple of problem related to replacements and runtime assertions. #### example problem 1 if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert not being inserted during insert_deferred_runtime_asserts. The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen, but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc. #### example problem 2 Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure. internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/ ## The solution : Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item(). The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion. This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions. I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue. Other approaches that did not work : #### ban replacements on unbacked. 1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements but problem 2 becomes more problematic. 2. Problem two, it affects the quality of reasoning ! in a bad way. #### Apply specialization on runtime assertions before codegen . 1. Can fix some issues, but may lead also to runtime assertions becoming NOPs. 2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661 Approved by: https://github.com/jansel
448 lines
18 KiB
Python
448 lines
18 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import types
|
|
import unittest
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
from torch.export import export, export_for_training
|
|
from torch.export._trace import _convert_ts_to_export_experimental
|
|
from torch.export.experimental import _export_forward_backward, _sticky_export
|
|
from torch.export.graph_signature import OutputKind
|
|
from torch.testing import FileCheck
|
|
|
|
|
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
|
class TestExperiment(TestCase):
|
|
def test_torchscript_module_export(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.cos() + x.sin()
|
|
|
|
model_to_trace = M()
|
|
inps = (torch.randn(4, 4),)
|
|
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
|
|
|
|
exported_module = _convert_ts_to_export_experimental(
|
|
traced_module_by_torchscript, inps
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps)))
|
|
|
|
def test_torchscript_module_export_single_input(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.cos() + x.sin()
|
|
|
|
model_to_trace = M()
|
|
inps = torch.randn(4, 4)
|
|
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
|
|
|
|
exported_module = _convert_ts_to_export_experimental(
|
|
traced_module_by_torchscript, inps
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps)))
|
|
|
|
def test_torchscript_module_export_various_inputs_with_annotated_input_names(self):
|
|
def _check_equality_and_annotations(m_func, inps):
|
|
# Original module.
|
|
model_to_trace = m_func()
|
|
|
|
# ExportedProgram from TorchScript module.
|
|
traced_module_by_torchscript = torch.jit.trace(
|
|
m_func(), example_inputs=inps
|
|
)
|
|
exported_module = _convert_ts_to_export_experimental(
|
|
traced_module_by_torchscript, inps
|
|
)
|
|
|
|
# ExportedProgram from original module.
|
|
original_exported_module = torch.export.export_for_training(
|
|
m_func(), inps, strict=True
|
|
)
|
|
|
|
# Check whether input annotations are the same as tracing the original module.
|
|
orig_ph_name_list = [
|
|
n.name
|
|
for n in original_exported_module.graph.nodes
|
|
if n.op == "placeholder"
|
|
]
|
|
ph_name_list = [
|
|
n.name for n in exported_module.graph.nodes if n.op == "placeholder"
|
|
]
|
|
self.assertEqual(orig_ph_name_list, ph_name_list)
|
|
|
|
# Check results equality.
|
|
self.assertTrue(
|
|
torch.allclose(exported_module(*inps), model_to_trace(*inps))
|
|
)
|
|
|
|
# Tuple
|
|
class MTuple(torch.nn.Module):
|
|
def forward(self, x: Tuple[torch.Tensor]):
|
|
return x[0] + x[1]
|
|
|
|
_check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),))
|
|
|
|
# List
|
|
class MList(torch.nn.Module):
|
|
def forward(self, x: List[torch.Tensor]):
|
|
return x[0] + x[1]
|
|
|
|
_check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],))
|
|
|
|
# Dict
|
|
class MDict(torch.nn.Module):
|
|
def forward(self, x: Dict[str, torch.Tensor]):
|
|
return x["0"] + x["1"]
|
|
|
|
_check_equality_and_annotations(
|
|
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
|
|
)
|
|
|
|
def test_joint_basic(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x):
|
|
return self.loss(
|
|
self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
|
|
)
|
|
|
|
m = Module()
|
|
example_inputs = (torch.randn(3),)
|
|
m(*example_inputs)
|
|
ep = torch.export.export_for_training(m, example_inputs, strict=True)
|
|
joint_ep = _export_forward_backward(ep)
|
|
self.assertExpectedInline(
|
|
str(joint_ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|
view = torch.ops.aten.view.default(x, [1, 3]); x = None
|
|
permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None
|
|
addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None
|
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
|
alias = torch.ops.aten.alias.default(_softmax)
|
|
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
|
alias_2 = torch.ops.aten.alias.default(_log_softmax)
|
|
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
|
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
|
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
|
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
|
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
|
|
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
|
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
|
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
|
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
|
mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None
|
|
permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None
|
|
sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
|
view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None
|
|
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
|
|
return (div, permute_3, view_3)""",
|
|
)
|
|
ep = joint_ep.run_decompositions()
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|
view = torch.ops.aten.view.default(x, [1, 3]); x = None
|
|
permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None
|
|
addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None
|
|
view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
|
_softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
|
alias = torch.ops.aten.alias.default(_softmax)
|
|
alias_1 = torch.ops.aten.alias.default(alias); alias = None
|
|
clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
|
_log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
|
alias_2 = torch.ops.aten.alias.default(_log_softmax)
|
|
alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None
|
|
mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
|
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
|
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
|
alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None
|
|
alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None
|
|
exp = torch.ops.aten.exp.default(alias_5); alias_5 = None
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True)
|
|
mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None
|
|
sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
|
|
alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None
|
|
alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True)
|
|
mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None
|
|
view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None
|
|
permute_1 = torch.ops.aten.permute.default(view_2, [1, 0])
|
|
mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None
|
|
permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None
|
|
sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
|
view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None
|
|
permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
|
|
return (div, permute_3, view_3)""",
|
|
)
|
|
|
|
def test_joint_dynamic(self) -> None:
|
|
from torch.export import Dim
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.y = torch.nn.Parameter(torch.randn(3))
|
|
|
|
def forward(self, x):
|
|
x = torch.ones(x.shape[0], 3)
|
|
return (self.y + x).sum()
|
|
|
|
m = Module()
|
|
example_inputs = (torch.randn(3),)
|
|
m(*example_inputs)
|
|
ep = torch.export.export_for_training(
|
|
m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}}, strict=True
|
|
)
|
|
_export_forward_backward(ep)
|
|
|
|
def test_joint_cifar10_backwards(self) -> None:
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# From Pytorch's CIFAR10 example:
|
|
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
self.fc2 = nn.Linear(120, 84)
|
|
self.fc3 = nn.Linear(84, 10)
|
|
self.loss = nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x, labels):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = self.pool(F.relu(self.conv2(x)))
|
|
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return self.loss(x, labels)
|
|
|
|
net = Net()
|
|
x = torch.randn(4, 3, 32, 32)
|
|
labels = torch.ones(4, dtype=torch.int64)
|
|
inputs = (x, labels)
|
|
|
|
ep = export_for_training(net, inputs, strict=True)
|
|
ep = _export_forward_backward(ep)
|
|
|
|
def test_joint_loss_index(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, index):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(4, 4)
|
|
self.index = index
|
|
|
|
def forward(self, x):
|
|
x = self.l(x)
|
|
x = x.sum()
|
|
if self.index == 0:
|
|
return x, -x.detach()
|
|
else:
|
|
return x.detach(), x
|
|
|
|
inputs = (torch.randn(4, 4),)
|
|
for i in [0, 1]:
|
|
ep = export_for_training(Foo(i), inputs, strict=True)
|
|
ep_joint = _export_forward_backward(ep, joint_loss_index=i)
|
|
for j, spec in enumerate(ep_joint.graph_signature.output_specs):
|
|
if i == j:
|
|
self.assertTrue(spec.kind == OutputKind.LOSS_OUTPUT)
|
|
else:
|
|
self.assertTrue(spec.kind != OutputKind.LOSS_OUTPUT)
|
|
|
|
def test_joint_buffer_input_mutations(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(4, 4)
|
|
self.register_buffer("buf", torch.randn(4))
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
def forward(self, x, label):
|
|
x.add_(self.buf)
|
|
x = self.l(x)
|
|
self.buf.add_(2.0)
|
|
return self.loss(x, label)
|
|
|
|
inputs = (
|
|
torch.randn(4, 4),
|
|
torch.randint(0, 4, (4,)),
|
|
)
|
|
ep = export(Foo(), inputs)
|
|
ep_joint = _export_forward_backward(ep)
|
|
self.assertEqual(len(ep_joint.graph_signature.output_specs), 5)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[0].kind,
|
|
OutputKind.BUFFER_MUTATION,
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[0].target,
|
|
"buf",
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[1].kind,
|
|
OutputKind.USER_INPUT_MUTATION,
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[1].target,
|
|
"x",
|
|
)
|
|
self.assertEqual(
|
|
ep_joint.graph_signature.output_specs[2].kind,
|
|
OutputKind.LOSS_OUTPUT,
|
|
)
|
|
|
|
def test_sticky_export(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Pipeline:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def generate(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
inp = torch.randn(4, 4)
|
|
|
|
p = Pipeline(Model())
|
|
orig_forward = p.model.forward
|
|
p.model.forward = _sticky_export(p.model.forward)
|
|
res = p.generate(inp)
|
|
|
|
p.model.forward = orig_forward
|
|
res2 = p.generate(inp)
|
|
self.assertTrue(torch.allclose(res, res2))
|
|
|
|
def test_sticky_export_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
if x.shape[0] < 5:
|
|
return self.linear(x)
|
|
return x.sin()
|
|
|
|
class Pipeline:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def generate(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
inp = torch.randn(4, 4)
|
|
|
|
def callback(*args, **kwargs):
|
|
# I think it is bit weird to use the forward arg name here, so
|
|
# lets just use ShapeCollections
|
|
|
|
flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs))
|
|
collections = torch.export.ShapesCollection()
|
|
for arg in flat_args:
|
|
if isinstance(arg, torch.Tensor):
|
|
collections[arg] = {
|
|
i: torch.export.Dim.AUTO for i in range(len(arg.shape))
|
|
}
|
|
return collections
|
|
|
|
p = Pipeline(Model())
|
|
p.model.forward = _sticky_export(
|
|
p.model.forward, dynamic_shapes_callback=callback
|
|
)
|
|
_ = p.generate(inp)
|
|
self.assertExpectedInline(
|
|
str(p.model.forward._exported_artifact.code).strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
linear_weight = self.linear.weight
|
|
linear_bias = self.linear.bias
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(x, 1)
|
|
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
|
|
eq = sym_size_int_2 == 4; sym_size_int_2 = None
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, 4) on node 'eq'"); eq = _assert_scalar_default = None
|
|
return pytree.tree_unflatten((linear,), self._out_spec)""",
|
|
)
|
|
|
|
def test_sticky_export_nested_inp(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, *, inputs):
|
|
return self.linear(inputs[0]) + self.linear(inputs[1])
|
|
|
|
class Pipeline:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def generate(self, *, input_tensor, input_tensor2):
|
|
inputs = [input_tensor, input_tensor2]
|
|
return self.model(inputs=inputs)
|
|
|
|
inp = torch.randn(4, 4)
|
|
inp2 = torch.randn(4, 4)
|
|
|
|
p = Pipeline(Model())
|
|
orig_forward = p.model.forward
|
|
p.model.forward = _sticky_export(p.model.forward)
|
|
res = p.generate(input_tensor=inp, input_tensor2=inp2)
|
|
|
|
p.model.forward = orig_forward
|
|
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
|
|
self.assertTrue(torch.allclose(res, res2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|