pytorch/test/fx/test_partitioner_order.py
Anthony Barbier c8d44a2296 Add __main__ guards to fx tests (#154715)
This PR is part of a series attempting to re-submit #134592 as smaller PRs.

In fx tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)
- Remove any remaining uses of "unittest.main()""

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715
Approved by: https://github.com/Skylion007
2025-06-04 14:38:50 +00:00

55 lines
1.7 KiB
Python

# Owner(s): ["module: fx"]
from collections.abc import Mapping
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.testing._internal.common_utils import TestCase
class DummyDevOperatorSupport(OperatorSupport):
def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return True
class DummyPartitioner(CapabilityBasedPartitioner):
def __init__(self, graph_module: torch.fx.GraphModule):
super().__init__(
graph_module,
DummyDevOperatorSupport(),
allows_single_node_partition=True,
)
class AddModule(torch.nn.Module):
def forward(self, x):
y = torch.add(x, x)
z = torch.add(y, x)
return z
class TestPartitionerOrder(TestCase):
# partitoner test to check graph node order
def test_partitioner_order(self):
m = AddModule()
traced_m = torch.fx.symbolic_trace(m)
partions = DummyPartitioner(traced_m).propose_partitions()
partion_nodes = [list(partition.nodes) for partition in partions]
node_order = [n.name for n in partion_nodes[0]]
for _ in range(10):
traced_m = torch.fx.symbolic_trace(m)
new_partion = DummyPartitioner(traced_m).propose_partitions()
new_partion_nodes = [list(partition.nodes) for partition in new_partion]
new_node_order = [n.name for n in new_partion_nodes[0]]
self.assertTrue(node_order == new_node_order)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)