mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Opt model save and load (#126374)
## save&load support for OptimizedModule [Issue Description](https://github.com/pytorch/pytorch/pull/101651) English is not my native language; please excuse typing errors. This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\ I'll do something with the merge conflicts later ### test result for test/dynamo Conclusion:\ It performs the same as before as far as I can see. ENV(CPU only):\ platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\ configfile: pytest.ini\ plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0 #### before this pr: [before](https://github.com/pytorch/pytorch/files/15329370/before.md) #### after this pr: [after](https://github.com/pytorch/pytorch/files/15329376/after.md) ### some changes 1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'" 2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling 3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py ) 4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374 Approved by: https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
parent
9a8ab778d3
commit
c3949b20a1
|
|
@ -3,6 +3,8 @@
|
|||
import collections
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import types
|
||||
import unittest
|
||||
|
|
@ -16,6 +18,7 @@ import torch
|
|||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.nn.functional as F
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.eval_frame import unsupported
|
||||
from torch._dynamo.mutation_guard import GenerationTracker
|
||||
from torch._dynamo.testing import expectedFailureDynamic, same
|
||||
|
|
@ -2739,6 +2742,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(test_functions._variable, 1)
|
||||
self.assertEqual(res, 3 * torch.ones(10))
|
||||
|
||||
@unittest.skipIf(
|
||||
"inductor" not in torch._dynamo.list_backends(),
|
||||
"inductor backend is not available",
|
||||
)
|
||||
def test_save_and_load_inductor(self):
|
||||
mod = MockModule()
|
||||
opt_mod = torch.compile(mod, backend="inductor")
|
||||
inp = torch.randn(10, 10)
|
||||
opt_mod(inp)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model(inp)
|
||||
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
|
||||
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
|
||||
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
loaded_model(inp)
|
||||
self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
|
||||
def test_save_and_load_all_backends(self):
|
||||
mod = MockModule()
|
||||
inp = torch.randn(10, 10)
|
||||
for backend in torch._dynamo.list_backends():
|
||||
try:
|
||||
opt_mod = torch.compile(mod, backend=backend)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
opt_mod(inp)
|
||||
opt_success = torch._inductor.metrics.generated_kernel_count == 0
|
||||
torch._dynamo.reset() # force recompiles
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
loaded_model(inp)
|
||||
loaded_success = torch._inductor.metrics.generated_kernel_count == 0
|
||||
self.assertEqual(opt_success, loaded_success)
|
||||
except torch._dynamo.exc.BackendCompilerFailed:
|
||||
pass
|
||||
|
||||
def test_monkeypatching_forward(self):
|
||||
class FakeModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -14,18 +14,22 @@ from torch.utils._python_dispatch import _disable_current_modes
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def aot_autograd(**kwargs):
|
||||
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
|
||||
class AotAutograd:
|
||||
def __init__(self, **kwargs):
|
||||
self.__name__ = "compiler_fn"
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
|
||||
return flatten_graph_inputs(
|
||||
gm,
|
||||
example_inputs,
|
||||
compiler_fn,
|
||||
self,
|
||||
)
|
||||
|
||||
# Hack to get around circular import problems with aot_eager_decomp_partition
|
||||
if callable(kwargs.get("decompositions")):
|
||||
kwargs["decompositions"] = kwargs["decompositions"]()
|
||||
if callable(self.kwargs.get("decompositions")):
|
||||
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
|
||||
|
||||
# NB: dont delete counter increment
|
||||
counters["aot_autograd"]["total"] += 1
|
||||
|
|
@ -42,10 +46,10 @@ def aot_autograd(**kwargs):
|
|||
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||
return disable(disable(bw_compiler)(*args, **kwargs))
|
||||
|
||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
kwargs["inference_compiler"] = (
|
||||
kwargs.get("inference_compiler") or kwargs["fw_compiler"]
|
||||
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
||||
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
self.kwargs["inference_compiler"] = (
|
||||
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
||||
)
|
||||
|
||||
from functorch.compile import nop
|
||||
|
|
@ -54,7 +58,7 @@ def aot_autograd(**kwargs):
|
|||
|
||||
# debug asserts slow down compile time noticeably,
|
||||
# So only default them on when the aot_eager backend is used.
|
||||
if kwargs.get("fw_compiler", None) == nop:
|
||||
if self.kwargs.get("fw_compiler", None) == nop:
|
||||
patch_config = patch("functorch.compile.config.debug_assert", True)
|
||||
else:
|
||||
patch_config = contextlib.nullcontext()
|
||||
|
|
@ -62,14 +66,16 @@ def aot_autograd(**kwargs):
|
|||
try:
|
||||
# NB: NOT cloned!
|
||||
with enable_aot_logging(), patch_config:
|
||||
cg = aot_module_simplified(gm, example_inputs, **kwargs)
|
||||
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
|
||||
counters["aot_autograd"]["ok"] += 1
|
||||
return disable(cg)
|
||||
except Exception:
|
||||
counters["aot_autograd"]["not_ok"] += 1
|
||||
raise
|
||||
|
||||
return compiler_fn
|
||||
|
||||
def aot_autograd(**kwargs):
|
||||
return AotAutograd(**kwargs)
|
||||
|
||||
|
||||
def mem_efficient_fusion_kwargs(use_decomps):
|
||||
|
|
|
|||
|
|
@ -361,17 +361,34 @@ def cprofile_wrapper(func):
|
|||
return profile_wrapper
|
||||
|
||||
|
||||
def convert_frame_assert(
|
||||
compiler_fn: CompilerFn,
|
||||
one_graph: bool = True,
|
||||
export: bool = False,
|
||||
export_constraints=None,
|
||||
):
|
||||
"""Fully convert a frame into an FX graph"""
|
||||
reset_graph_break_dup_checker()
|
||||
class ConvertFrameAssert:
|
||||
def __init__(
|
||||
self,
|
||||
compiler_fn: CompilerFn,
|
||||
one_graph: bool = True,
|
||||
export: bool = False,
|
||||
export_constraints=None,
|
||||
):
|
||||
reset_graph_break_dup_checker()
|
||||
self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
||||
self._one_graph = one_graph
|
||||
self._export = export
|
||||
self._export_constraints = export_constraints
|
||||
|
||||
def _convert_frame_assert(
|
||||
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0
|
||||
@property
|
||||
def _clone_with_backend(self):
|
||||
return lambda backend: convert_frame_assert(
|
||||
backend, self._one_graph, self._export, self._export_constraints
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
frame: types.FrameType,
|
||||
cache_entry,
|
||||
hooks: Hooks,
|
||||
frame_state,
|
||||
*,
|
||||
skip: int = 0,
|
||||
):
|
||||
increment_frame()
|
||||
|
||||
|
|
@ -458,10 +475,10 @@ def convert_frame_assert(
|
|||
frame.f_globals,
|
||||
frame.f_locals,
|
||||
frame.f_builtins,
|
||||
compiler_fn,
|
||||
one_graph,
|
||||
export,
|
||||
export_constraints,
|
||||
self._torchdynamo_orig_callable,
|
||||
self._one_graph,
|
||||
self._export,
|
||||
self._export_constraints,
|
||||
hooks,
|
||||
cache_entry,
|
||||
cache_size,
|
||||
|
|
@ -471,13 +488,15 @@ def convert_frame_assert(
|
|||
skip=skip + 1,
|
||||
)
|
||||
|
||||
_convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
||||
|
||||
def _clone_with_backend(backend):
|
||||
return convert_frame_assert(backend, one_graph, export, export_constraints)
|
||||
|
||||
_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined]
|
||||
return _convert_frame_assert
|
||||
def convert_frame_assert(
|
||||
compiler_fn: CompilerFn,
|
||||
one_graph: bool = True,
|
||||
export: bool = False,
|
||||
export_constraints=None,
|
||||
):
|
||||
"""Fully convert a frame into an FX graph"""
|
||||
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
|
|
@ -907,16 +926,27 @@ def _compile(
|
|||
torch._dynamo.callback_handler.run_end_callbacks()
|
||||
|
||||
|
||||
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
||||
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
||||
inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
|
||||
class ConvertFrame:
|
||||
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks):
|
||||
self._torchdynamo_orig_callable = compiler_fn
|
||||
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
|
||||
self._hooks = hooks
|
||||
|
||||
def _convert_frame(
|
||||
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0
|
||||
@property
|
||||
def _clone_with_backend(self):
|
||||
return lambda backend: convert_frame(backend, self._hooks)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
frame: types.FrameType,
|
||||
cache_entry,
|
||||
hooks: Hooks,
|
||||
frame_state,
|
||||
skip: int = 0,
|
||||
):
|
||||
counters["frames"]["total"] += 1
|
||||
try:
|
||||
result = inner_convert(
|
||||
result = self._inner_convert(
|
||||
frame, cache_entry, hooks, frame_state, skip=skip + 1
|
||||
)
|
||||
counters["frames"]["ok"] += 1
|
||||
|
|
@ -980,9 +1010,10 @@ def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
|||
log.warning(error_msg, exc_info=True)
|
||||
return None
|
||||
|
||||
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
||||
_convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined]
|
||||
return _convert_frame
|
||||
|
||||
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
||||
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
||||
return ConvertFrame(compiler_fn, hooks)
|
||||
|
||||
|
||||
# TODO mlazos: add support for same args, or record them
|
||||
|
|
@ -1023,9 +1054,13 @@ def first_real_inst_idx(code):
|
|||
raise RuntimeError("RESUME instruction not found in code")
|
||||
|
||||
|
||||
def catch_errors_wrapper(callback, hooks: Hooks):
|
||||
@functools.wraps(callback)
|
||||
def catch_errors(frame, cache_entry, frame_state):
|
||||
class CatchErrorsWrapper:
|
||||
def __init__(self, callback, hooks):
|
||||
functools.wraps(callback)(self)
|
||||
self._torchdynamo_orig_callable = callback
|
||||
self.hooks = hooks
|
||||
|
||||
def __call__(self, frame, cache_entry, frame_state):
|
||||
assert frame_state is not None
|
||||
|
||||
is_skipfile = trace_rules.check(frame.f_code)
|
||||
|
|
@ -1063,19 +1098,26 @@ def catch_errors_wrapper(callback, hooks: Hooks):
|
|||
|
||||
ddp_optimizer = DDPOptimizer(
|
||||
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
||||
backend_compile_fn=callback._torchdynamo_orig_callable,
|
||||
backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable,
|
||||
)
|
||||
assert hasattr(
|
||||
callback, "_clone_with_backend"
|
||||
self._torchdynamo_orig_callable, "_clone_with_backend"
|
||||
), "DDPOptimizer only supports callback fns that know how to clone themselves."
|
||||
hijacked_callback = callback._clone_with_backend(
|
||||
ddp_optimizer.compile_fn,
|
||||
hijacked_callback = (
|
||||
self._torchdynamo_orig_callable._clone_with_backend(
|
||||
ddp_optimizer.compile_fn,
|
||||
)
|
||||
)
|
||||
return hijacked_callback(
|
||||
frame, cache_entry, self.hooks, frame_state
|
||||
)
|
||||
return hijacked_callback(frame, cache_entry, hooks, frame_state)
|
||||
|
||||
with compile_lock, _disable_current_modes():
|
||||
# skip=1: skip this frame
|
||||
return callback(frame, cache_entry, hooks, frame_state, skip=1)
|
||||
return self._torchdynamo_orig_callable(
|
||||
frame, cache_entry, self.hooks, frame_state, skip=1
|
||||
)
|
||||
|
||||
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
|
||||
return catch_errors
|
||||
|
||||
def catch_errors_wrapper(callback, hooks: Hooks):
|
||||
return CatchErrorsWrapper(callback, hooks)
|
||||
|
|
|
|||
|
|
@ -168,6 +168,9 @@ class OptimizedModule(torch.nn.Module):
|
|||
self._forward = self.forward
|
||||
self.forward = self._call_lazy_check
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self._orig_mod, self.dynamo_ctx))
|
||||
|
||||
def __getstate__(self):
|
||||
state = dict(self.__dict__)
|
||||
state.pop("forward", None)
|
||||
|
|
@ -273,9 +276,11 @@ class _TorchDynamoContext:
|
|||
super().__init__()
|
||||
assert callable(callback) or callback is False or callback is None
|
||||
self.callback: DynamoCallback = callback
|
||||
self._backend_ctx_ctor = backend_ctx_ctor
|
||||
self.prior: Union[Unset, DynamoCallback] = unset
|
||||
self.first_ctx = first_ctx
|
||||
self.export = export
|
||||
self._dynamic = dynamic
|
||||
self.compiler_config = compiler_config
|
||||
self.cleanup_fns: List[Callable[[], Any]] = []
|
||||
self.enter_exit_hooks = []
|
||||
|
|
@ -379,7 +384,13 @@ class _TorchDynamoContext:
|
|||
# call to a builtin without a frame for us to capture
|
||||
fn = external_utils.wrap_inline(fn)
|
||||
|
||||
callback = self.callback
|
||||
def do_nothing(*arg, **kwargs):
|
||||
pass
|
||||
|
||||
if hasattr(self, "callback"):
|
||||
callback = self.callback
|
||||
else:
|
||||
callback = do_nothing
|
||||
|
||||
is_jit_tracing = torch._C._is_tracing
|
||||
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
|
||||
|
|
@ -522,6 +533,17 @@ class OptimizeContext(_TorchDynamoContext):
|
|||
|
||||
self.enter_exit_hooks.append(call_compiled_autograd)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
(self.callback, self._backend_ctx_ctor, self.first_ctx),
|
||||
{
|
||||
"export": self.export,
|
||||
"dynamic": self._dynamic,
|
||||
"compiler_config": self.compiler_config,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class RunOnlyContext(_TorchDynamoContext):
|
||||
def __init__(self):
|
||||
|
|
@ -531,6 +553,9 @@ class RunOnlyContext(_TorchDynamoContext):
|
|||
|
||||
super().__init__(callback=False, on_enter=on_enter)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, ())
|
||||
|
||||
|
||||
class DisableContext(_TorchDynamoContext):
|
||||
def __init__(self):
|
||||
|
|
@ -583,6 +608,9 @@ class DisableContext(_TorchDynamoContext):
|
|||
|
||||
return _fn
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, ())
|
||||
|
||||
|
||||
def _optimize_catch_errors(
|
||||
compile_fn,
|
||||
|
|
|
|||
|
|
@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn):
|
|||
)
|
||||
|
||||
|
||||
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
||||
"""
|
||||
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
|
||||
As opposed to wrap_compiler_debug, this wrapper intercepts at the
|
||||
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
|
||||
level, e.g., it is useful for minifying issues related to Aot Autograd
|
||||
tracing. If an error is found, we minify and save the minified repro in
|
||||
repro.tar.gz.
|
||||
"""
|
||||
class WrapBackendDebug:
|
||||
def __init__(self, unconfigured_compiler_fn, compiler_name: str):
|
||||
functools.wraps(unconfigured_compiler_fn)(self)
|
||||
self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
|
||||
self._compiler_name = compiler_name
|
||||
if hasattr(unconfigured_compiler_fn, "__name__"):
|
||||
self.__name__ = unconfigured_compiler_fn.__name__
|
||||
if hasattr(unconfigured_compiler_fn, "compiler_name"):
|
||||
self.__name__ = unconfigured_compiler_fn.compiler_name
|
||||
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
|
||||
self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
|
||||
|
||||
@functools.wraps(unconfigured_compiler_fn)
|
||||
def debug_wrapper(gm, example_inputs, **kwargs):
|
||||
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
|
||||
def __call__(self, gm, example_inputs, **kwargs):
|
||||
compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs)
|
||||
assert config.repro_after in ("dynamo", "aot", None)
|
||||
|
||||
if config.repro_after == "dynamo":
|
||||
|
|
@ -82,7 +83,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||
)
|
||||
|
||||
if config.repro_level == 3:
|
||||
dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
|
||||
dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name)
|
||||
|
||||
# Check for either accuracy (level 4) or other type of failures.
|
||||
if config.repro_level == 4:
|
||||
|
|
@ -95,7 +96,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||
dump_to_minify_after_dynamo(
|
||||
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
|
||||
example_inputs,
|
||||
compiler_name,
|
||||
self._compiler_name,
|
||||
)
|
||||
exc = AccuracyError("Bad accuracy detected.")
|
||||
add_paths(exc)
|
||||
|
|
@ -110,7 +111,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||
)
|
||||
if config.repro_level == 1:
|
||||
dump_state_fn = functools.partial(
|
||||
dump_backend_state, compiler_name=compiler_name
|
||||
dump_backend_state, compiler_name=self._compiler_name
|
||||
)
|
||||
dump_state_fn(
|
||||
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
|
||||
|
|
@ -119,7 +120,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||
dump_to_minify_after_dynamo(
|
||||
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
|
||||
example_inputs,
|
||||
compiler_name,
|
||||
self._compiler_name,
|
||||
)
|
||||
add_paths(exc)
|
||||
raise
|
||||
|
|
@ -128,12 +129,17 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||
|
||||
return compiled_gm
|
||||
|
||||
debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
|
||||
if hasattr(unconfigured_compiler_fn, "compiler_name"):
|
||||
debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name
|
||||
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
|
||||
debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
|
||||
return debug_wrapper
|
||||
|
||||
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
||||
"""
|
||||
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
|
||||
As opposed to wrap_compiler_debug, this wrapper intercepts at the
|
||||
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
|
||||
level, e.g., it is useful for minifying issues related to Aot Autograd
|
||||
tracing. If an error is found, we minify and save the minified repro in
|
||||
repro.tar.gz.
|
||||
"""
|
||||
return WrapBackendDebug(unconfigured_compiler_fn, compiler_name)
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user