pytorch/test/fx/test_source_matcher_utils.py
Angela Yi d4225c55d9 [fx] Prioritize runtime assertions ops (#124213)
Summary:
We want to prioritize operators involved in data-dependent runtime assertions when legalizing the graph. For example, in the following piece of code, the `assert_scalar` and `assert_async` calls need to occur before the `slice_copy` for the program to run correctly with fake tensors. Otherwise we will run into a data-dependent error.

```
        _local_scalar_dense: "Sym(u113)" = torch.ops.aten._local_scalar_dense.default(aten_minimum_default);  aten_minimum_default = None

        ge_1: "Sym(u113 >= 2)" = _local_scalar_dense >= 2
        aten_scalar_tensor_default_3: "f32[]" = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(ge_1);  ge_1 = None
        aten__assert_async_msg_2 = executorch_exir_dialects_edge__ops_aten__assert_async_msg(aten_scalar_tensor_default_3, '_local_scalar_dense is outside of inline constraint [2, 1000].');  aten_scalar_tensor_default_3 = None
        le_1: "Sym(u113 <= 1000)" = _local_scalar_dense <= 1000
        aten_scalar_tensor_default_4: "f32[]" = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(le_1);  le_1 = None
        aten__assert_async_msg_3 = executorch_exir_dialects_edge__ops_aten__assert_async_msg(aten_scalar_tensor_default_4, '_local_scalar_dense is outside of inline constraint [2, 1000].');  aten_scalar_tensor_default_4 = None

        mul: "Sym(-u112)" = -1 * sym_size;  sym_size = None
        add: "Sym(-u112 + u113)" = _local_scalar_dense + mul;  mul = None
        lt: "Sym(-u112 + u113 < 0)" = add < 0;  add = None
        aten__assert_scalar_default = executorch_exir_dialects_edge__ops_aten__assert_scalar_default(lt, 'Deferred runtime assertion failed -u0 + u1 < 0');  lt = None

        aten_slice_copy_tensor_3: "f32[u113]" = executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor(getitem, 0, 0, _local_scalar_dense);  getitem = None
```

Test Plan: test case

Differential Revision: D56201450

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124213
Approved by: https://github.com/SherlockNoMad
2024-05-07 21:31:10 +00:00

229 lines
8.2 KiB
Python

# Owner(s): ["module: fx"]
import os
import sys
import unittest
import torch
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.eval_frame import is_dynamo_supported
from torch.fx.passes.tools_common import legalize_graph
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
get_source_partitions,
)
from torch.testing._internal.jit_utils import JitTestCase
class TestSourceMatcher(JitTestCase):
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_linear_relu_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
inputs = (torch.randn(3, 3),)
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.ReLU]
)
self.assertEqual(len(module_partitions), 2)
self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Linear][0],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertTrue(
check_subgraphs_connected(
module_partitions[torch.nn.Linear][1],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Linear][2],
module_partitions[torch.nn.ReLU][0],
)
)
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_conv_relu_maxpool(self):
class M(torch.nn.Module):
def __init__(self, constant_tensor: torch.Tensor) -> None:
super().__init__()
self.constant_tensor = constant_tensor
self.conv1 = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.conv2 = torch.nn.Conv2d(
in_channels=16, out_channels=16, kernel_size=3, padding=1
)
self.conv3 = torch.nn.Conv2d(
in_channels=16, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor) -> torch.Tensor:
a = self.conv1(x)
b = self.conv2(a)
c = a + self.constant_tensor
z = self.conv3(b + c)
return self.maxpool(self.relu(z))
inputs = (torch.randn(1, 3, 256, 256),)
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(
*inputs
)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]
)
self.assertEqual(len(module_partitions), 3)
self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Conv2d][0],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Conv2d][1],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertTrue(
check_subgraphs_connected(
module_partitions[torch.nn.Conv2d][2],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.MaxPool2d][0],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertTrue(
check_subgraphs_connected(
module_partitions[torch.nn.ReLU][0],
module_partitions[torch.nn.MaxPool2d][0],
)
)
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_functional_conv_relu_conv(self):
class FunctionalConv2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
def forward(self, x, weight, bias):
return torch.nn.functional.conv2d(
x,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = FunctionalConv2d()
self.conv2 = FunctionalConv2d()
def forward(self, x, weight, bias):
x = self.conv1(x, weight, bias)
x = torch.nn.functional.relu(x)
x = self.conv2(x, weight, bias)
return x
inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(
gm.graph, [torch.nn.functional.conv2d]
)
self.assertEqual(len(module_partitions), 1)
self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_functional_linear_relu_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, bias):
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.relu(x)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.linear(x, weight, bias)
x = torch.nn.functional.relu(x)
return x
inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(
gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]
)
self.assertEqual(len(module_partitions), 2)
self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_legalize_slice(self):
class M(torch.nn.Module):
def forward(self, x, y):
b = x.item()
torch._check_is_size(b)
torch._check(b + 1 < y.size(0))
return y[: b + 1]
ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)))
fake_inputs = [
node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder"
]
gm = ep.module()
with fake_inputs[0].fake_mode:
torch.fx.Interpreter(gm).run(*fake_inputs)
legalized_gm = legalize_graph(gm)
with fake_inputs[0].fake_mode:
torch.fx.Interpreter(legalized_gm).run(*fake_inputs)