mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable:
05dd638ee9/torch/_dynamo/backends/common.py (L76)
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.
On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already.
```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]
```
This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical.
Fixes https://github.com/pytorch/pytorch/issues/154536
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155251
Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
168 lines
5.4 KiB
Python
168 lines
5.4 KiB
Python
# mypy: ignore-errors
|
|
|
|
"""
|
|
This module provides common utilities and base classes for TorchDynamo backends.
|
|
|
|
Key components:
|
|
- AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends
|
|
- Backend utilities for handling:
|
|
- Fake tensor conversion
|
|
- Device/dtype detection from inputs
|
|
- Memory efficient fusion
|
|
- Graph flattening
|
|
- Common compiler configurations
|
|
|
|
The utilities here are used by various backend implementations to handle
|
|
common operations and provide consistent behavior across different backends.
|
|
AOT autograd functionality is particularly important as it enables ahead-of-time
|
|
optimization of both forward and backward passes.
|
|
"""
|
|
|
|
import contextlib
|
|
import functools
|
|
import logging
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch._dynamo import disable
|
|
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
|
|
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
|
|
from torch._functorch.aot_autograd import (
|
|
aot_module_simplified,
|
|
SerializableAOTDispatchCompiler,
|
|
)
|
|
from torch.utils._python_dispatch import _disable_current_modes
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class AotAutograd:
|
|
def __init__(self, **kwargs) -> None:
|
|
self.__name__ = "compiler_fn"
|
|
self.kwargs = kwargs
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
|
|
if kwargs:
|
|
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
|
|
|
|
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
|
|
return flatten_graph_inputs(
|
|
gm,
|
|
example_inputs,
|
|
self,
|
|
)
|
|
|
|
# Hack to get around circular import problems with aot_eager_decomp_partition
|
|
if callable(self.kwargs.get("decompositions")):
|
|
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
|
|
|
|
# NB: dont delete counter increment
|
|
counters["aot_autograd"]["total"] += 1
|
|
use_fallback = False
|
|
|
|
if use_fallback:
|
|
log.debug("Unable to use AOT Autograd because graph has mutation")
|
|
counters["aot_autograd"]["not_ok"] += 1
|
|
return gm
|
|
|
|
def wrap_bw_compiler(bw_compiler_fn):
|
|
def _wrapped_bw_compiler(*args, **kwargs):
|
|
# Note [Wrapping bw_compiler in disable]
|
|
# The two disables here:
|
|
# - stop TorchDynamo from trying to compile the bw_compiler function itself
|
|
# - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
|
|
return disable(
|
|
disable(
|
|
bw_compiler_fn, reason="do not trace backward compiler function"
|
|
)(*args, **kwargs),
|
|
reason="do not trace generated backwards pass",
|
|
)
|
|
|
|
return _wrapped_bw_compiler
|
|
|
|
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
|
|
|
if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
|
|
bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
|
|
else:
|
|
bw_compiler = wrap_bw_compiler(bw_compiler)
|
|
|
|
self.kwargs["bw_compiler"] = bw_compiler
|
|
self.kwargs["inference_compiler"] = (
|
|
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
|
)
|
|
|
|
from functorch.compile import nop
|
|
from torch._inductor.debug import enable_aot_logging
|
|
|
|
# debug asserts slow down compile time noticeably,
|
|
# So only default them on when the aot_eager backend is used.
|
|
if self.kwargs.get("fw_compiler", None) == nop:
|
|
patch_config = patch("functorch.compile.config.debug_assert", True)
|
|
else:
|
|
patch_config = contextlib.nullcontext()
|
|
|
|
try:
|
|
# NB: NOT cloned!
|
|
with enable_aot_logging(), patch_config:
|
|
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
|
|
counters["aot_autograd"]["ok"] += 1
|
|
return disable(cg, reason="do not trace AOT-compiled graph")
|
|
except TensorifyScalarRestartAnalysis:
|
|
raise
|
|
except Exception:
|
|
counters["aot_autograd"]["not_ok"] += 1
|
|
raise
|
|
|
|
|
|
def aot_autograd(**kwargs) -> AotAutograd:
|
|
return AotAutograd(**kwargs)
|
|
|
|
|
|
def mem_efficient_fusion_kwargs(use_decomps):
|
|
from functorch.compile import (
|
|
default_decompositions,
|
|
min_cut_rematerialization_partition,
|
|
ts_compile,
|
|
)
|
|
|
|
kwargs = {
|
|
# these are taken from memory_efficient_fusion()
|
|
"fw_compiler": ts_compile,
|
|
"bw_compiler": ts_compile,
|
|
"partition_fn": min_cut_rematerialization_partition,
|
|
}
|
|
|
|
if use_decomps:
|
|
kwargs["decompositions"] = default_decompositions
|
|
|
|
return kwargs
|
|
|
|
|
|
def fake_tensor_unsupported(fn):
|
|
"""
|
|
Decorator for backends that need real inputs. We swap out fake
|
|
tensors for zero tensors.
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(model, inputs, **kwargs):
|
|
with _disable_current_modes():
|
|
inputs = list(map(defake, inputs))
|
|
return fn(model, inputs, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def device_from_inputs(example_inputs) -> torch.device:
|
|
for x in example_inputs:
|
|
if hasattr(x, "device"):
|
|
return x.device
|
|
|
|
|
|
def dtype_from_inputs(example_inputs) -> torch.dtype:
|
|
for x in example_inputs:
|
|
if hasattr(x, "dtype"):
|
|
return x.dtype
|