[compile] Switch off inference mode during compilation (#149321)

PR does following
* Turns `inference_mode` to False and `no_grad` for `convert_frame`, if the inference_mode is on globally.
* Turns off inference_mode for fake tensor prop. This ensures that converting from real inference tensor to a fake tensor removes the inference-ness.
* Graph breaks on is_inference and is_inference_mode_enabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149321
Approved by: https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
Animesh Jain 2025-03-18 15:42:42 -07:00 committed by PyTorch MergeBot
parent 04e251a7dd
commit a3c286677b
14 changed files with 159 additions and 21 deletions

View File

@ -1003,13 +1003,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
fn = torch.Tensor.dim
return fn(x + 1)
@make_test
def test_tensor_is_inference(x):
if x.is_inference():
return x + 1
else:
return x - 1
def test_is_inference_recompilation(self):
def fn(x):
if x.is_inference():
@ -1021,8 +1014,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
x_inference = torch.randn(2, 2)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=False)
x = torch.randn(2, 2)
self.assertEqual(fn(x), opt_fn(x))
@ -1031,6 +1023,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
self.assertEqual(fn(x_inference), opt_fn(x_inference))
self.assertEqual(cnts.frame_count, 2) # Recompiles
def test_is_inference_mode_global_recompilation(self):
def fn(x):
if torch.is_inference_mode_enabled():
return x + 1
else:
return x - 1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts, fullgraph=False)
x = torch.randn(2, 2)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(cnts.frame_count, 1)
@make_test
def test_get_privateuse1_name(x):
if torch._C._get_privateuse1_backend_name() == "privateuseone":

View File

@ -12647,6 +12647,9 @@ class TestAutogradInferenceMode(TestCase):
self.assertFalse(func_out.requires_grad)
self.assertTrue(func_out.is_leaf)
@skipIfTorchDynamo(
"exception from ill-formed graph module is not propagated with eager_noexcept"
)
def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
def run_test(fn):
for requires_grad in (False, True):

View File

@ -534,6 +534,10 @@ fake_tensor_cache_crosscheck_enabled = (
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1"
)
# Disables inference mode for fake tensor prop during compilation. At runtime,
# the inference_mode is still respected.
fake_tensor_disable_inference_mode = True
# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile().
# Note: AOT Autograd will still trace joint graphs.
compiled_autograd = False

View File

