mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Improve minifier printing to be more chatty when it makes sense (#100486)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/100486 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
c7e9f40653
commit
c2556c034d
|
|
@ -24,7 +24,7 @@ def inner(x):
|
|||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
inner(torch.randn(2, 2).to("{device}"))
|
||||
"""
|
||||
# These must isolate because they crash the process
|
||||
self._run_full_test(run_code, "aot", expected_error, isolate=True)
|
||||
|
|
|
|||
|
|
@ -559,14 +559,16 @@ def isolate_fails(
|
|||
)
|
||||
p.wait()
|
||||
|
||||
if p.returncode != 0:
|
||||
stdout.seek(0)
|
||||
stderr.seek(0)
|
||||
print(textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "))
|
||||
print(textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "))
|
||||
# print(f"Isolated test failed - {file_name}")
|
||||
return True
|
||||
return False
|
||||
stdout.seek(0)
|
||||
stderr.seek(0)
|
||||
print(
|
||||
textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), file=sys.stdout
|
||||
)
|
||||
print(
|
||||
textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), file=sys.stderr
|
||||
)
|
||||
# print(f"Isolated test failed - {file_name}")
|
||||
return p.returncode != 0
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import torch.fx as fx
|
|||
import copy
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
from typing import Callable, List
|
||||
from functools import wraps, partial
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -95,13 +96,17 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
|
|||
ConcreteProp(fail_f).propagate(*inps)
|
||||
if not graph_fails(failing_graph, inps):
|
||||
raise RuntimeError("Input graph did not fail the tester")
|
||||
print(f"Started off with {cur_size} nodes")
|
||||
print(f"Started off with {cur_size} nodes", file=sys.stderr)
|
||||
|
||||
def _register_strategy(strategy: Callable, name: str):
|
||||
@wraps(strategy)
|
||||
def new_func(old_state: ReproState, granularity=1):
|
||||
print()
|
||||
print(f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)")
|
||||
print(file=sys.stderr)
|
||||
print(
|
||||
f"Strategy: {name} (G: {granularity}) "
|
||||
f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)",
|
||||
file=sys.stderr
|
||||
)
|
||||
new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity)
|
||||
if new_state is not None:
|
||||
new_nodes = len(new_state.graph.nodes)
|
||||
|
|
@ -113,23 +118,23 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
|
|||
progress_made = False
|
||||
if new_nodes < old_nodes:
|
||||
progress_made = True
|
||||
print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes")
|
||||
print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", file=sys.stderr)
|
||||
if new_inps > old_inps:
|
||||
progress_made = True
|
||||
print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs")
|
||||
print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs", file=sys.stderr)
|
||||
if new_outs < old_outs:
|
||||
progress_made = True
|
||||
print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs")
|
||||
print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs", file=sys.stderr)
|
||||
|
||||
if not progress_made:
|
||||
raise RuntimeError("Success raised but no progress made?")
|
||||
|
||||
if not graph_fails(new_state.graph, new_state.inps):
|
||||
print("WARNING: Something went wrong, not applying this minification")
|
||||
print("WARNING: Something went wrong, not applying this minification", file=sys.stderr)
|
||||
return None
|
||||
return new_state
|
||||
else:
|
||||
print(f"FAIL: {name}")
|
||||
print(f"FAIL: {name}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return new_func
|
||||
|
|
@ -251,7 +256,7 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
|
|||
failing_state = ReproState(failing_graph, inps)
|
||||
|
||||
def try_granularity(failing_state, granularity, use_non_granular):
|
||||
print(f"Trying granularity {granularity}")
|
||||
print(f"Trying granularity {granularity}", file=sys.stderr)
|
||||
|
||||
strategies = []
|
||||
num_nodes = len(failing_state.graph.nodes)
|
||||
|
|
@ -300,8 +305,8 @@ def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable =
|
|||
if not graph_fails(failing_state.graph, failing_state.inps):
|
||||
raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
|
||||
|
||||
print(f"Made {num_queries} queries")
|
||||
print(f"Made {num_queries} queries", file=sys.stderr)
|
||||
failing_fx = fx.GraphModule(fail_f, failing_state.graph)
|
||||
dump_state(failing_fx, failing_state.inps)
|
||||
print("Wrote minimal repro out to repro.py")
|
||||
print("Wrote minimal repro out to repro.py", file=sys.stderr)
|
||||
return failing_fx, failing_state.inps
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user