mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
## save&load support for OptimizedModule [Issue Description](https://github.com/pytorch/pytorch/pull/101651) English is not my native language; please excuse typing errors. This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\ I'll do something with the merge conflicts later ### test result for test/dynamo Conclusion:\ It performs the same as before as far as I can see. ENV(CPU only):\ platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\ configfile: pytest.ini\ plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0 #### before this pr: [before](https://github.com/pytorch/pytorch/files/15329370/before.md) #### after this pr: [after](https://github.com/pytorch/pytorch/files/15329376/after.md) ### some changes 1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'" 2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling 3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py ) 4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374 Approved by: https://github.com/msaroufim, https://github.com/jansel
126 lines
3.7 KiB
Python
126 lines
3.7 KiB
Python
# mypy: ignore-errors
|
|
|
|
import contextlib
|
|
import functools
|
|
import logging
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch._dynamo import disable
|
|
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
from torch.utils._python_dispatch import _disable_current_modes
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class AotAutograd:
|
|
def __init__(self, **kwargs):
|
|
self.__name__ = "compiler_fn"
|
|
self.kwargs = kwargs
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
|
|
return flatten_graph_inputs(
|
|
gm,
|
|
example_inputs,
|
|
self,
|
|
)
|
|
|
|
# Hack to get around circular import problems with aot_eager_decomp_partition
|
|
if callable(self.kwargs.get("decompositions")):
|
|
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
|
|
|
|
# NB: dont delete counter increment
|
|
counters["aot_autograd"]["total"] += 1
|
|
use_fallback = False
|
|
|
|
if use_fallback:
|
|
log.debug("Unable to use AOT Autograd because graph has mutation")
|
|
counters["aot_autograd"]["not_ok"] += 1
|
|
return gm
|
|
|
|
# OK attempt to compile
|
|
|
|
def _wrapped_bw_compiler(*args, **kwargs):
|
|
# stop TorchDynamo from trying to compile our generated backwards pass
|
|
return disable(disable(bw_compiler)(*args, **kwargs))
|
|
|
|
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
|
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
self.kwargs["inference_compiler"] = (
|
|
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
|
)
|
|
|
|
from functorch.compile import nop
|
|
|
|
from torch._inductor.debug import enable_aot_logging
|
|
|
|
# debug asserts slow down compile time noticeably,
|
|
# So only default them on when the aot_eager backend is used.
|
|
if self.kwargs.get("fw_compiler", None) == nop:
|
|
patch_config = patch("functorch.compile.config.debug_assert", True)
|
|
else:
|
|
patch_config = contextlib.nullcontext()
|
|
|
|
try:
|
|
# NB: NOT cloned!
|
|
with enable_aot_logging(), patch_config:
|
|
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
|
|
counters["aot_autograd"]["ok"] += 1
|
|
return disable(cg)
|
|
except Exception:
|
|
counters["aot_autograd"]["not_ok"] += 1
|
|
raise
|
|
|
|
|
|
def aot_autograd(**kwargs):
|
|
return AotAutograd(**kwargs)
|
|
|
|
|
|
def mem_efficient_fusion_kwargs(use_decomps):
|
|
from functorch.compile import (
|
|
default_decompositions,
|
|
min_cut_rematerialization_partition,
|
|
ts_compile,
|
|
)
|
|
|
|
kwargs = {
|
|
# these are taken from memory_efficient_fusion()
|
|
"fw_compiler": ts_compile,
|
|
"bw_compiler": ts_compile,
|
|
"partition_fn": min_cut_rematerialization_partition,
|
|
}
|
|
|
|
if use_decomps:
|
|
kwargs["decompositions"] = default_decompositions
|
|
|
|
return kwargs
|
|
|
|
|
|
def fake_tensor_unsupported(fn):
|
|
"""
|
|
Decorator for backends that need real inputs. We swap out fake
|
|
tensors for zero tensors.
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(model, inputs, **kwargs):
|
|
with _disable_current_modes():
|
|
inputs = list(map(defake, inputs))
|
|
return fn(model, inputs, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def device_from_inputs(example_inputs) -> torch.device:
|
|
for x in example_inputs:
|
|
if hasattr(x, "device"):
|
|
return x.device
|
|
|
|
|
|
def dtype_from_inputs(example_inputs) -> torch.dtype:
|
|
for x in example_inputs:
|
|
if hasattr(x, "dtype"):
|
|
return x.dtype
|