@ -138,6 +138,8 @@ from .utils import (
is_namedtuple,
istype,
LazyString,
maybe_disable_inference_mode,
maybe_disable_inference_mode_for_fake_prop,
orig_code_map,
reset_graph_break_dup_checker,
setup_compile_debug,
@ -227,11 +229,16 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
guards = GlobalStateGuard()
prior_grad_mode = torch.is_grad_enabled()
# Just in case we get left in a bad dispatch state we want to restore
# it. This can happen because the dispatch bits aren't a true
# stack/counter - so we can't just increment/decrement them as we enter
# and leave.
with torch._C._PreserveDispatchKeyGuard():
with (
torch._C._PreserveDispatchKeyGuard(),
maybe_disable_inference_mode(),
maybe_disable_inference_mode_for_fake_prop(),
):
prior_inference_mode = torch.is_inference_mode_enabled()
prior_deterministic = torch.are_deterministic_algorithms_enabled()
prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()

View File

@ -19,3 +19,8 @@ SUPPORTABLE = [
CAUSED_BY_EARLIER_GRAPH_BREAK = [
"This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.",
]
INFERENCE_MODE = [
"Avoid using `tensor.is_inference()` and `torch.is_inference_mode_enabled()` in your compile code. "
"This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead ",
" because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used.",
]

View File

@ -4474,3 +4474,41 @@ def get_optimize_ddp_mode():
f"Invalid dynamo config optimize_ddp value {mode=}"
)
return mode
@contextmanager
def maybe_disable_inference_mode() -> Generator[None, None, None]:
"""
Disables torch.inference_mode for the compilation (still on at runtime).
This simplifies the compile stack where we can assume that inference_mode
will always be off.
Since inference_mode is equivalent to no_grad + some optimizations (version
counts etc), we turn on no_grad here. The other optimizations are not
relevant to torch.compile.
"""
is_inference_mode_on = (
config.fake_tensor_disable_inference_mode and torch.is_inference_mode_enabled()
)
if is_inference_mode_on:
with (
torch.inference_mode(False),
torch.no_grad(),
):
yield
else:
yield
@contextmanager
def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
"""
Turns off tracking of inference_mode for fake tensor propagation. With this
context manager, when a real tensor is converted to fake tensor, the fake
tensor looses its inference-ness.
"""
if config.fake_tensor_disable_inference_mode:
with torch._subclasses.meta_utils.disable_inference_mode_for_fake_prop():
yield
else:
yield

View File

@ -1795,6 +1795,7 @@ class VariableBuilder:
example_value = wrap_to_fake_tensor_and_record(
value, tx=self.tx, is_tensor=True, source=source
)
tensor_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(value),
@ -3029,6 +3030,7 @@ def wrap_to_fake_tensor_and_record(
symbolic_context,
type(e),
)
fake_e = wrap_fake_exception(
lambda: tx.fake_mode.from_tensor(
e,

View File

@ -598,10 +598,27 @@ class InferenceModeVariable(ContextWrappingVariable):
)
def enter(self, tx):
disabled_inference_mode_forcibly = False
if (
torch._dynamo.config.fake_tensor_disable_inference_mode
and self.target_values[0]
):
# Do not set the inference mode because we keep it off during
# compilation. Set the grad_enabled to False to reflect the relevant
# part of inference_mode to torch.compile.
disabled_inference_mode_forcibly = True
prior = torch.is_grad_enabled()
torch._C._set_grad_enabled(False)
else:
ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
self.set_cleanup_hook(
tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
)
def cleanup_hook():
if disabled_inference_mode_forcibly:
torch._C._set_grad_enabled(prior)
else:
torch.autograd.grad_mode._exit_inference_mode(ctx)
self.set_cleanup_hook(tx, cleanup_hook)
self.state.proxy = tx.output.create_node(
"call_function",
torch.autograd.grad_mode._enter_inference_mode,

View File

@ -43,10 +43,11 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, variables
from .. import config, graph_break_hints, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import (
unimplemented,
unimplemented_v2,
UnknownPropertiesDuringBackwardTrace,
UserError,
UserErrorType,
@ -708,6 +709,16 @@ class TensorVariable(VariableTracker):
return ConstantVariable.create(self.dtype.is_floating_point)
def method_is_inference(self):
if config.fake_tensor_disable_inference_mode:
unimplemented_v2(
gb_type="Encountered tensor.is_inference() during tracing",
context="",
explanation="tensor.is_inference() is not supported",
hints=[
*graph_break_hints.FUNDAMENTAL,
*graph_break_hints.INFERENCE_MODE,
],
)
if (fake := self.proxy.node.meta.get("example_value")) is not None:
return ConstantVariable.create(fake.is_inference())

View File

@ -44,7 +44,7 @@ from torch._guards import TracingContext
from torch._logging import warning_once
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
from .. import config, polyfills, variables
from .. import config, graph_break_hints, polyfills, variables
from ..codegen import PyCodegen
from ..create_parameter_op import (
can_convert_to_tracable_parameter,
@ -52,7 +52,7 @@ from ..create_parameter_op import (
tracable_create_parameter,
)
from ..device_interface import get_registered_device_interfaces
from ..exc import unimplemented
from ..exc import unimplemented, unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import CallFunctionNoArgsSource, SyntheticLocalSource
from ..utils import (
@ -519,6 +519,18 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
VariableTracker.build(tx, polyfills.radians), args, kwargs
)
@register(torch.is_inference_mode_enabled)
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
unimplemented_v2(
gb_type="Encountered torch.is_inference_mode_enabled during tracing",
context="",
explanation="torch.is_inference_mode_enabled() is not supported",
hints=[
*graph_break_hints.FUNDAMENTAL,
*graph_break_hints.INFERENCE_MODE,
],
)
@register(torch.is_tensor, torch.overrides.is_tensor_like)
def handle_is_tensor(self, tx: "InstructionTranslator", arg):
if isinstance(arg, TensorVariable) or (

View File

@ -2120,6 +2120,7 @@ class FakeTensorMode(TorchDispatchMode):
and len(flat_arg_fake_tensors) != 0
and not has_symbolic_sizes
and not avoiding_device_init
and func is not aten._nested_tensor_from_tensor_list.default
):
const_flat_args = [
a.constant if self.is_our_fake(a) else a for a in flat_args

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import contextlib
import dataclasses
import functools
import threading
import typing
import warnings
import weakref
from abc import abstractmethod
from contextlib import AbstractContextManager
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass
from typing import (
Any,
@ -45,6 +46,8 @@ from torch.utils.weak import WeakIdKeyDictionary
if TYPE_CHECKING:
from collections.abc import Generator
from torch._C._functorch import CInterpreter
from torch._guards import Source
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -92,6 +95,23 @@ def assert_eq(a: _T, b: _T) -> None:
assert a == b, f"{a} != {b}"
tls = threading.local()
# Turns off inference mode for fake tensor propagation. This is turned to True
# only for `torch.compile`. Also look at
# _dynamo.config.fake_tensor_disable_inference_mode
tls.disable_inference_mode = False
@contextmanager
def disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
prior = getattr(tls, "disable_inference_mode", False)
tls.disable_inference_mode = True
try:
yield
finally:
tls.disable_inference_mode = prior
def assert_metadata_eq(
assert_eq: Callable[[object, object], None],
m1: Union[MetaTensorDesc, torch.Tensor],
@ -116,7 +136,10 @@ def assert_metadata_eq(
# MetaTensorDesc doesn't store grad_fn; inferred from leaf
# assert_eq(m1.grad_fn is None, m2.grad_fn is None)
assert_eq(m1.is_sparse, m2.is_sparse)
if not getattr(tls, "disable_inference_mode", False):
assert_eq(m1.is_inference, m2.is_inference())
else:
assert_eq(m1.is_inference, False)
assert_eq(m1.is_conj, m2.is_conj())
assert_eq(m1.is_neg, m2.is_neg())
assert_eq(m1.grad is not None, safe_grad(m2) is not None)
@ -354,10 +377,11 @@ class MetaTensorDescriber:
# TODO: Is it important to enable torch.inference_mode before querying
# these values?
is_inference_mode_disabled = getattr(tls, "disable_inference_mode", False)
r: MetaTensorDesc = MetaTensorDesc(
id=self.get_tensor_id(t),
storage=storage,
is_inference=t.is_inference(),
is_inference=False if is_inference_mode_disabled else t.is_inference(),
is_leaf=is_leaf,
requires_grad=t.requires_grad,
# NB: ndim should be OK too but there is a disaster at

View File

@ -299,7 +299,14 @@ def foreach_all_gather_copy_out(
out = [t.view(world_size, -1) for t in split_with_sizes_out]
# only avoid VC bump if we are not in inference mode
if torch._dynamo.is_compiling():
# For torch.compile, we turn off inference_mode for fake tensor
# propagation, and therefore graph break on is_inference. For `compile`,
# we don't care about VCs, so just skip the optimization.
non_inference_outs = []
else:
non_inference_outs = [o for o in out if not o.is_inference()]
if len(non_inference_outs) > 0:
with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)):
torch.ops.fsdp.split_with_sizes_copy(