mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Internal failed because of torch.deploy issues with disable_dynamo in fx/* and _jit/* files. Removing disable_dynamo for both. Added a comment in the code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104664 Approved by: https://github.com/wconstab
This commit is contained in:
parent
d3589c9456
commit
0444f9f85b
|
|
@ -1336,6 +1336,15 @@ for name in dir(_C._VariableFunctions):
|
||||||
if not name.startswith("_"):
|
if not name.startswith("_"):
|
||||||
__all__.append(name)
|
__all__.append(name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
# needs to be before from .functional import * to avoid circular dependencies
|
||||||
|
from ._compile import _disable_dynamo
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Import interface functions defined in Python
|
# Import interface functions defined in Python
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
|
||||||
30
torch/_compile.py
Normal file
30
torch/_compile.py
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
"""
|
||||||
|
APIs related to torch.compile which lazily import torch._dynamo to avoid
|
||||||
|
circular dependencies.
|
||||||
|
"""
|
||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_dynamo(fn=None, recursive=True):
|
||||||
|
"""
|
||||||
|
This API should be only used inside torch, external users should still use
|
||||||
|
torch._dynamo.disable. The main goal of this API is to avoid circular
|
||||||
|
imports issues that is common while using _dynamo.disable inside torch
|
||||||
|
itself.
|
||||||
|
|
||||||
|
This API avoids it by lazily importing torch._dynamo from the import time to
|
||||||
|
the invocation of the decorated function.
|
||||||
|
"""
|
||||||
|
if fn is not None:
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
import torch._dynamo
|
||||||
|
|
||||||
|
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
else:
|
||||||
|
# decorator usage like @_disable_dynamo(recursive=False). The resulting
|
||||||
|
# object expects the original decorated function as the arg.
|
||||||
|
return functools.partial(_disable_dynamo, recursive=recursive)
|
||||||
|
|
@ -57,7 +57,6 @@ from .utils import compile_times
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch.fx.experimental import proxy_tensor
|
|
||||||
|
|
||||||
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
|
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
|
||||||
null_context = contextlib.nullcontext
|
null_context = contextlib.nullcontext
|
||||||
|
|
@ -1212,33 +1211,23 @@ class TorchPatcher:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def patch():
|
def patch():
|
||||||
# Disable TorchDynamo on some torch.* compilers generated frames
|
# A better way to disable the following would be decorate the source
|
||||||
|
# functions with @torch._disable_dynamo. However, this causes issues
|
||||||
|
# with torch.deploy internally.
|
||||||
torch.jit.trace = disable(torch.jit.trace)
|
torch.jit.trace = disable(torch.jit.trace)
|
||||||
torch.jit.trace_module = disable(torch.jit.trace_module)
|
torch.jit.trace_module = disable(torch.jit.trace_module)
|
||||||
torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
|
torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
|
||||||
|
|
||||||
# symbolic_trace creates new frames. We disable Dynamo on such frames
|
|
||||||
torch.fx._symbolic_trace.Tracer.trace = disable(
|
torch.fx._symbolic_trace.Tracer.trace = disable(
|
||||||
torch.fx._symbolic_trace.Tracer.trace
|
torch.fx._symbolic_trace.Tracer.trace
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
|
|
||||||
torch.distributions.Distribution.set_default_validate_args(False)
|
torch.distributions.Distribution.set_default_validate_args(False)
|
||||||
|
|
||||||
proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
|
|
||||||
|
|
||||||
optimizers = [
|
optimizers = [
|
||||||
opt
|
opt
|
||||||
for opt in torch.optim.__dict__.values()
|
for opt in torch.optim.__dict__.values()
|
||||||
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
|
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
|
||||||
]
|
]
|
||||||
|
|
||||||
# disable dynamo for the wrapper that helps give dynamo hints about entering DDP
|
|
||||||
if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
|
|
||||||
DistributedDataParallel._inside_ddp_forward = disable(
|
|
||||||
DistributedDataParallel._inside_ddp_forward, recursive=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: this excludes the optimizers that are unsupported in excluded_opts below
|
# Note: this excludes the optimizers that are unsupported in excluded_opts below
|
||||||
from ..optim import (
|
from ..optim import (
|
||||||
adadelta,
|
adadelta,
|
||||||
|
|
@ -1284,11 +1273,6 @@ class TorchPatcher:
|
||||||
if opt in excluded_opts:
|
if opt in excluded_opts:
|
||||||
opt.step = disable(opt.step)
|
opt.step = disable(opt.step)
|
||||||
|
|
||||||
opt.zero_grad = disable(opt.zero_grad)
|
|
||||||
opt.state_dict = disable(opt.state_dict)
|
|
||||||
opt.load_state_dict = disable(opt.load_state_dict)
|
|
||||||
opt.add_param_group = disable(opt.add_param_group)
|
|
||||||
|
|
||||||
if hasattr(opt, "_init_group"):
|
if hasattr(opt, "_init_group"):
|
||||||
opt._init_group = disable(opt._init_group)
|
opt._init_group = disable(opt._init_group)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -459,6 +459,7 @@ class PythonKeyTracer(Tracer):
|
||||||
return super().create_arg(a)
|
return super().create_arg(a)
|
||||||
|
|
||||||
|
|
||||||
|
@torch._disable_dynamo
|
||||||
def dispatch_trace(
|
def dispatch_trace(
|
||||||
root: Union[torch.nn.Module, Callable],
|
root: Union[torch.nn.Module, Callable],
|
||||||
tracer: Tracer,
|
tracer: Tracer,
|
||||||
|
|
|
||||||
|
|
@ -1344,6 +1344,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||||
# for the 'module_to_run' underneath
|
# for the 'module_to_run' underneath
|
||||||
# see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
|
# see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@torch._disable_dynamo(recursive=False)
|
||||||
def _inside_ddp_forward(self):
|
def _inside_ddp_forward(self):
|
||||||
DistributedDataParallel._active_ddp_module = self
|
DistributedDataParallel._active_ddp_module = self
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1231,6 +1231,7 @@ def _model_to_graph(
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
@_beartype.beartype
|
||||||
|
@torch._disable_dynamo
|
||||||
def export_to_pretty_string(
|
def export_to_pretty_string(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
|
|
|
||||||
|
|
@ -382,6 +382,7 @@ class Optimizer:
|
||||||
self._optimizer_step_post_hooks[handle.id] = hook
|
self._optimizer_step_post_hooks[handle.id] = hook
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
@torch._disable_dynamo
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
r"""Returns the state of the optimizer as a :class:`dict`.
|
r"""Returns the state of the optimizer as a :class:`dict`.
|
||||||
|
|
||||||
|
|
@ -439,6 +440,7 @@ class Optimizer:
|
||||||
else:
|
else:
|
||||||
return value.to(device=param.device)
|
return value.to(device=param.device)
|
||||||
|
|
||||||
|
@torch._disable_dynamo
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
r"""Loads the optimizer state.
|
r"""Loads the optimizer state.
|
||||||
|
|
||||||
|
|
@ -495,6 +497,7 @@ class Optimizer:
|
||||||
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
||||||
self.__setstate__({'state': state, 'param_groups': param_groups})
|
self.__setstate__({'state': state, 'param_groups': param_groups})
|
||||||
|
|
||||||
|
@torch._disable_dynamo
|
||||||
def zero_grad(self, set_to_none: bool = True):
|
def zero_grad(self, set_to_none: bool = True):
|
||||||
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
|
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
|
||||||
|
|
||||||
|
|
@ -549,6 +552,7 @@ class Optimizer:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@torch._disable_dynamo
|
||||||
def add_param_group(self, param_group):
|
def add_param_group(self, param_group):
|
||||||
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
|
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user