# Owner(s): ["module: fx"] from unittest import mock import torch from torch.fx.passes.net_min_base import ( _MinimizerBase, _MinimizerSettingBase, FxNetMinimizerResultMismatchError, ) from torch.fx.passes.tools_common import Names from torch.testing._internal.common_utils import TestCase class TestNetMinBaseBlock(TestCase): def setUp(self) -> None: # Setup test fixtures for each test method class SimpleModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(10, 5) self.linear2 = torch.nn.Linear(5, 5) self.relu = torch.nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear(x) x = self.linear2(x) x = self.relu(x) return x self.compare_fn = mock.MagicMock() self.module = torch.fx.symbolic_trace(SimpleModule()) self.sample_input = (torch.randn(2, 10),) self.settings = _MinimizerSettingBase(traverse_method="block") self.minimizer = _MinimizerBase( module=self.module, sample_input=self.sample_input, settings=self.settings, compare_fn=self.compare_fn, ) self.report = [] def assert_problematic_nodes(self, culprit_names: Names) -> None: """ Quick helper function to assert that a set of nodes (when present together in a subgraph) cause a discrepancy """ with mock.patch("torch.fx.passes.net_min_base._MinimizerBase._run_and_compare"): def run_and_compare_side_effect( split_module: torch.fx.GraphModule, submod_name: str, output_names: Names, report_idx: int = -1, ) -> None: submodule = getattr(split_module, submod_name) # Remove input/output layer names = set([node.name for node in submodule.graph.nodes][1:-1]) if set(culprit_names) <= names: raise FxNetMinimizerResultMismatchError self.minimizer._run_and_compare.side_effect = run_and_compare_side_effect # Every single node should be a discrepancy culprits = self.minimizer.minimize() self.assertEqual({node.name for node in culprits}, set(culprit_names)) def test_no_discrepancy(self) -> None: # No discrepancies should handle gracefully with an empty set with ( mock.patch("torch.fx.passes.net_min_base._MinimizerBase.run_a"), mock.patch("torch.fx.passes.net_min_base._MinimizerBase.run_b"), ): # Have both run_a and run_b return the same result return_value = torch.zeros((2, 5)) self.minimizer.run_a.return_value = return_value self.minimizer.run_b.return_value = return_value self.compare_fn.return_value = (0, True) # There should be no discrepancy between the two, and thus we should receive an empty set culprits = self.minimizer.minimize() self.assertEqual(culprits, set()) def test_all_nodes_discrepancy(self) -> None: self.assert_problematic_nodes(["linear", "linear2", "relu"]) def test_first_node_discrepancy(self) -> None: self.assert_problematic_nodes(["linear"]) def test_last_node_discrepancy(self) -> None: self.assert_problematic_nodes(["relu"]) def test_middle_node_discrepancy(self) -> None: self.assert_problematic_nodes(["linear2"]) def test_contiguous_partial_discrepancy_end(self) -> None: self.assert_problematic_nodes(["linear2", "relu"]) def test_continugous_partial_discrepancy_beginning(self) -> None: self.assert_problematic_nodes(["linear", "linear2"]) if __name__ == "__main__": raise RuntimeError( "This test is not currently used and should be " "enabled in discover_tests.py if required." )