pytorch/torch/_dynamo/debug_utils.py
Edward Z. Yang 0a479d9b9c Simplify minifier testing by incorporating fault injection in prod code (#100357)
Previously, minifier testing injected faults by injecting extra code
into the repro scripts, and then ensuring this code got propagated to
all subsequent subprocess calls.  This was not only quite complicated,
but also induced a big slowdown on the minifier, because to inject the
faults, you had to import torch._inductor, which would cause the
compilation threads to immediately get initialized before you even got
to do anything else in the repro script.

This new approach fixes this problem by incorporating the fault
injection into "prod" code.  Essentially, for inductor fault injection
we introduce some new config flags that let you "configure" Inductor to
be buggy; for Dynamo fault injection we just permanently keep the buggy
testing backends registered.  This is MUCH simpler: we only have to
propagate the buggy config (which is something we're already doing),
and it saves the minifier scripts from having to immediately initialize
inductor on entry.

Also, I enable the test for Triton runtime errors, now that tl.assert_device is here.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100357
Approved by: https://github.com/voznesenskym
2023-05-02 11:44:06 +00:00

414 lines
13 KiB
Python

import copy
import functools
import getpass
import logging
import os
import subprocess
import tempfile
import textwrap
from collections import Counter
from importlib import import_module
import torch
from torch._prims_common import is_float_dtype
from . import config
from .utils import clone_inputs, get_debug_dir
log = logging.getLogger(__name__)
inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()
if use_buck:
import libfb.py.build_info # type: ignore[import]
extra_deps = []
extra_imports = ""
if use_buck:
extra_deps = [
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
"//caffe2/torch/fb/sparsenn:sparsenn_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
]
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//")
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
class BuckTargetWriter:
def __init__(self, filename):
self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
self.target = self.py_file.replace(".py", "")
# Get main_module path from fbcode
self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
self.path = self.path[self.path.find("fbcode.") :]
self.path = self.path[7:]
# Get cmd line path
tmp = self.subdir
tmp = tmp[tmp.find("fbcode/") :][7:]
self.cmd_line_path = f"//{tmp}:{self.target}"
def build(self):
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
return textwrap.dedent(
f"""
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
python_binary(
name="{self.target}",
srcs = ["{self.py_file}"],
compile = False,
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch",
"//triton:triton",
"{cur_target}",
],
cpp_deps = [
{extra_cpp_deps}
],
main_module = "{self.path}",
)
"""
)
def write(self, print_msg=True):
target_file = os.path.join(self.subdir, "TARGETS")
with open(target_file, "w") as fd:
fd.write(self.build())
# log.warning("Wrote isolation TARGETS file at %s", target_file)
cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
if print_msg:
log.warning(
"Found an example that reproduces the error. Run this cmd to repro - %s",
" ".join(cmd_split),
)
return cmd_split
def minifier_dir():
path = os.path.join(get_debug_dir(), "minifier")
if path is None:
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
return path
class NNModuleToString:
safe_reprs = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.Softmax,
torch.nn.ReLU,
torch.nn.GELU,
torch.nn.Identity,
torch.nn.MaxPool2d,
torch.nn.Embedding,
torch.nn.Tanh,
torch.nn.ConvTranspose1d,
torch.nn.GLU,
torch.nn.LSTM,
torch.nn.Flatten,
torch.nn.AdaptiveAvgPool2d,
]
@staticmethod
def can_convert_to_string(gm):
cant_convert = set()
for _, module in gm.named_children():
if type(module) not in NNModuleToString.safe_reprs:
cant_convert.add(module)
if len(cant_convert) > 0:
log.warning("We have not tested reprs of some modules - %s", cant_convert)
# TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
return True
@staticmethod
def convert(gm):
from torch.nn.modules.module import _addindent
tab = " " * 4
model_str = textwrap.dedent(
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
"""
)
for module_name, module in gm.named_children():
module_str = f"{module.__repr__()}"
# module should be a core torch.nn.Module, so all parameters
# should be on the same device.
example_param = next(module.parameters(), None)
if example_param is not None and example_param.is_cuda:
module_str = f"{module_str}.cuda()"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in gm._buffers.items():
if buffer is None:
continue
if torch.is_floating_point(buffer):
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
else:
tensor_str = (
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
)
if buffer.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
for param_name, param in gm._parameters.items():
if param is None:
continue
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))"
if param.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
# TODO - Keep this code for now. But, I don't think we will need this.
# attrs = dir(gm)
# for attr in attrs:
# if "_tensor_constant" in attr:
# val = getattr(gm, attr)
# model_str += f" {attr} = {val!r}\n"
model_str += f"{_addindent(gm.code, 4)}\n"
return model_str
@functools.lru_cache(None) # subprocess is expensive
def _cuda_system_info_comment():
if not torch.cuda.is_available():
return "# torch.cuda.is_available()==False, no GPU info collected\n"
model_str = "# CUDA Info: \n"
try:
cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE)
cuda_version_lines = cuda_version_out.stdout.decode().split("\n")
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
model_str += f"{comment}\n"
except FileNotFoundError:
model_str += "# nvcc not found\n"
gpu_names = Counter(
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
)
model_str += "# GPU Hardware Info: \n"
for name, count in gpu_names.items():
model_str += f"# {name} : {count} \n"
model_str += "\n"
return model_str
def generate_config_string(*, stable_output=False):
import torch._functorch.config
import torch._inductor.config
if stable_output:
return "# config omitted due to stable_output=True"
return f"""\
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
{torch._functorch.config.codegen_config()}
"""
def get_minifier_repro_path():
return os.path.join(minifier_dir(), "minifier_launcher.py")
def helper_for_dump_minify(contents):
minified_repro_path = get_minifier_repro_path()
log.warning("Writing minified repro to %s", minified_repro_path)
if use_buck:
BuckTargetWriter(minified_repro_path).write()
try:
with open(minified_repro_path, "w") as fd:
fd.write(contents)
except OSError as e:
log.exception(e)
raise NotImplementedError("Could not write to {minified_repro_path}") from e
class AccuracyError(Exception):
pass
def clone_inputs_retaining_gradness(example_inputs):
"""
This clone inputs is different from utils clone_input. In case of minifier,
all the tensors are leaf tensors while creating a new graph. So, we set the
requires_grad field w/o checking the leafness of the tensor.
"""
cloned_inputs = clone_inputs(example_inputs)
for idx in range(len(example_inputs)):
if isinstance(cloned_inputs[idx], torch.Tensor):
cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
return cloned_inputs
def run_fwd_maybe_bwd(gm, args, only_fwd=False):
"""
Runs a forward and possibly backward iteration for a given mod and args.
"""
from torch._functorch.aot_autograd import make_boxed_func
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
gm = copy.deepcopy(gm)
args = clone_inputs_retaining_gradness(args)
if hasattr(gm, "zero_grad"):
gm.zero_grad(True)
# TorchInductor returned callable expects lists. So, boxing the call.
orig_named_parameters = getattr(gm, "named_parameters", None)
orig_named_buffers = getattr(gm, "named_buffers", None)
if not hasattr(gm, "_boxed_call") and (
orig_named_parameters is not None or orig_named_buffers is not None
):
gm = make_boxed_func(gm)
if orig_named_parameters is not None:
gm.named_parameters = orig_named_parameters
if orig_named_buffers is not None:
gm.named_buffers = orig_named_buffers
out = gm(args)
if only_fwd:
return out
if requires_bwd_pass(out):
loss = reduce_to_scalar_loss(out)
loss.backward()
return collect_results(gm, out, None, args)
def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
"""
Check two models have same accuracy.
"""
from .eval_frame import OptimizedModule
from .testing import (
named_buffers_for_optimized_module,
named_parameters_for_optimized_module,
)
from .utils import same
if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)
gm.named_buffers = named_buffers_for_optimized_module(gm)
if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm)
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
try:
fp64_model, fp64_examples = cast_to_fp64(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
except Exception:
log.warning("Could not generate fp64 outputs")
fp64_ref = None
try:
res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
except Exception as e:
# This means that the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return True.
log.exception(
(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph."
)
)
return True
passing = same(ref, res, fp64_ref, tol=config.repro_tolerance, equal_nan=True)
return passing
def cast_convert_element_type_to_fp64(model):
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.prims.convert_element_type.default
):
assert len(node.args) == 2
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
node.args = (node.args[0], torch.float64)
model.graph.lint()
model.recompile()
return model
def cast_to(dtype, model, inputs):
from torch.utils._pytree import tree_map
model = model.to(dtype)
if dtype == torch.float64:
# If casting to fp64 for accuracy comparison, we need to
# take care of convert_element_type explicitly
model = cast_convert_element_type_to_fp64(model)
inputs = tree_map(
lambda x: x.to(dtype)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
inputs,
)
return model, inputs
def cast_to_fp64(model, inputs):
return cast_to(torch.float64, model, inputs)
def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False):
try:
compiled_gm = compiler_fn(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
except Exception as e:
# This means that the the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return False.
log.exception(
(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph"
)
)
return False
return not same_two_models(gm, compiled_gm, example_inputs, only_fwd)