mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
We want to prioritize operators involved in data-dependent runtime assertions when legalizing the graph. For example, in the following piece of code, the `assert_scalar` and `assert_async` calls need to occur before the `slice_copy` for the program to run correctly with fake tensors. Otherwise we will run into a data-dependent error.
```
_local_scalar_dense: "Sym(u113)" = torch.ops.aten._local_scalar_dense.default(aten_minimum_default); aten_minimum_default = None
ge_1: "Sym(u113 >= 2)" = _local_scalar_dense >= 2
aten_scalar_tensor_default_3: "f32[]" = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(ge_1); ge_1 = None
aten__assert_async_msg_2 = executorch_exir_dialects_edge__ops_aten__assert_async_msg(aten_scalar_tensor_default_3, '_local_scalar_dense is outside of inline constraint [2, 1000].'); aten_scalar_tensor_default_3 = None
le_1: "Sym(u113 <= 1000)" = _local_scalar_dense <= 1000
aten_scalar_tensor_default_4: "f32[]" = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(le_1); le_1 = None
aten__assert_async_msg_3 = executorch_exir_dialects_edge__ops_aten__assert_async_msg(aten_scalar_tensor_default_4, '_local_scalar_dense is outside of inline constraint [2, 1000].'); aten_scalar_tensor_default_4 = None
mul: "Sym(-u112)" = -1 * sym_size; sym_size = None
add: "Sym(-u112 + u113)" = _local_scalar_dense + mul; mul = None
lt: "Sym(-u112 + u113 < 0)" = add < 0; add = None
aten__assert_scalar_default = executorch_exir_dialects_edge__ops_aten__assert_scalar_default(lt, 'Deferred runtime assertion failed -u0 + u1 < 0'); lt = None
aten_slice_copy_tensor_3: "f32[u113]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(getitem, 0, 0, _local_scalar_dense); getitem = None
```
Test Plan: test case
Differential Revision: D56201450
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124213
Approved by: https://github.com/SherlockNoMad
229 lines
8.2 KiB
Python
229 lines
8.2 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch._dynamo.eval_frame import is_dynamo_supported
|
|
from torch.fx.passes.tools_common import legalize_graph
|
|
from torch.fx.passes.utils.source_matcher_utils import (
|
|
check_subgraphs_connected,
|
|
get_source_partitions,
|
|
)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
class TestSourceMatcher(JitTestCase):
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_module_partitioner_linear_relu_linear(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(3, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
self.linear2 = torch.nn.Linear(3, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear1(x)
|
|
x = self.relu(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
inputs = (torch.randn(3, 3),)
|
|
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
module_partitions = get_source_partitions(
|
|
gm.graph, [torch.nn.Linear, torch.nn.ReLU]
|
|
)
|
|
|
|
self.assertEqual(len(module_partitions), 2)
|
|
self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
|
|
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
|
|
|
|
self.assertFalse(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.Linear][0],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.Linear][1],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
self.assertFalse(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.Linear][2],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_module_partitioner_conv_relu_maxpool(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, constant_tensor: torch.Tensor) -> None:
|
|
super().__init__()
|
|
self.constant_tensor = constant_tensor
|
|
self.conv1 = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=16, kernel_size=3, padding=1
|
|
)
|
|
self.conv2 = torch.nn.Conv2d(
|
|
in_channels=16, out_channels=16, kernel_size=3, padding=1
|
|
)
|
|
self.conv3 = torch.nn.Conv2d(
|
|
in_channels=16, out_channels=16, kernel_size=3, padding=1
|
|
)
|
|
self.relu = torch.nn.ReLU()
|
|
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
a = self.conv1(x)
|
|
b = self.conv2(a)
|
|
c = a + self.constant_tensor
|
|
z = self.conv3(b + c)
|
|
return self.maxpool(self.relu(z))
|
|
|
|
inputs = (torch.randn(1, 3, 256, 256),)
|
|
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(
|
|
*inputs
|
|
)
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
module_partitions = get_source_partitions(
|
|
gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]
|
|
)
|
|
|
|
self.assertEqual(len(module_partitions), 3)
|
|
self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
|
|
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
|
|
self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
|
|
|
|
self.assertFalse(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.Conv2d][0],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
self.assertFalse(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.Conv2d][1],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.Conv2d][2],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
self.assertFalse(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.MaxPool2d][0],
|
|
module_partitions[torch.nn.ReLU][0],
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
check_subgraphs_connected(
|
|
module_partitions[torch.nn.ReLU][0],
|
|
module_partitions[torch.nn.MaxPool2d][0],
|
|
)
|
|
)
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_module_partitioner_functional_conv_relu_conv(self):
|
|
class FunctionalConv2d(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.stride = (1, 1)
|
|
self.padding = (0, 0)
|
|
self.dilation = (1, 1)
|
|
self.groups = 1
|
|
|
|
def forward(self, x, weight, bias):
|
|
return torch.nn.functional.conv2d(
|
|
x,
|
|
weight,
|
|
bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = FunctionalConv2d()
|
|
self.conv2 = FunctionalConv2d()
|
|
|
|
def forward(self, x, weight, bias):
|
|
x = self.conv1(x, weight, bias)
|
|
x = torch.nn.functional.relu(x)
|
|
x = self.conv2(x, weight, bias)
|
|
return x
|
|
|
|
inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
|
|
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
module_partitions = get_source_partitions(
|
|
gm.graph, [torch.nn.functional.conv2d]
|
|
)
|
|
|
|
self.assertEqual(len(module_partitions), 1)
|
|
self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_module_partitioner_functional_linear_relu_linear(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, weight, bias):
|
|
x = torch.nn.functional.linear(x, weight, bias)
|
|
x = torch.nn.functional.linear(x, weight, bias)
|
|
x = torch.nn.functional.relu(x)
|
|
x = torch.nn.functional.linear(x, weight, bias)
|
|
x = torch.nn.functional.linear(x, weight, bias)
|
|
x = torch.nn.functional.relu(x)
|
|
return x
|
|
|
|
inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
|
|
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
module_partitions = get_source_partitions(
|
|
gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]
|
|
)
|
|
|
|
self.assertEqual(len(module_partitions), 2)
|
|
self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
|
|
self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_legalize_slice(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
b = x.item()
|
|
torch._check_is_size(b)
|
|
torch._check(b + 1 < y.size(0))
|
|
return y[: b + 1]
|
|
|
|
ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)))
|
|
fake_inputs = [
|
|
node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder"
|
|
]
|
|
gm = ep.module()
|
|
with fake_inputs[0].fake_mode:
|
|
torch.fx.Interpreter(gm).run(*fake_inputs)
|
|
legalized_gm = legalize_graph(gm)
|
|
with fake_inputs[0].fake_mode:
|
|
torch.fx.Interpreter(legalized_gm).run(*fake_inputs)
|