Improve minifier printing to be more chatty when it makes sense (#100486)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100486
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-05-02 20:36:48 -07:00 committed by PyTorch MergeBot
parent c7e9f40653
commit c2556c034d
3 changed files with 27 additions and 20 deletions

View File

@ -24,7 +24,7 @@ def inner(x):
x = torch.cos(x)
return x
inner(torch.randn(20, 20).to("{device}"))
inner(torch.randn(2, 2).to("{device}"))
"""
# These must isolate because they crash the process
self._run_full_test(run_code, "aot", expected_error, isolate=True)

View File

@ -559,14 +559,16 @@ def isolate_fails(
)
p.wait()
if p.returncode != 0:
stdout.seek(0)
stderr.seek(0)
print(textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "))
print(textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "))
# print(f"Isolated test failed - {file_name}")
return True
return False
stdout.seek(0)
stderr.seek(0)
print(
textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), file=sys.stdout
)
print(
textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), file=sys.stderr
)
# print(f"Isolated test failed - {file_name}")
return p.returncode != 0
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

View File

@ -2,6 +2,7 @@ import torch.fx as fx
import copy
import torch
import math
import sys
from typing import Callable, List
from functools import wraps, partial
from dataclasses import dataclass
@ -95,13 +96,17 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
ConcreteProp(fail_f).propagate(*inps)
if not graph_fails(failing_graph, inps):
raise RuntimeError("Input graph did not fail the tester")
print(f"Started off with {cur_size} nodes")
print(f"Started off with {cur_size} nodes", file=sys.stderr)
def _register_strategy(strategy: Callable, name: str):
@wraps(strategy)
def new_func(old_state: ReproState, granularity=1):
print()
print(f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)")
print(file=sys.stderr)
print(
f"Strategy: {name} (G: {granularity}) "
f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)",
file=sys.stderr
)
new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity)
if new_state is not None:
new_nodes = len(new_state.graph.nodes)
@ -113,23 +118,23 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
progress_made = False
if new_nodes < old_nodes:
progress_made = True
print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes")
print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", file=sys.stderr)
if new_inps > old_inps:
progress_made = True
print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs")
print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs", file=sys.stderr)
if new_outs < old_outs:
progress_made = True
print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs")
print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs", file=sys.stderr)
if not progress_made:
raise RuntimeError("Success raised but no progress made?")
if not graph_fails(new_state.graph, new_state.inps):
print("WARNING: Something went wrong, not applying this minification")
print("WARNING: Something went wrong, not applying this minification", file=sys.stderr)
return None
return new_state
else:
print(f"FAIL: {name}")
print(f"FAIL: {name}", file=sys.stderr)
return None
return new_func
@ -251,7 +256,7 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
failing_state = ReproState(failing_graph, inps)
def try_granularity(failing_state, granularity, use_non_granular):
print(f"Trying granularity {granularity}")
print(f"Trying granularity {granularity}", file=sys.stderr)
strategies = []
num_nodes = len(failing_state.graph.nodes)
@ -300,8 +305,8 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
if not graph_fails(failing_state.graph, failing_state.inps):
raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
print(f"Made {num_queries} queries")
print(f"Made {num_queries} queries", file=sys.stderr)
failing_fx = fx.GraphModule(fail_f, failing_state.graph)
dump_state(failing_fx, failing_state.inps)
print("Wrote minimal repro out to repro.py")
print("Wrote minimal repro out to repro.py", file=sys.stderr)
return failing_fx, failing_state.inps