[3/x][fx minimizer] Support all_outputs in minimizer (#139774)

Summary: output nodes may be eliminated to the input nodes if only partial output nodes are specified. add option to check results for all output nodes in the partitioned graph

Test Plan: see D65367305

Reviewed By: qcyuan

Differential Revision: D65367305

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139774
Approved by: https://github.com/jfix71
This commit is contained in:
Zejun Huang 2024-11-13 22:56:38 +00:00 committed by PyTorch MergeBot
parent 26fde110db
commit 274f4cfacb

View File

@ -66,12 +66,16 @@ class _MinimizerSettingBase:
`return_intermediate`: If true, when using `run_nodes()` function to run the
model, intermediate results of all the ops will be returned as output.
`all_outputs`: If true, when using `_run_and_compare()` function,
all the output nodes in the subgraph will be used for comparison.
"""
accumulate_error: bool = False
traverse_method: str = "sequential"
find_all: bool = False
return_intermediate: bool = False
all_outputs: bool = False
def __str__(self):
settings_str = "FX Minimizer Settings:\n"
@ -341,7 +345,7 @@ class _MinimizerBase:
report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1]
report.append("Run and compare ...")
if output_names:
if output_names and not self.settings.all_outputs:
output_nodes: NodeList = []
for node in submodule.graph.nodes:
if node.op == "output":
@ -381,19 +385,23 @@ class _MinimizerBase:
self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
report.append(f"Numerical accuracy = {numeric_result}")
if not bool_result:
report.append(f"Result mismatch for {result_key}")
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
if self.module_exporter:
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
result_key = result_key[-1]
# pyre-ignore[29]: not a function
self.module_exporter(
a_input,
submodule,
str(result_key[0]) + "_cpu", # type: ignore[index]
)
# pyre-ignore[29]: not a function
self.module_exporter(
b_input,
submodule,
str(result_key[0]) + "_acc", # type: ignore[index]
)
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
def _binary_search_impl(
self, all_nodes: NodeList, start_idx: int, end_idx: int