Opt model save and load (#126374)

## save&load support for OptimizedModule

[Issue Description](https://github.com/pytorch/pytorch/pull/101651)

English is not my native language; please excuse typing errors.

This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\
I'll do something with the merge conflicts later

### test result for test/dynamo

Conclusion:\
It performs the same as before as far as I can see.

ENV(CPU only):\
platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\
configfile: pytest.ini\
plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0

#### before this pr:

[before](https://github.com/pytorch/pytorch/files/15329370/before.md)

#### after this pr:

[after](https://github.com/pytorch/pytorch/files/15329376/after.md)

### some changes

1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'"
2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling
3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py )
4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374
Approved by: https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
weiyusheng 2024-06-05 13:01:16 +00:00 committed by PyTorch MergeBot
parent 9a8ab778d3
commit c3949b20a1
5 changed files with 203 additions and 75 deletions

View File

@ -3,6 +3,8 @@
import collections
import copy
import itertools
import os
import tempfile
import traceback
import types
import unittest
@ -16,6 +18,7 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.nn.functional as F
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
@ -2739,6 +2742,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10))
@unittest.skipIf(
"inductor" not in torch._dynamo.list_backends(),
"inductor backend is not available",
)
def test_save_and_load_inductor(self):
mod = MockModule()
opt_mod = torch.compile(mod, backend="inductor")
inp = torch.randn(10, 10)
opt_mod(inp)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(inp)
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
loaded_model(inp)
self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
def test_save_and_load_all_backends(self):
mod = MockModule()
inp = torch.randn(10, 10)
for backend in torch._dynamo.list_backends():
try:
opt_mod = torch.compile(mod, backend=backend)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
opt_mod(inp)
opt_success = torch._inductor.metrics.generated_kernel_count == 0
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
loaded_model(inp)
loaded_success = torch._inductor.metrics.generated_kernel_count == 0
self.assertEqual(opt_success, loaded_success)
except torch._dynamo.exc.BackendCompilerFailed:
pass
def test_monkeypatching_forward(self):
class FakeModule(torch.nn.Module):
def forward(self, x):

View File

@ -14,18 +14,22 @@ from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
def aot_autograd(**kwargs):
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
class AotAutograd:
def __init__(self, **kwargs):
self.__name__ = "compiler_fn"
self.kwargs = kwargs
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
return flatten_graph_inputs(
gm,
example_inputs,
compiler_fn,
self,
)
# Hack to get around circular import problems with aot_eager_decomp_partition
if callable(kwargs.get("decompositions")):
kwargs["decompositions"] = kwargs["decompositions"]()
if callable(self.kwargs.get("decompositions")):
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
# NB: dont delete counter increment
counters["aot_autograd"]["total"] += 1
@ -42,10 +46,10 @@ def aot_autograd(**kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs))
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
kwargs["inference_compiler"] = (
kwargs.get("inference_compiler") or kwargs["fw_compiler"]
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
self.kwargs["inference_compiler"] = (
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
)
from functorch.compile import nop
@ -54,7 +58,7 @@ def aot_autograd(**kwargs):
# debug asserts slow down compile time noticeably,
# So only default them on when the aot_eager backend is used.
if kwargs.get("fw_compiler", None) == nop:
if self.kwargs.get("fw_compiler", None) == nop:
patch_config = patch("functorch.compile.config.debug_assert", True)
else:
patch_config = contextlib.nullcontext()
@ -62,14 +66,16 @@ def aot_autograd(**kwargs):
try:
# NB: NOT cloned!
with enable_aot_logging(), patch_config:
cg = aot_module_simplified(gm, example_inputs, **kwargs)
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
counters["aot_autograd"]["ok"] += 1
return disable(cg)
except Exception:
counters["aot_autograd"]["not_ok"] += 1
raise
return compiler_fn
def aot_autograd(**kwargs):
return AotAutograd(**kwargs)
def mem_efficient_fusion_kwargs(use_decomps):

View File

@ -361,17 +361,34 @@ def cprofile_wrapper(func):
return profile_wrapper
def convert_frame_assert(
compiler_fn: CompilerFn,
one_graph: bool = True,
export: bool = False,
export_constraints=None,
):
"""Fully convert a frame into an FX graph"""
reset_graph_break_dup_checker()
class ConvertFrameAssert:
def __init__(
self,
compiler_fn: CompilerFn,
one_graph: bool = True,
export: bool = False,
export_constraints=None,
):
reset_graph_break_dup_checker()
self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
self._one_graph = one_graph
self._export = export
self._export_constraints = export_constraints
def _convert_frame_assert(
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0
@property
def _clone_with_backend(self):
return lambda backend: convert_frame_assert(
backend, self._one_graph, self._export, self._export_constraints
)
def __call__(
self,
frame: types.FrameType,
cache_entry,
hooks: Hooks,
frame_state,
*,
skip: int = 0,
):
increment_frame()
@ -458,10 +475,10 @@ def convert_frame_assert(
frame.f_globals,
frame.f_locals,
frame.f_builtins,
compiler_fn,
one_graph,
export,
export_constraints,
self._torchdynamo_orig_callable,
self._one_graph,
self._export,
self._export_constraints,
hooks,
cache_entry,
cache_size,
@ -471,13 +488,15 @@ def convert_frame_assert(
skip=skip + 1,
)
_convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
def _clone_with_backend(backend):
return convert_frame_assert(backend, one_graph, export, export_constraints)
_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined]
return _convert_frame_assert
def convert_frame_assert(
compiler_fn: CompilerFn,
one_graph: bool = True,
export: bool = False,
export_constraints=None,
):
"""Fully convert a frame into an FX graph"""
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
from collections import OrderedDict
@ -907,16 +926,27 @@ def _compile(
torch._dynamo.callback_handler.run_end_callbacks()
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
class ConvertFrame:
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks):
self._torchdynamo_orig_callable = compiler_fn
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
self._hooks = hooks
def _convert_frame(
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0
@property
def _clone_with_backend(self):
return lambda backend: convert_frame(backend, self._hooks)
def __call__(
self,
frame: types.FrameType,
cache_entry,
hooks: Hooks,
frame_state,
skip: int = 0,
):
counters["frames"]["total"] += 1
try:
result = inner_convert(
result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1
)
counters["frames"]["ok"] += 1
@ -980,9 +1010,10 @@ def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
log.warning(error_msg, exc_info=True)
return None
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
_convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined]
return _convert_frame
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
return ConvertFrame(compiler_fn, hooks)
# TODO mlazos: add support for same args, or record them
@ -1023,9 +1054,13 @@ def first_real_inst_idx(code):
raise RuntimeError("RESUME instruction not found in code")
def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_entry, frame_state):
class CatchErrorsWrapper:
def __init__(self, callback, hooks):
functools.wraps(callback)(self)
self._torchdynamo_orig_callable = callback
self.hooks = hooks
def __call__(self, frame, cache_entry, frame_state):
assert frame_state is not None
is_skipfile = trace_rules.check(frame.f_code)
@ -1063,19 +1098,26 @@ def catch_errors_wrapper(callback, hooks: Hooks):
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
backend_compile_fn=callback._torchdynamo_orig_callable,
backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable,
)
assert hasattr(
callback, "_clone_with_backend"
self._torchdynamo_orig_callable, "_clone_with_backend"
), "DDPOptimizer only supports callback fns that know how to clone themselves."
hijacked_callback = callback._clone_with_backend(
ddp_optimizer.compile_fn,
hijacked_callback = (
self._torchdynamo_orig_callable._clone_with_backend(
ddp_optimizer.compile_fn,
)
)
return hijacked_callback(
frame, cache_entry, self.hooks, frame_state
)
return hijacked_callback(frame, cache_entry, hooks, frame_state)
with compile_lock, _disable_current_modes():
# skip=1: skip this frame
return callback(frame, cache_entry, hooks, frame_state, skip=1)
return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1
)
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return catch_errors
def catch_errors_wrapper(callback, hooks: Hooks):
return CatchErrorsWrapper(callback, hooks)

