pytorch/test/fx/test_net_min_base.py
Autin Mitra 5623d30228 [Minimizer] Gracefully exit when there is no discrepancy in block mode (#154076)
Summary:
Previously, when there is no discrepancy in results for block mode, net_min_base will throw an OOB error.

This occurs due to the block _block_traverse_impl returning an OOB after exhausting subgraphs all the way down to a single node

There is also an issue where we may get an unsound subgraph (i.e. mark an earlier node as the "end" even if the correct end is later). This is due to an incorrect check (start_idx == mid) where there can possibly be two values left before the program pre-maturely returns

Test Plan:
Buck UI: https://www.internalfb.com/buck2/52524c26-ace5-4593-8a4b-843a54eb206a
Test UI: https://www.internalfb.com/intern/testinfra/testrun/3096224973363310
Network: Up: 0B  Down: 15MiB  (reSessionID-cd404e97-395f-49fc-8381-373e90a1378f)
Executing actions. Remaining     0/1
Command: test.
Time elapsed: 53.7s
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0

Differential Revision: D75143242

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154076
Approved by: https://github.com/jfix71
2025-05-23 06:42:07 +00:00

103 lines
3.8 KiB
Python

# Owner(s): ["module: fx"]
from unittest import mock
import torch
from torch.fx.passes.net_min_base import (
_MinimizerBase,
_MinimizerSettingBase,
FxNetMinimizerResultMismatchError,
)
from torch.fx.passes.tools_common import Names
from torch.testing._internal.common_utils import TestCase
class TestNetMinBaseBlock(TestCase):
def setUp(self) -> None:
# Setup test fixtures for each test method
class SimpleModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.relu = torch.nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = self.linear2(x)
x = self.relu(x)
return x
self.compare_fn = mock.MagicMock()
self.module = torch.fx.symbolic_trace(SimpleModule())
self.sample_input = (torch.randn(2, 10),)
self.settings = _MinimizerSettingBase(traverse_method="block")
self.minimizer = _MinimizerBase(
module=self.module,
sample_input=self.sample_input,
settings=self.settings,
compare_fn=self.compare_fn,
)
self.report = []
def assert_problematic_nodes(self, culprit_names: Names) -> None:
"""
Quick helper function to assert that a set of nodes (when present together in a subgraph) cause a discrepancy
"""
with mock.patch("torch.fx.passes.net_min_base._MinimizerBase._run_and_compare"):
def run_and_compare_side_effect(
split_module: torch.fx.GraphModule,
submod_name: str,
output_names: Names,
report_idx: int = -1,
) -> None:
submodule = getattr(split_module, submod_name)
# Remove input/output layer
names = set([node.name for node in submodule.graph.nodes][1:-1])
if set(culprit_names) <= names:
raise FxNetMinimizerResultMismatchError
self.minimizer._run_and_compare.side_effect = run_and_compare_side_effect
# Every single node should be a discrepancy
culprits = self.minimizer.minimize()
self.assertEqual({node.name for node in culprits}, set(culprit_names))
def test_no_discrepancy(self) -> None:
# No discrepancies should handle gracefully with an empty set
with (
mock.patch("torch.fx.passes.net_min_base._MinimizerBase.run_a"),
mock.patch("torch.fx.passes.net_min_base._MinimizerBase.run_b"),
):
# Have both run_a and run_b return the same result
return_value = torch.zeros((2, 5))
self.minimizer.run_a.return_value = return_value
self.minimizer.run_b.return_value = return_value
self.compare_fn.return_value = (0, True)
# There should be no discrepancy between the two, and thus we should receive an empty set
culprits = self.minimizer.minimize()
self.assertEqual(culprits, set())
def test_all_nodes_discrepancy(self) -> None:
self.assert_problematic_nodes(["linear", "linear2", "relu"])
def test_first_node_discrepancy(self) -> None:
self.assert_problematic_nodes(["linear"])
def test_last_node_discrepancy(self) -> None:
self.assert_problematic_nodes(["relu"])
def test_middle_node_discrepancy(self) -> None:
self.assert_problematic_nodes(["linear2"])
def test_contiguous_partial_discrepancy_end(self) -> None:
self.assert_problematic_nodes(["linear2", "relu"])
def test_continugous_partial_discrepancy_beginning(self) -> None:
self.assert_problematic_nodes(["linear", "linear2"])