mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR introduces two components. CapabilityBasedPartitioner for FX graph: given a list of supported operators, this partitioner tries to forms the largest subgraphs that only contain the supported ops. Fuser utility: given a list of nodes in FX graph, it lifts them as a sub-GraphModule in the original graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79439 Approved by: https://github.com/jjsjann123, https://github.com/davidberard98
277 lines
9.0 KiB
Python
277 lines
9.0 KiB
Python
# Owner(s): ["oncall: fx"]
|
|
|
|
import operator
|
|
import logging
|
|
|
|
import torch
|
|
from torch.fx._symbolic_trace import symbolic_trace
|
|
|
|
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
|
from torch.fx.passes.operator_support import OperatorSupport
|
|
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
|
|
|
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.linear2 = torch.nn.Linear(4, 4)
|
|
self.param = torch.nn.Parameter(torch.rand(4, 4))
|
|
|
|
def forward(self, a, b, c):
|
|
add = a + b
|
|
|
|
linear_1 = self.linear(add)
|
|
|
|
add_1 = add + c
|
|
add_2 = add_1 + self.param
|
|
add_3 = add_1 + linear_1
|
|
add_4 = add_2 + add_3
|
|
|
|
linear_2 = self.linear2(add_4)
|
|
|
|
add_5 = linear_2 + add_4
|
|
add_6 = add_5 + a
|
|
relu = add_6.relu()
|
|
|
|
return add_4, add_6, relu
|
|
|
|
class TestPartitionFunctions:
|
|
@staticmethod
|
|
def forward1(a, b, c):
|
|
add = a + b
|
|
add_1 = add + b
|
|
add_2 = add_1 + c
|
|
relu_1 = add_2.relu()
|
|
add_3 = add_1 + add_2
|
|
add_4 = add_1 + relu_1 + add_3
|
|
relu_2 = add_4.relu()
|
|
add_5 = relu_2 + add_4
|
|
add_6 = add_5 + add_4
|
|
return add_4, add_6
|
|
|
|
@staticmethod
|
|
def forward2(a, b, _):
|
|
add = a + b
|
|
add_1 = add + b
|
|
relu_1 = add_1.relu() # blocked by this
|
|
add_3 = add_1 + relu_1
|
|
add_4 = add_1 + add_3
|
|
return add_4, add_1
|
|
|
|
@staticmethod
|
|
def forward3(a, b, c):
|
|
add = a + b
|
|
add_1 = a + c
|
|
add_2 = b + c
|
|
return add, add_1, add_2
|
|
|
|
@staticmethod
|
|
def forward4(a, b, c):
|
|
add = a + b
|
|
add_1 = a + c
|
|
add_2 = b + c
|
|
return torch.where(add > 0, add_1, add_2)
|
|
|
|
@staticmethod
|
|
def forward5(a, b, c):
|
|
# add should be fused right branch, as left branch is not supported
|
|
add = a + 1
|
|
# left branch
|
|
relu = add.relu()
|
|
# right branch
|
|
add_1 = add + 2
|
|
return relu, add_1
|
|
|
|
@staticmethod
|
|
def forward6(a, b, c):
|
|
# add should have its own partition, as neither branchs are supported
|
|
add = a + 1
|
|
# left branch
|
|
relu = add.relu()
|
|
# right branch
|
|
relu_1 = add.relu()
|
|
return relu, relu_1
|
|
|
|
@staticmethod
|
|
def forward7(a, b, c):
|
|
# both branches are supported, all adds should be fused together
|
|
add = a + 1
|
|
# left branch
|
|
add_1 = add + 2
|
|
# right branch is larger
|
|
add_2 = add + 1
|
|
add_3 = add_2 + 1
|
|
return add_3, add_1
|
|
|
|
@staticmethod
|
|
def forward8(a, b, c):
|
|
# both branches are in the same partition, add should join the same partition
|
|
add = a + 1
|
|
# left branch
|
|
add_1 = add + 2
|
|
# right branch
|
|
add_2 = add + 1
|
|
# left and right branch merges
|
|
add_3 = add_2 + add_1
|
|
|
|
return add_3
|
|
|
|
@staticmethod
|
|
def forward9(a, b, c):
|
|
add = a + 1
|
|
# branch 1
|
|
add_1 = add + 1
|
|
# branch 2
|
|
add_2 = add + 1
|
|
# branch_3
|
|
add_3 = add + 1
|
|
out = torch.stack([add_1, add_2, add_3])
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward10(a, b, c):
|
|
add = a + 1
|
|
# branch 1
|
|
add_1 = add + 1
|
|
# branch 2
|
|
add_2 = add + 1
|
|
# branch 3: depends on branch 2
|
|
add_3 = add + add_2
|
|
out = torch.stack([add_1, add_2, add_3])
|
|
return out
|
|
|
|
@staticmethod
|
|
def forward11(a, b, c):
|
|
add = a + 1
|
|
# branch 1
|
|
add_1 = add.relu()
|
|
# branch 2 depends on branch 1
|
|
add_2 = add + add_1
|
|
# branch 3
|
|
add_3 = add.relu()
|
|
out = torch.stack([add_1, add_2, add_3])
|
|
return out
|
|
|
|
# A mock OperatorSupport class, where only operator.add is supported
|
|
class MockOperatorSupport(OperatorSupport):
|
|
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
|
return node.op == "call_function" and node.target in {operator.add}
|
|
|
|
class TestFXGraphPasses(JitTestCase):
|
|
|
|
@parametrize("fn, expected_partition", [
|
|
(TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]]),
|
|
(TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]]),
|
|
|
|
# 2 branches cases
|
|
(TestPartitionFunctions.forward5, [["add_1", "add"]]),
|
|
(TestPartitionFunctions.forward6, [["add"]]),
|
|
(TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]]),
|
|
(TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]]),
|
|
|
|
# 3 branch cases
|
|
(TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']]),
|
|
(TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']]),
|
|
(TestPartitionFunctions.forward11, [['add_1'], ['add']]),
|
|
])
|
|
def test_partitioner(self, fn, expected_partition):
|
|
traced = symbolic_trace(fn)
|
|
|
|
supported_ops = MockOperatorSupport()
|
|
partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True)
|
|
partitions = partitioner.propose_partitions()
|
|
|
|
partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
|
|
assert len(partitions_name) == len(expected_partition)
|
|
for i in range(len(partitions_name)):
|
|
assert set(partitions_name[i]) == set(expected_partition[i])
|
|
|
|
fused_graph = partitioner.fuse_partitions(partitions)
|
|
|
|
a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
|
|
|
|
expected = fn(a, b, c)
|
|
result = fused_graph(a, b, c)
|
|
torch.testing.assert_close(expected, result)
|
|
|
|
|
|
@parametrize("fn, expected_partition", [
|
|
# horizontal fusion without a common downstream node, not supported yet
|
|
(TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]),
|
|
# horizontal fusion with a common downstream node, not supported yet
|
|
(TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]),
|
|
])
|
|
def test_partitioner_xfail(self, fn, expected_partition):
|
|
traced = symbolic_trace(fn)
|
|
|
|
supported_ops = MockOperatorSupport()
|
|
partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True)
|
|
partitions = partitioner.propose_partitions()
|
|
|
|
partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
|
|
with self.assertRaises(Exception):
|
|
assert len(partitions_name) == len(expected_partition)
|
|
|
|
@parametrize("partition", [
|
|
[['add', 'add_1'], ['add_5', 'add_6']],
|
|
[['add', 'add_1', 'add_2']], # vertical fusion
|
|
[['add_2', 'add_3']], # horizontal fusion
|
|
[['add_3', 'add_4']],
|
|
[['add_6', 'add_5']], # arbitray node order
|
|
[['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order
|
|
[['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order
|
|
[['add_5', 'linear2']], # includes call_function + call_module node
|
|
[['add_6', 'relu']], # includes call_function + call_module node
|
|
[['param', 'add_2']], # includes get_attr + call_module nodes
|
|
[['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes
|
|
[["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph
|
|
])
|
|
def test_fuser_util(self, partition):
|
|
m = TestModule()
|
|
gm = symbolic_trace(m)
|
|
|
|
nodes_by_name = {node.name : node for node in gm.graph.nodes}
|
|
|
|
partitions = []
|
|
for node_names in partition:
|
|
partitions.append([nodes_by_name[name] for name in node_names])
|
|
|
|
fused_graph = fuse_by_partitions(gm, partitions)
|
|
|
|
a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
|
|
|
|
expected = m(a, b, c)
|
|
result = fused_graph(a, b, c)
|
|
|
|
torch.testing.assert_close(expected, result)
|
|
|
|
@parametrize("partition", [
|
|
[['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions
|
|
[['add', 'add_1', 'add_3']], # invalid partition: circular dependency
|
|
[['add_4', 'add_5']], # invalid partition: circular dependency
|
|
[['relu', 'add_5']], # invalid partition: circular dependency
|
|
])
|
|
def test_fuser_util_xfail(self, partition):
|
|
m = TestModule()
|
|
gm = symbolic_trace(m)
|
|
|
|
nodes_by_name = {node.name : node for node in gm.graph.nodes}
|
|
|
|
partitions = []
|
|
for node_names in partition:
|
|
partitions.append([nodes_by_name[name] for name in node_names])
|
|
|
|
with self.assertRaises(Exception):
|
|
fuse_by_partitions(gm, partitions)
|
|
|
|
instantiate_parametrized_tests(TestFXGraphPasses)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|