View File

@ -168,6 +168,9 @@ class OptimizedModule(torch.nn.Module):
self._forward = self.forward
self.forward = self._call_lazy_check
def __reduce__(self):
return (self.__class__, (self._orig_mod, self.dynamo_ctx))
def __getstate__(self):
state = dict(self.__dict__)
state.pop("forward", None)
@ -273,9 +276,11 @@ class _TorchDynamoContext:
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback
self._backend_ctx_ctor = backend_ctx_ctor
self.prior: Union[Unset, DynamoCallback] = unset
self.first_ctx = first_ctx
self.export = export
self._dynamic = dynamic
self.compiler_config = compiler_config
self.cleanup_fns: List[Callable[[], Any]] = []
self.enter_exit_hooks = []
@ -379,7 +384,13 @@ class _TorchDynamoContext:
# call to a builtin without a frame for us to capture
fn = external_utils.wrap_inline(fn)
callback = self.callback
def do_nothing(*arg, **kwargs):
pass
if hasattr(self, "callback"):
callback = self.callback
else:
callback = do_nothing
is_jit_tracing = torch._C._is_tracing
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
@ -522,6 +533,17 @@ class OptimizeContext(_TorchDynamoContext):
self.enter_exit_hooks.append(call_compiled_autograd)
def __reduce__(self):
return (
self.__class__,
(self.callback, self._backend_ctx_ctor, self.first_ctx),
{
"export": self.export,
"dynamic": self._dynamic,
"compiler_config": self.compiler_config,
},
)
class RunOnlyContext(_TorchDynamoContext):
def __init__(self):
@ -531,6 +553,9 @@ class RunOnlyContext(_TorchDynamoContext):
super().__init__(callback=False, on_enter=on_enter)
def __reduce__(self):
return (self.__class__, ())
class DisableContext(_TorchDynamoContext):
def __init__(self):
@ -583,6 +608,9 @@ class DisableContext(_TorchDynamoContext):
return _fn
def __reduce__(self):
return (self.__class__, ())
def _optimize_catch_errors(
compile_fn,

View File

@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn):
)
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
"""
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
As opposed to wrap_compiler_debug, this wrapper intercepts at the
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
level, e.g., it is useful for minifying issues related to Aot Autograd
tracing. If an error is found, we minify and save the minified repro in
repro.tar.gz.
"""
class WrapBackendDebug:
def __init__(self, unconfigured_compiler_fn, compiler_name: str):
functools.wraps(unconfigured_compiler_fn)(self)
self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
self._compiler_name = compiler_name
if hasattr(unconfigured_compiler_fn, "__name__"):
self.__name__ = unconfigured_compiler_fn.__name__
if hasattr(unconfigured_compiler_fn, "compiler_name"):
self.__name__ = unconfigured_compiler_fn.compiler_name
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
@functools.wraps(unconfigured_compiler_fn)
def debug_wrapper(gm, example_inputs, **kwargs):
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
def __call__(self, gm, example_inputs, **kwargs):
compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs)
assert config.repro_after in ("dynamo", "aot", None)
if config.repro_after == "dynamo":
@ -82,7 +83,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
)
if config.repro_level == 3:
dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name)
# Check for either accuracy (level 4) or other type of failures.
if config.repro_level == 4:
@ -95,7 +96,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
dump_to_minify_after_dynamo(
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
example_inputs,
compiler_name,
self._compiler_name,
)
exc = AccuracyError("Bad accuracy detected.")
add_paths(exc)
@ -110,7 +111,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
)
if config.repro_level == 1:
dump_state_fn = functools.partial(
dump_backend_state, compiler_name=compiler_name
dump_backend_state, compiler_name=self._compiler_name
)
dump_state_fn(
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
@ -119,7 +120,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
dump_to_minify_after_dynamo(
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
example_inputs,
compiler_name,
self._compiler_name,
)
add_paths(exc)
raise
@ -128,12 +129,17 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
return compiled_gm
debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
if hasattr(unconfigured_compiler_fn, "compiler_name"):
debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
return debug_wrapper
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
"""
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
As opposed to wrap_compiler_debug, this wrapper intercepts at the
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
level, e.g., it is useful for minifying issues related to Aot Autograd
tracing. If an error is found, we minify and save the minified repro in
repro.tar.gz.
"""
return WrapBackendDebug(unconfigured_compiler_fn, compiler_name)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #