pytorch/torch/_dynamo/backends/common.py
James Wu be2ad70cfa Fix dynamo tracing into AOTAutogradCache results in cpu tensors (#155251)
On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable:
05dd638ee9/torch/_dynamo/backends/common.py (L76)
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already.

```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

```

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical.

Fixes https://github.com/pytorch/pytorch/issues/154536

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155251
Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
2025-06-09 02:06:16 +00:00

168 lines
5.4 KiB
Python

# mypy: ignore-errors
"""
This module provides common utilities and base classes for TorchDynamo backends.
Key components:
- AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends
- Backend utilities for handling:
- Fake tensor conversion
- Device/dtype detection from inputs
- Memory efficient fusion
- Graph flattening
- Common compiler configurations
The utilities here are used by various backend implementations to handle
common operations and provide consistent behavior across different backends.
AOT autograd functionality is particularly important as it enables ahead-of-time
optimization of both forward and backward passes.
"""
import contextlib
import functools
import logging
from unittest.mock import patch
import torch
from torch._dynamo import disable
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
from torch._functorch.aot_autograd import (
aot_module_simplified,
SerializableAOTDispatchCompiler,
)
from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
class AotAutograd:
def __init__(self, **kwargs) -> None:
self.__name__ = "compiler_fn"
self.kwargs = kwargs
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
if kwargs:
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
return flatten_graph_inputs(
gm,
example_inputs,
self,
)
# Hack to get around circular import problems with aot_eager_decomp_partition
if callable(self.kwargs.get("decompositions")):
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
# NB: dont delete counter increment
counters["aot_autograd"]["total"] += 1
use_fallback = False
if use_fallback:
log.debug("Unable to use AOT Autograd because graph has mutation")
counters["aot_autograd"]["not_ok"] += 1
return gm
def wrap_bw_compiler(bw_compiler_fn):
def _wrapped_bw_compiler(*args, **kwargs):
# Note [Wrapping bw_compiler in disable]
# The two disables here:
# - stop TorchDynamo from trying to compile the bw_compiler function itself
# - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
return disable(
disable(
bw_compiler_fn, reason="do not trace backward compiler function"
)(*args, **kwargs),
reason="do not trace generated backwards pass",
)
return _wrapped_bw_compiler
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
else:
bw_compiler = wrap_bw_compiler(bw_compiler)
self.kwargs["bw_compiler"] = bw_compiler
self.kwargs["inference_compiler"] = (
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
)
from functorch.compile import nop
from torch._inductor.debug import enable_aot_logging
# debug asserts slow down compile time noticeably,
# So only default them on when the aot_eager backend is used.
if self.kwargs.get("fw_compiler", None) == nop:
patch_config = patch("functorch.compile.config.debug_assert", True)
else:
patch_config = contextlib.nullcontext()
try:
# NB: NOT cloned!
with enable_aot_logging(), patch_config:
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
counters["aot_autograd"]["ok"] += 1
return disable(cg, reason="do not trace AOT-compiled graph")
except TensorifyScalarRestartAnalysis:
raise
except Exception:
counters["aot_autograd"]["not_ok"] += 1
raise
def aot_autograd(**kwargs) -> AotAutograd:
return AotAutograd(**kwargs)
def mem_efficient_fusion_kwargs(use_decomps):
from functorch.compile import (
default_decompositions,
min_cut_rematerialization_partition,
ts_compile,
)
kwargs = {
# these are taken from memory_efficient_fusion()
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
}
if use_decomps:
kwargs["decompositions"] = default_decompositions
return kwargs
def fake_tensor_unsupported(fn):
"""
Decorator for backends that need real inputs. We swap out fake
tensors for zero tensors.
"""
@functools.wraps(fn)
def wrapper(model, inputs, **kwargs):
with _disable_current_modes():
inputs = list(map(defake, inputs))
return fn(model, inputs, **kwargs)
return wrapper
def device_from_inputs(example_inputs) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device
def dtype_from_inputs(example_inputs) -> torch.dtype:
for x in example_inputs:
if hasattr(x, "dtype"):
return x.dtype