mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[compile] Switch off inference mode during compilation (#149321)
PR does following * Turns `inference_mode` to False and `no_grad` for `convert_frame`, if the inference_mode is on globally. * Turns off inference_mode for fake tensor prop. This ensures that converting from real inference tensor to a fake tensor removes the inference-ness. * Graph breaks on is_inference and is_inference_mode_enabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149321 Approved by: https://github.com/jansel, https://github.com/zou3519
This commit is contained in:
parent
04e251a7dd
commit
a3c286677b
|
|
@ -1003,13 +1003,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
fn = torch.Tensor.dim
|
||||
return fn(x + 1)
|
||||
|
||||
@make_test
|
||||
def test_tensor_is_inference(x):
|
||||
if x.is_inference():
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
def test_is_inference_recompilation(self):
|
||||
def fn(x):
|
||||
if x.is_inference():
|
||||
|
|
@ -1021,8 +1014,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
x_inference = torch.randn(2, 2)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=False)
|
||||
x = torch.randn(2, 2)
|
||||
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
|
@ -1031,6 +1023,21 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(fn(x_inference), opt_fn(x_inference))
|
||||
self.assertEqual(cnts.frame_count, 2) # Recompiles
|
||||
|
||||
def test_is_inference_mode_global_recompilation(self):
|
||||
def fn(x):
|
||||
if torch.is_inference_mode_enabled():
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=False)
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
@make_test
|
||||
def test_get_privateuse1_name(x):
|
||||
if torch._C._get_privateuse1_backend_name() == "privateuseone":
|
||||
|
|
|
|||
|
|
@ -12647,6 +12647,9 @@ class TestAutogradInferenceMode(TestCase):
|
|||
self.assertFalse(func_out.requires_grad)
|
||||
self.assertTrue(func_out.is_leaf)
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"exception from ill-formed graph module is not propagated with eager_noexcept"
|
||||
)
|
||||
def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
|
||||
def run_test(fn):
|
||||
for requires_grad in (False, True):
|
||||
|
|
|
|||
|
|
@ -534,6 +534,10 @@ fake_tensor_cache_crosscheck_enabled = (
|
|||
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1"
|
||||
)
|
||||
|
||||
# Disables inference mode for fake tensor prop during compilation. At runtime,
|
||||
# the inference_mode is still respected.
|
||||
fake_tensor_disable_inference_mode = True
|
||||
|
||||
# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile().
|
||||
# Note: AOT Autograd will still trace joint graphs.
|
||||
compiled_autograd = False
|
||||
|
|
|
|||
|
|
@ -138,6 +138,8 @@ from .utils import (
|
|||
is_namedtuple,
|
||||
istype,
|
||||
LazyString,
|
||||
maybe_disable_inference_mode,
|
||||
maybe_disable_inference_mode_for_fake_prop,
|
||||
orig_code_map,
|
||||
reset_graph_break_dup_checker,
|
||||
setup_compile_debug,
|
||||
|
|
@ -227,11 +229,16 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
guards = GlobalStateGuard()
|
||||
prior_grad_mode = torch.is_grad_enabled()
|
||||
|
||||
# Just in case we get left in a bad dispatch state we want to restore
|
||||
# it. This can happen because the dispatch bits aren't a true
|
||||
# stack/counter - so we can't just increment/decrement them as we enter
|
||||
# and leave.
|
||||
with torch._C._PreserveDispatchKeyGuard():
|
||||
with (
|
||||
torch._C._PreserveDispatchKeyGuard(),
|
||||
maybe_disable_inference_mode(),
|
||||
maybe_disable_inference_mode_for_fake_prop(),
|
||||
):
|
||||
prior_inference_mode = torch.is_inference_mode_enabled()
|
||||
prior_deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
|
||||
|
|
|
|||
|
|
@ -19,3 +19,8 @@ SUPPORTABLE = [
|
|||
CAUSED_BY_EARLIER_GRAPH_BREAK = [
|
||||
"This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.",
|
||||
]
|
||||
INFERENCE_MODE = [
|
||||
"Avoid using `tensor.is_inference()` and `torch.is_inference_mode_enabled()` in your compile code. "
|
||||
"This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead ",
|
||||
" because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used.",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4474,3 +4474,41 @@ def get_optimize_ddp_mode():
|
|||
f"Invalid dynamo config optimize_ddp value {mode=}"
|
||||
)
|
||||
return mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_disable_inference_mode() -> Generator[None, None, None]:
|
||||
"""
|
||||
Disables torch.inference_mode for the compilation (still on at runtime).
|
||||
This simplifies the compile stack where we can assume that inference_mode
|
||||
will always be off.
|
||||
|
||||
Since inference_mode is equivalent to no_grad + some optimizations (version
|
||||
counts etc), we turn on no_grad here. The other optimizations are not
|
||||
relevant to torch.compile.
|
||||
"""
|
||||
is_inference_mode_on = (
|
||||
config.fake_tensor_disable_inference_mode and torch.is_inference_mode_enabled()
|
||||
)
|
||||
if is_inference_mode_on:
|
||||
with (
|
||||
torch.inference_mode(False),
|
||||
torch.no_grad(),
|
||||
):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
|
||||
"""
|
||||
Turns off tracking of inference_mode for fake tensor propagation. With this
|
||||
context manager, when a real tensor is converted to fake tensor, the fake
|
||||
tensor looses its inference-ness.
|
||||
"""
|
||||
if config.fake_tensor_disable_inference_mode:
|
||||
with torch._subclasses.meta_utils.disable_inference_mode_for_fake_prop():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -1795,6 +1795,7 @@ class VariableBuilder:
|
|||
example_value = wrap_to_fake_tensor_and_record(
|
||||
value, tx=self.tx, is_tensor=True, source=source
|
||||
)
|
||||
|
||||
tensor_proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||
type(value),
|
||||
|
|
@ -3029,6 +3030,7 @@ def wrap_to_fake_tensor_and_record(
|
|||
symbolic_context,
|
||||
type(e),
|
||||
)
|
||||
|
||||
fake_e = wrap_fake_exception(
|
||||
lambda: tx.fake_mode.from_tensor(
|
||||
e,
|
||||
|
|
|
|||
|
|
@ -598,10 +598,27 @@ class InferenceModeVariable(ContextWrappingVariable):
|
|||
)
|
||||
|
||||
def enter(self, tx):
|
||||
disabled_inference_mode_forcibly = False
|
||||
if (
|
||||
torch._dynamo.config.fake_tensor_disable_inference_mode
|
||||
and self.target_values[0]
|
||||
):
|
||||
# Do not set the inference mode because we keep it off during
|
||||
# compilation. Set the grad_enabled to False to reflect the relevant
|
||||
# part of inference_mode to torch.compile.
|
||||
disabled_inference_mode_forcibly = True
|
||||
prior = torch.is_grad_enabled()
|
||||
torch._C._set_grad_enabled(False)
|
||||
else:
|
||||
ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
|
||||
self.set_cleanup_hook(
|
||||
tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
|
||||
)
|
||||
|
||||
def cleanup_hook():
|
||||
if disabled_inference_mode_forcibly:
|
||||
torch._C._set_grad_enabled(prior)
|
||||
else:
|
||||
torch.autograd.grad_mode._exit_inference_mode(ctx)
|
||||
|
||||
self.set_cleanup_hook(tx, cleanup_hook)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch.autograd.grad_mode._enter_inference_mode,
|
||||
|
|
|
|||
|
|
@ -43,10 +43,11 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
)
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config, variables
|
||||
from .. import config, graph_break_hints, variables
|
||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||
from ..exc import (
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
UnknownPropertiesDuringBackwardTrace,
|
||||
UserError,
|
||||
UserErrorType,
|
||||
|
|
@ -708,6 +709,16 @@ class TensorVariable(VariableTracker):
|
|||
return ConstantVariable.create(self.dtype.is_floating_point)
|
||||
|
||||
def method_is_inference(self):
|
||||
if config.fake_tensor_disable_inference_mode:
|
||||
unimplemented_v2(
|
||||
gb_type="Encountered tensor.is_inference() during tracing",
|
||||
context="",
|
||||
explanation="tensor.is_inference() is not supported",
|
||||
hints=[
|
||||
*graph_break_hints.FUNDAMENTAL,
|
||||
*graph_break_hints.INFERENCE_MODE,
|
||||
],
|
||||
)
|
||||
if (fake := self.proxy.node.meta.get("example_value")) is not None:
|
||||
return ConstantVariable.create(fake.is_inference())
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from torch._guards import TracingContext
|
|||
from torch._logging import warning_once
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
|
||||
|
||||
from .. import config, polyfills, variables
|
||||
from .. import config, graph_break_hints, polyfills, variables
|
||||
from ..codegen import PyCodegen
|
||||
from ..create_parameter_op import (
|
||||
can_convert_to_tracable_parameter,
|
||||
|
|
@ -52,7 +52,7 @@ from ..create_parameter_op import (
|
|||
tracable_create_parameter,
|
||||
)
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import unimplemented
|
||||
from ..exc import unimplemented, unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import CallFunctionNoArgsSource, SyntheticLocalSource
|
||||
from ..utils import (
|
||||
|
|
@ -519,6 +519,18 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||
)
|
||||
|
||||
@register(torch.is_inference_mode_enabled)
|
||||
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
|
||||
unimplemented_v2(
|
||||
gb_type="Encountered torch.is_inference_mode_enabled during tracing",
|
||||
context="",
|
||||
explanation="torch.is_inference_mode_enabled() is not supported",
|
||||
hints=[
|
||||
*graph_break_hints.FUNDAMENTAL,
|
||||
*graph_break_hints.INFERENCE_MODE,
|
||||
],
|
||||
)
|
||||
|
||||
@register(torch.is_tensor, torch.overrides.is_tensor_like)
|
||||
def handle_is_tensor(self, tx: "InstructionTranslator", arg):
|
||||
if isinstance(arg, TensorVariable) or (
|
||||
|
|
|
|||
|
|
@ -2120,6 +2120,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
and len(flat_arg_fake_tensors) != 0
|
||||
and not has_symbolic_sizes
|
||||
and not avoiding_device_init
|
||||
and func is not aten._nested_tensor_from_tensor_list.default
|
||||
):
|
||||
const_flat_args = [
|
||||
a.constant if self.is_our_fake(a) else a for a in flat_args
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
|||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
|
|
@ -45,6 +46,8 @@ from torch.utils.weak import WeakIdKeyDictionary
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from torch._C._functorch import CInterpreter
|
||||
from torch._guards import Source
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
|
|
@ -92,6 +95,23 @@ def assert_eq(a: _T, b: _T) -> None:
|
|||
assert a == b, f"{a} != {b}"
|
||||
|
||||
|
||||
tls = threading.local()
|
||||
# Turns off inference mode for fake tensor propagation. This is turned to True
|
||||
# only for `torch.compile`. Also look at
|
||||
# _dynamo.config.fake_tensor_disable_inference_mode
|
||||
tls.disable_inference_mode = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
|
||||
prior = getattr(tls, "disable_inference_mode", False)
|
||||
tls.disable_inference_mode = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tls.disable_inference_mode = prior
|
||||
|
||||
|
||||
def assert_metadata_eq(
|
||||
assert_eq: Callable[[object, object], None],
|
||||
m1: Union[MetaTensorDesc, torch.Tensor],
|
||||
|
|
@ -116,7 +136,10 @@ def assert_metadata_eq(
|
|||
# MetaTensorDesc doesn't store grad_fn; inferred from leaf
|
||||
# assert_eq(m1.grad_fn is None, m2.grad_fn is None)
|
||||
assert_eq(m1.is_sparse, m2.is_sparse)
|
||||
if not getattr(tls, "disable_inference_mode", False):
|
||||
assert_eq(m1.is_inference, m2.is_inference())
|
||||
else:
|
||||
assert_eq(m1.is_inference, False)
|
||||
assert_eq(m1.is_conj, m2.is_conj())
|
||||
assert_eq(m1.is_neg, m2.is_neg())
|
||||
assert_eq(m1.grad is not None, safe_grad(m2) is not None)
|
||||
|
|
@ -354,10 +377,11 @@ class MetaTensorDescriber:
|
|||
|
||||
# TODO: Is it important to enable torch.inference_mode before querying
|
||||
# these values?
|
||||
is_inference_mode_disabled = getattr(tls, "disable_inference_mode", False)
|
||||
r: MetaTensorDesc = MetaTensorDesc(
|
||||
id=self.get_tensor_id(t),
|
||||
storage=storage,
|
||||
is_inference=t.is_inference(),
|
||||
is_inference=False if is_inference_mode_disabled else t.is_inference(),
|
||||
is_leaf=is_leaf,
|
||||
requires_grad=t.requires_grad,
|
||||
# NB: ndim should be OK too but there is a disaster at
|
||||
|
|
|
|||
|
|
@ -299,7 +299,14 @@ def foreach_all_gather_copy_out(
|
|||
out = [t.view(world_size, -1) for t in split_with_sizes_out]
|
||||
|
||||
# only avoid VC bump if we are not in inference mode
|
||||
if torch._dynamo.is_compiling():
|
||||
# For torch.compile, we turn off inference_mode for fake tensor
|
||||
# propagation, and therefore graph break on is_inference. For `compile`,
|
||||
# we don't care about VCs, so just skip the optimization.
|
||||
non_inference_outs = []
|
||||
else:
|
||||
non_inference_outs = [o for o in out if not o.is_inference()]
|
||||
|
||||
if len(non_inference_outs) > 0:
|
||||
with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)):
|
||||
torch.ops.fsdp.split_with_sizes_copy(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user