From 274f4cfacb878b56e1dfefbc6ff229190ee557e5 Mon Sep 17 00:00:00 2001 From: Zejun Huang Date: Wed, 13 Nov 2024 22:56:38 +0000 Subject: [PATCH] [3/x][fx minimizer] Support all_outputs in minimizer (#139774) Summary: output nodes may be eliminated to the input nodes if only partial output nodes are specified. add option to check results for all output nodes in the partitioned graph Test Plan: see D65367305 Reviewed By: qcyuan Differential Revision: D65367305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139774 Approved by: https://github.com/jfix71 --- torch/fx/passes/net_min_base.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 81f8a845e83..c349c896ac3 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -66,12 +66,16 @@ class _MinimizerSettingBase: `return_intermediate`: If true, when using `run_nodes()` function to run the model, intermediate results of all the ops will be returned as output. + + `all_outputs`: If true, when using `_run_and_compare()` function, + all the output nodes in the subgraph will be used for comparison. """ accumulate_error: bool = False traverse_method: str = "sequential" find_all: bool = False return_intermediate: bool = False + all_outputs: bool = False def __str__(self): settings_str = "FX Minimizer Settings:\n" @@ -341,7 +345,7 @@ class _MinimizerBase: report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1] report.append("Run and compare ...") - if output_names: + if output_names and not self.settings.all_outputs: output_nodes: NodeList = [] for node in submodule.graph.nodes: if node.op == "output": @@ -381,19 +385,23 @@ class _MinimizerBase: self.results[result_key] = numeric_result # type: ignore[possibly-undefined] report.append(f"Numerical accuracy = {numeric_result}") if not bool_result: - report.append(f"Result mismatch for {result_key}") + report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] if self.module_exporter: + if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + result_key = result_key[-1] + # pyre-ignore[29]: not a function self.module_exporter( a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index] ) + # pyre-ignore[29]: not a function self.module_exporter( b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index] ) - raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") + raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] def _binary_search_impl( self, all_nodes: NodeList, start_idx: int, end_idx: int