Revert "[dynamo] Lazy disable_dynamo API out-of-dynamo (#104317)"

This reverts commit 5c12a810ac.

Reverted https://github.com/pytorch/pytorch/pull/104317 on behalf of https://github.com/huydhn due to This has been reverted internally by D47166892, so I need to also revert it on OSS to keep them in sync ([comment](https://github.com/pytorch/pytorch/pull/104317#issuecomment-1621099151))
This commit is contained in:
PyTorch MergeBot 2023-07-05 06:21:48 +00:00
parent 40f53912cf
commit 54e320d4d1
9 changed files with 22 additions and 49 deletions

View File

@ -1336,15 +1336,6 @@ for name in dir(_C._VariableFunctions):
if not name.startswith("_"):
__all__.append(name)
################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices
################################################################################
# needs to be before from .functional import * to avoid circular dependencies
from ._compile import _disable_dynamo
################################################################################
# Import interface functions defined in Python
################################################################################

View File

@ -1,30 +0,0 @@
"""
APIs related to torch.compile which lazily import torch._dynamo to avoid
circular dependencies.
"""
import functools
def _disable_dynamo(fn=None, recursive=True):
"""
This API should be only used inside torch, external users should still use
torch._dynamo.disable. The main goal of this API is to avoid circular
imports issues that is common while using _dynamo.disable inside torch
itself.
This API avoids it by lazily importing torch._dynamo from the import time to
the invocation of the decorated function.
"""
if fn is not None:
@functools.wraps(fn)
def inner(*args, **kwargs):
import torch._dynamo
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
return inner
else:
# decorator usage like @_disable_dynamo(recursive=False). The resulting
# object expects the original decorated function as the arg.
return functools.partial(_disable_dynamo, recursive=recursive)

View File

@ -57,6 +57,7 @@ from .utils import compile_times
log = logging.getLogger(__name__)
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.experimental import proxy_tensor
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
null_context = contextlib.nullcontext
@ -1213,15 +1214,31 @@ class TorchPatcher:
def patch():
# Disable TorchDynamo on some torch.* compilers generated frames
torch.jit.trace = disable(torch.jit.trace)
torch.jit.trace_module = disable(torch.jit.trace_module)
torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
# symbolic_trace creates new frames. We disable Dynamo on such frames
torch.fx._symbolic_trace.Tracer.trace = disable(
torch.fx._symbolic_trace.Tracer.trace
)
torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
torch.distributions.Distribution.set_default_validate_args(False)
proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
optimizers = [
opt
for opt in torch.optim.__dict__.values()
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
]
# disable dynamo for the wrapper that helps give dynamo hints about entering DDP
if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
DistributedDataParallel._inside_ddp_forward = disable(
DistributedDataParallel._inside_ddp_forward, recursive=False
)
# Note: this excludes the optimizers that are unsupported in excluded_opts below
from ..optim import (
adadelta,
@ -1269,6 +1286,11 @@ class TorchPatcher:
if opt in excluded_opts:
opt.step = disable(opt.step)
opt.zero_grad = disable(opt.zero_grad)
opt.state_dict = disable(opt.state_dict)
opt.load_state_dict = disable(opt.load_state_dict)
opt.add_param_group = disable(opt.add_param_group)
if hasattr(opt, "_init_group"):
opt._init_group = disable(opt._init_group)

View File

@ -693,7 +693,6 @@ class Tracer(TracerBase):
return root_fn, args
@compatibility(is_backward_compatible=True)
@torch._disable_dynamo
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],

View File

@ -459,7 +459,6 @@ class PythonKeyTracer(Tracer):
return super().create_arg(a)
@torch._disable_dynamo
def dispatch_trace(
root: Union[torch.nn.Module, Callable],
tracer: Tracer,

View File

@ -902,7 +902,6 @@ def trace(
_trace_module_map: Optional[Dict[Any, Any]] = None
@torch._disable_dynamo
def trace_module(
mod,
inputs,
@ -1230,7 +1229,6 @@ def _script_if_tracing(fn):
return wrapper
@torch._disable_dynamo
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
return_inputs=False, _return_inputs_states=False):
"""

View File

@ -1344,7 +1344,6 @@ class DistributedDataParallel(Module, Joinable):
# for the 'module_to_run' underneath
# see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
@contextmanager
@torch._disable_dynamo(recursive=False)
def _inside_ddp_forward(self):
DistributedDataParallel._active_ddp_module = self
try:

View File

@ -1225,7 +1225,6 @@ def _model_to_graph(
@_beartype.beartype
@torch._disable_dynamo
def export_to_pretty_string(
model,
args,

View File

@ -376,7 +376,6 @@ class Optimizer:
self._optimizer_step_post_hooks[handle.id] = hook
return handle
@torch._disable_dynamo
def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`.
@ -429,7 +428,6 @@ class Optimizer:
return value.to(device=param.device)
return value
@torch._disable_dynamo
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@ -486,7 +484,6 @@ class Optimizer:
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
@torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True):
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
@ -541,7 +538,6 @@ class Optimizer:
"""
raise NotImplementedError
@torch._disable_dynamo
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.