mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[3/x][fx minimizer] Support all_outputs in minimizer (#139774)
Summary: output nodes may be eliminated to the input nodes if only partial output nodes are specified. add option to check results for all output nodes in the partitioned graph Test Plan: see D65367305 Reviewed By: qcyuan Differential Revision: D65367305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139774 Approved by: https://github.com/jfix71
This commit is contained in:
parent
26fde110db
commit
274f4cfacb
|
|
@ -66,12 +66,16 @@ class _MinimizerSettingBase:
|
|||
|
||||
`return_intermediate`: If true, when using `run_nodes()` function to run the
|
||||
model, intermediate results of all the ops will be returned as output.
|
||||
|
||||
`all_outputs`: If true, when using `_run_and_compare()` function,
|
||||
all the output nodes in the subgraph will be used for comparison.
|
||||
"""
|
||||
|
||||
accumulate_error: bool = False
|
||||
traverse_method: str = "sequential"
|
||||
find_all: bool = False
|
||||
return_intermediate: bool = False
|
||||
all_outputs: bool = False
|
||||
|
||||
def __str__(self):
|
||||
settings_str = "FX Minimizer Settings:\n"
|
||||
|
|
@ -341,7 +345,7 @@ class _MinimizerBase:
|
|||
report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1]
|
||||
report.append("Run and compare ...")
|
||||
|
||||
if output_names:
|
||||
if output_names and not self.settings.all_outputs:
|
||||
output_nodes: NodeList = []
|
||||
for node in submodule.graph.nodes:
|
||||
if node.op == "output":
|
||||
|
|
@ -381,19 +385,23 @@ class _MinimizerBase:
|
|||
self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
|
||||
report.append(f"Numerical accuracy = {numeric_result}")
|
||||
if not bool_result:
|
||||
report.append(f"Result mismatch for {result_key}")
|
||||
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||
if self.module_exporter:
|
||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||
result_key = result_key[-1]
|
||||
# pyre-ignore[29]: not a function
|
||||
self.module_exporter(
|
||||
a_input,
|
||||
submodule,
|
||||
str(result_key[0]) + "_cpu", # type: ignore[index]
|
||||
)
|
||||
# pyre-ignore[29]: not a function
|
||||
self.module_exporter(
|
||||
b_input,
|
||||
submodule,
|
||||
str(result_key[0]) + "_acc", # type: ignore[index]
|
||||
)
|
||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
|
||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||
|
||||
def _binary_search_impl(
|
||||
self, all_nodes: NodeList, start_idx: int, end_idx: int
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user