import copy import functools import getpass import logging import os import shutil import subprocess import textwrap import uuid from collections import Counter from importlib import import_module import torch import torch.fx as fx from . import config from .optimizations.backends import register_backend from .utils import clone_inputs log = logging.getLogger(__name__) def minifier_dir(): path = config.repro_dir if path is None: path = f"/tmp/minifier_{getpass.getuser()}" if not os.path.exists(path): os.makedirs(path, exist_ok=True) return path class NNModuleToString: safe_reprs = [ torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.Softmax, torch.nn.ReLU, torch.nn.GELU, torch.nn.Identity, torch.nn.MaxPool2d, torch.nn.Embedding, torch.nn.Tanh, torch.nn.ConvTranspose1d, torch.nn.GLU, torch.nn.LSTM, torch.nn.Flatten, torch.nn.AdaptiveAvgPool2d, ] @staticmethod def can_convert_to_string(gm): cant_convert = set() for _, module in gm.named_children(): if type(module) not in NNModuleToString.safe_reprs: cant_convert.add(module) if len(cant_convert) > 0: log.warning(f"We have not tested reprs of some modules - {cant_convert}") # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. return True @staticmethod def convert(gm): from torch.nn.modules.module import _addindent tab = " " * 4 model_str = textwrap.dedent( """ from torch.nn import * class Repro(torch.nn.Module): def __init__(self): super().__init__() """ ) for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): if buffer is None: continue if torch.is_floating_point(buffer): tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" else: tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. # attrs = dir(gm) # for attr in attrs: # if "_tensor_constant" in attr: # val = getattr(gm, attr) # model_str += f" {attr} = {val!r}\n" model_str += f"{_addindent(gm.code, 4)}\n" return model_str @functools.lru_cache(None) # subprocess is expensive def _cuda_system_info_comment(): if not torch.cuda.is_available(): return "# torch.cuda.is_available()==False, no GPU info collected\n" model_str = "# CUDA Info: \n" try: cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE) cuda_version_lines = cuda_version_out.stdout.decode().split("\n") cuda_version_out = "".join( [f"# {s} \n" for s in cuda_version_lines if s not in [""]] ) model_str += f"{cuda_version_out}\n" except FileNotFoundError: model_str += "nvcc not found\n" gpu_names = subprocess.run( ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], stdout=subprocess.PIPE, ) gpu_names = gpu_names.stdout.decode().split("\n") gpu_names = [name for name in gpu_names if name not in ("", "name")] gpu_names = Counter(gpu_names) model_str += "# GPU Hardware Info: \n" for name, count in gpu_names.items(): model_str += f"# {name} : {count} \n" model_str += "\n" return model_str def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" import torch from torch import tensor, device import torch.fx as fx from {config.dynamo_import}.testing import rand_strided from math import inf from torch.fx.experimental.proxy_tensor import make_fx """ ) model_str += f"# torch version: {torch.version.__version__}\n" if hasattr(torch.version, "cuda"): model_str += f"# torch cuda version: {torch.version.cuda}\n" if hasattr(torch.version, "git_version"): model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() model_str += NNModuleToString.convert(gm) model_str += f"args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type) for a in args]!r}\n" model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' return model_str INDUCTOR_IMPORT = f""" from {config.inductor_import}.compile_fx import compile_fx_inner from {config.dynamo_import}.debug_utils import same_two_models """ NVFUSER_IMPORT = """ from torch.fx.passes.backends.nvfuser import NvFuserBackend nvfuser = NvFuserBackend() """ COMPILER_REPRO_OPTIONS = { "inductor": (INDUCTOR_IMPORT, "compile_fx_inner", "inductor_fails"), "inductor_accuracy": ( INDUCTOR_IMPORT, "compile_fx_inner", "inductor_accuracy_fails", ), "nvfuser": (NVFUSER_IMPORT, "nvfuser", "nvfuser_fails"), } def dump_compiler_graph_state(gm, args, compiler_name): subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) repro_path = os.path.join(config.base_dir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") except OSError: log.warning(f"No write permissions for {repro_path}") pass def save_graph_repro(fd, gm, args, compiler_name): fd.write(generate_compiler_repro_string(gm, args)) fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0]) if "_accuracy" in compiler_name: fd.write( textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" """ ) ) else: fd.write( textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) compiled(*args) """ ) ) def isolate_fails(fx_g, args, compiler_name: str, env=None): if env is None: env = {} subdir = f"{minifier_dir()}/isolate" if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: fd.write(generate_compiler_repro_string(fx_g, args)) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( f""" from {__name__} import {fail_fn} """ ) ) fd.write( textwrap.dedent( f""" if {fail_fn}(mod, args): exit(1) else: exit(0) """ ) ) new_env = os.environ.copy() new_env = {**new_env, **env} p = subprocess.Popen( ["python", file_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=new_env, ) out, err = p.communicate() if p.returncode != 0: print(textwrap.indent(out.decode("utf-8"), prefix=">> ")) print(textwrap.indent(err.decode("utf-8"), prefix=">> ")) return True return False def inductor_fails(fx_g, args, check_str=None): compile_fx_inner = import_module( f"{config.inductor_import}.compile_fx" ).compile_fx_inner import_module(f"{config.inductor_import}.config").triton.autotune = False try: result = fx_g(*args) assert isinstance(result, (tuple, list)) assert not any([isinstance(x, (tuple, list)) for x in result]) except Exception: return False try: compile_mod = compile_fx_inner(fx_g, args) compile_mod(*args) except Exception as e: if check_str is not None and check_str not in repr(e): return False print(repr(e)) return True return False def nvfuser_fails(fx_g, args, check_str=None): from torch.fx.passes.backends.nvfuser import NvFuserBackend nvfuser = NvFuserBackend() try: compile_mod = nvfuser(fx_g, args) compile_mod = compile_mod(*args) except Exception as e: if check_str is not None and check_str not in repr(e): return False print(repr(e)) return True return False def inductor_accuracy_fails(fx_g, args, check_str=None): from torchinductor.compile_fx import compile_fx_inner return backend_aot_accuracy_fails(fx_g, args, compile_fx_inner) def helper_for_dump_minify(contents): minified_repro_path = os.path.join(minifier_dir(), "minifier_launcher.py") log.warning(f"Writing minified repro to {minified_repro_path}") try: with open(minified_repro_path, "w") as fd: fd.write(contents) except OSError as e: log.exception(e) raise NotImplementedError("Could not write to {minified_repro_path}") local_path = os.path.join(config.base_dir, "minifier_launcher.py") try: shutil.copyfile(minified_repro_path, local_path) log.warning( f"Copying minified repro from {minified_repro_path} to {local_path} for convenience" ) except OSError: log.warning(f"Don't have write permissions for {local_path}") def dump_to_minify(gm, args, compiler_name: str): favored_device = 1 if torch.cuda.device_count() >= 2 else 0 contents = textwrap.dedent( f""" {generate_compiler_repro_string(gm, args)} from functools import partial from {__name__} import ( isolate_fails, dump_compiler_graph_state, ) from functorch.compile import minifier env_variables = {{"CUDA_VISIBLE_DEVICES": "{favored_device}"}} minifier( mod, args, module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ ) return helper_for_dump_minify(contents) def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both forward and backward call separately with the backend compiler_fn - like inductor or nvfuser. Intercepting after Aot Autograd presents neat abstration, where all the params are lifted as graph inputs, making it easy to save the graph as a string. """ @functools.wraps(compiler_fn) def debug_wrapper(gm, example_inputs, **kwargs): orig_graph = copy.deepcopy(gm.graph) assert config.repro_after in ("dynamo", "aot", None) def deferred_for_real_inputs(*real_inputs): """ Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, example_inputs can be fake tensors. We can call compiler_fn (which is inductor or nvfuser) with fake tensors but the actualy compiled_fn should be called with real tensors. Therefore, the actual invocation is deffered. """ if config.repro_level == 3: # Always dump the original module in case we have segfaults dump_to_minify( fx.GraphModule(gm, orig_graph), real_inputs, compiler_name ) if config.repro_level == 4: if compiler_name != "inductor": raise NotImplementedError( "Accuracy minification is supported for inductor only" ) compiled_fn = compiler_fn(gm, example_inputs, **kwargs) if backend_aot_accuracy_fails(gm, real_inputs, compiler_fn): log.warning("Accuracy failed for the AOT Autograd graph") dump_compiler_graph_state( fx.GraphModule(gm, orig_graph), real_inputs, f"{compiler_name}_accuracy", ) dump_to_minify( fx.GraphModule(gm, orig_graph), real_inputs, f"{compiler_name}_accuracy", ) raise ValueError("Bad accuracy detected") else: # Call the compiled function with real inputs return compiled_fn(*real_inputs) else: try: # Call the compiler_fn - which is either aot_autograd or inductor # with fake inputs compiled_fn = compiler_fn(gm, example_inputs, **kwargs) # Call the compiled function with real inputs return compiled_fn(*real_inputs) except Exception as e: if config.repro_level == 1: dump_compiler_graph_state( fx.GraphModule(gm, orig_graph), real_inputs, compiler_name ) elif config.repro_level == 2: dump_to_minify( fx.GraphModule(gm, orig_graph), real_inputs, compiler_name ) raise e if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs else: compiled_fn = compiler_fn(gm, example_inputs, **kwargs) return compiled_fn return debug_wrapper def run_fwd_maybe_bwd(gm, args, only_fwd=False): """ Runs a forward and possibly backward iteration for a given mod and args. """ from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass gm = copy.deepcopy(gm) new_args = clone_inputs(args) # Set the requires_grad field explicitly because clone_inputs only sets # requires_grad for leaf tensors. for narg, arg in zip(new_args, args): narg.requires_grad_(arg.requires_grad) args = new_args if hasattr(gm, "zero_grad"): gm.zero_grad(True) out = gm(*args) if only_fwd: return out if requires_bwd_pass(out): loss = reduce_to_scalar_loss(out) loss.backward() return collect_results(gm, out, None, []) def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ from .utils import same ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: fp64_model, fp64_examples = cast_to_fp64( copy.deepcopy(gm), clone_inputs(example_inputs) ) fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) except Exception: log.warning("Could not generate fp64 outputs") fp64_ref = None res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) passing = same(ref, res, fp64_ref, tol=0.001, equal_nan=True) return passing def cast_to(dtype, model, inputs): from torch.utils._pytree import tree_map # cast model and inputs to fp16 model = model.to(dtype) inputs = tree_map( lambda x: x.to(dtype) if isinstance(x, torch.Tensor) and x.is_floating_point() else x, inputs, ) return model, inputs def cast_to_fp64(model, inputs): return cast_to(torch.float64, model, inputs) def generate_dynamo_fx_repro_string( model_str, args, compiler_name, check_accuracy=False ): """ Generate a repro string for backend-agnostic minified version. """ run_code = textwrap.dedent( f""" with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): ref = run_fwd_maybe_bwd(mod, args) res = run_fwd_maybe_bwd(opt_mod, args) """ ) if config.repro_level == 4 or check_accuracy: run_code = textwrap.dedent( f""" mod.eval() opt_mod.eval() with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" assert same_two_models(mod, opt_mod, args), "Dynamo failed" """ ) return textwrap.dedent( f""" from math import inf import torch from torch import tensor, device import torch.fx as fx import {config.dynamo_import} from {config.dynamo_import}.testing import rand_strided from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} mod = Repro().cuda() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} """ ) def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): """ Saves the repro to a repro.py file """ subdir = os.path.join(minifier_dir()) if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") model_str = NNModuleToString.convert(gm) with open(file_name, "w") as fd: fd.write( generate_dynamo_fx_repro_string( model_str, args, compiler_name, check_accuracy ) ) latest_repro = os.path.join(subdir, "repro.py") log.warning(f"Copying {file_name} to {latest_repro} for convenience") shutil.copyfile(file_name, latest_repro) local_path = os.path.join(config.base_dir, "repro.py") try: shutil.copyfile(file_name, local_path) log.warning( f"Copying minified repro from {file_name} to {local_path} for convenience" ) except OSError: log.warning("No write permissions for {local_path}") # TODO - Commented because we are assuming that nn.Modules can be safely repr'd # If that does not work, we might have to bring this code back. So, keeping it # as it is for now. # def dump_backend_repro_as_tarfile(gm, args, compiler_name): # """ # Saves the repro in repro.tar.gz, as opposed to a file. This is used for # cases, where we can't convert a Fx GraphModule to a string, and therefore # fallback to to_folder for serialization. We accompany this with a repro.py # script that imports the saved module, sets it up and runs the model to repro # the error. # """ # import tarfile # subdir = os.path.join(minifier_dir(), "checkpoints") # if not os.path.exists(subdir): # os.makedirs(subdir, exist_ok=True) # tmp_dir = os.path.join(subdir, f"{len(gm.graph.nodes)}") # if os.path.exists(tmp_dir): # shutil.rmtree(tmp_dir) # os.makedirs(tmp_dir, exist_ok=True) # file_name = os.path.join(tmp_dir, "repro.py") # gm_dir = os.path.join(tmp_dir, "module") # if not os.path.exists(gm_dir): # os.makedirs(gm_dir, exist_ok=True) # for node in gm.graph.nodes: # new_kwargs = {} # for k, v in node.kwargs.items(): # if isinstance(v, torch.device): # v = v.type # new_kwargs[k] = v # node.kwargs = new_kwargs # gm.recompile() # print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") # with open(file_name, "w") as fd: # # TODO - Add the readable version of to_folder when available # gm.to_folder(gm_dir, "Repro") # fd.write( # generate_dynamo_fx_repro_string( # "from module import Repro", args, compiler_name # ) # ) # local_dir = os.path.join(config.base_dir, "repro") # if os.path.exists(local_dir): # shutil.rmtree(local_dir) # shutil.copytree(tmp_dir, local_dir) # local_tar_file = os.path.join(config.base_dir, "repro.tar.gz") # print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}") # with tarfile.open(local_tar_file, "w:gz") as tar: # tar.add(local_dir, arcname=os.path.basename(local_dir)) def dump_backend_state(gm, args, compiler_name, check_accuracy=False): """ Dumps the dynamo graph to repro the issue. 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a repro.py file. 2) If we can't convert Fx GraphModule to a string, we use to_folder to save the module and save a tar file. """ assert NNModuleToString.can_convert_to_string(gm) return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) # return dump_backend_repro_as_tarfile(gm, args, compiler_name) def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False): compiled_gm = compiler_fn(copy.deepcopy(gm), clone_inputs(example_inputs)) return not same_two_models(gm, compiled_gm, example_inputs, only_fwd) backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) def backend_fails(gm, example_inputs, compiler_fn, orig_failure): """ Minifier uses this function to identify if the minified graph module fails with the same error. One caveat is that minifier can potentially go into a wrong direction when the resulting graph module fails for a different reason. To avoid this, we save the string for the original exception and check similarity between new and old exception. They can be somewhat different in some cases, when the exception string depends on the failing node information. So, we have a loose similarity metric to guide the minifier path. """ from difflib import SequenceMatcher try: compiled_gm = compiler_fn(gm, example_inputs) run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs)) return False except Exception as e: new_failure = str(e) if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: return True return False def dump_to_minify_after_dynamo(gm, args, compiler_name): model_str = NNModuleToString.convert(gm) minifier_backend = "dynamo_minifier_backend" if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" contents = textwrap.dedent( f""" import os from math import inf import torch from torch import tensor, device import torch.fx as fx import functools import {config.dynamo_import} from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided {config.dynamo_import}.config.repro_dir = \"{minifier_dir()}\" args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} mod = Repro().cuda() # Setup debug minifier compiler compiler_fn = BACKENDS["{minifier_backend}"] dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", ) opt_mod = {config.dynamo_import}.optimize(dynamo_minifier_backend)(mod) with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): opt_mod(*args) """ ) helper_for_dump_minify(contents) def wrap_backend_debug(compiler_fn, compiler_name: str): """ A minifier decorator that wraps the TorchDynamo produced Fx graph modules. As opposed to wrap_compiler_debug, this wrapper intercepts at the TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some level, e.g., it is useful for minifying issues related to Aot Autograd tracing. If an error is found, we minify and save the minified repro in repro.tar.gz. """ @functools.wraps(compiler_fn) def debug_wrapper(gm, example_inputs, **kwargs): assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": # Ensure that we fail when backend fails config.raise_on_backend_error = True if config.repro_level == 3: dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) # Check for either accuracy (level 4) or other type of failures. if config.repro_level == 4: # Check Accuracy compiled_gm = compiler_fn(gm, example_inputs, **kwargs) if backend_accuracy_fails(gm, example_inputs, compiler_fn): log.warning("Accuracy failed for the TorchDyanmo produced graph") dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, compiler_name, ) raise ValueError("Bad accuracy detected") else: try: compiled_gm = compiler_fn(gm, example_inputs, **kwargs) run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs)) except Exception as exc: log.warning( "Compiled Fx GraphModule failed with following error. Setting up minifier." ) log.exception(exc) if config.repro_level == 1: dump_state_fn = functools.partial( dump_backend_state, compiler_name=compiler_name ) dump_state_fn( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs ) elif config.repro_level == 2: dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, compiler_name, ) raise ValueError("Issue deteced. Repro at minifier_launcher.py.") else: compiled_gm = compiler_fn(gm, example_inputs, **kwargs) return compiled_gm debug_wrapper._torchdynamo_orig_callable = compiler_fn return debug_wrapper @register_backend def dynamo_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier from .eval_frame import lookup_backend compiler_fn = lookup_backend(compiler_name) try: compiled_gm = compiler_fn(gm, example_inputs) run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs)) raise ValueError("No issue was detected") except Exception as exc: orig_failure = str(exc) log.warning( "Compiled Fx GraphModule failed with following error. Starting minifier." ) log.exception(exc) dump_state_fn = functools.partial( dump_backend_state, compiler_name=compiler_name ) dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) fails_fn = functools.partial( backend_fails, compiler_fn=compiler_fn, orig_failure=orig_failure, ) minifier( gm, example_inputs, module_fails=fails_fn, dump_state=dump_state_fn, ) return gm @register_backend def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier from torchdynamo.optimizations.backends import BACKENDS if compiler_name == "inductor": from torchinductor.compile_fx import compile_fx compiler_fn = compile_fx else: compiler_fn = BACKENDS[compiler_name] # Set the eval mode to remove randomness. gm.eval() # Check Accuracy if backend_accuracy_fails(gm, example_inputs, compiler_fn): log.warning("Accuracy failed for the TorchDyanmo produced graph") dump_state_fn = functools.partial( dump_backend_state, compiler_name=compiler_name, check_accuracy=True ) fails_fn = functools.partial( backend_accuracy_fails, compiler_fn=compiler_fn, ) dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) minifier( gm, example_inputs, module_fails=fails_fn, dump_state=dump_state_fn, ) else: log.error("Input graph does not fail accuracy testing") return gm