diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 81f8a845e83..c349c896ac3 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -66,12 +66,16 @@ class _MinimizerSettingBase: `return_intermediate`: If true, when using `run_nodes()` function to run the model, intermediate results of all the ops will be returned as output. + + `all_outputs`: If true, when using `_run_and_compare()` function, + all the output nodes in the subgraph will be used for comparison. """ accumulate_error: bool = False traverse_method: str = "sequential" find_all: bool = False return_intermediate: bool = False + all_outputs: bool = False def __str__(self): settings_str = "FX Minimizer Settings:\n" @@ -341,7 +345,7 @@ class _MinimizerBase: report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1] report.append("Run and compare ...") - if output_names: + if output_names and not self.settings.all_outputs: output_nodes: NodeList = [] for node in submodule.graph.nodes: if node.op == "output": @@ -381,19 +385,23 @@ class _MinimizerBase: self.results[result_key] = numeric_result # type: ignore[possibly-undefined] report.append(f"Numerical accuracy = {numeric_result}") if not bool_result: - report.append(f"Result mismatch for {result_key}") + report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] if self.module_exporter: + if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + result_key = result_key[-1] + # pyre-ignore[29]: not a function self.module_exporter( a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index] ) + # pyre-ignore[29]: not a function self.module_exporter( b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index] ) - raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") + raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] def _binary_search_impl( self, all_nodes: NodeList, start_idx: int, end_idx: int