Revert "Add inductor standalone_compile API (#150670)"

This reverts commit c9aef50898.

Reverted https://github.com/pytorch/pytorch/pull/150670 on behalf of https://github.com/Camyll due to breaking internal builds with torch module not found error ([comment](https://github.com/pytorch/pytorch/pull/150670#issuecomment-2806975267))
This commit is contained in:
PyTorch MergeBot 2025-04-15 17:35:59 +00:00
parent c0a0761871
commit 74f6bc28a7
13 changed files with 17 additions and 438 deletions

View File

@ -3,8 +3,6 @@ import functools
import os
import pickle
import shutil
import subprocess
import sys
import tempfile
import unittest
from typing import Optional, Union
@ -13,7 +11,6 @@ from unittest import mock
import torch
from torch._dynamo import reset
from torch._dynamo.utils import counters
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor import config, metrics
from torch._inductor.codecache import (
@ -1379,131 +1376,6 @@ class TestFxGraphCache(TestCase):
)
@instantiate_parametrized_tests
class TestStandaloneCompile(TestCase):
def setUp(self):
super().setUp()
counters.clear()
PatchCaches.setUp()
CacheArtifactManager.clear()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
def reset(self):
AOTAutogradCache.clear()
PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset()
clear_inductor_caches()
def capture(self, fn):
def inner(*args):
gm = None
actual_args = None
kwargs = None
def backend(gm_, args_, **kwargs_):
nonlocal gm
nonlocal actual_args
nonlocal kwargs
gm = gm_
actual_args = args_
kwargs = kwargs_
return gm
_ = torch.compile(fn, fullgraph=True, backend=backend)(*args)
return gm, actual_args, kwargs
return inner
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("format", ("binary", "unpacked"))
@parametrize("dynamic", (False, True))
def test_basic(self, format: str, dynamic: bool) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
if dynamic:
torch._dynamo.mark_dynamic(x, 0)
def f(x):
with torch.no_grad():
return mod(x)
eager_out = f(x)
with tempfile.TemporaryDirectory() as temp_dir:
path = (
temp_dir
if format == "unpacked"
else os.path.join(temp_dir, "compiled_artifact.bin")
)
with fresh_inductor_cache():
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path, format=format)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_inductor_cache():
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
compiled_out = loaded(*args)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_different_process(self):
x = torch.ones(4, 1)
def f(x):
return x.sin() * 2
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "compiled_artifact.bin")
with fresh_inductor_cache():
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path)
script = f"""
import torch
from torch._inductor.utils import fresh_inductor_cache
arg = torch.ones(4, 1)
with fresh_inductor_cache():
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
compiled_result = loaded(arg)
eager_result = arg.sin() * 2
if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
raise RuntimeError("tensors do not match")
"""
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
cwd=os.path.dirname(os.path.realpath(__file__)),
)
except subprocess.CalledProcessError as e:
self.fail(
msg=(
"Subprocess exception while attempting to run test: "
+ e.output.decode("utf-8")
)
)
class TestFxGraphCacheHashing(TestCase):
def test_parameter_constants(self):
"""

View File

@ -488,9 +488,6 @@ class AOTAutogradCacheEntry:
forward_time_taken_ns: int
backward_time_taken_ns: int
# Used by standalone_compile
sanitized_aot_config: AOTConfig
# Turn cache entry into the original callable
def wrap_post_compile(
self,

View File

@ -142,27 +142,6 @@ def aot_dispatch_export(
return compiled_fn, fw_metadata
def sanitize_aot_config(input: AOTConfig) -> AOTConfig:
return AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
inference_compiler=None,
num_params_buffers=input.num_params_buffers,
aot_id=input.aot_id,
keep_inference_input_mutations=input.keep_inference_input_mutations,
is_export=input.is_export,
no_tangents=input.no_tangents,
aot_autograd_arg_pos_to_source=input.aot_autograd_arg_pos_to_source,
dynamic_shapes=input.dynamic_shapes,
enable_log=input.enable_log,
static_input_indices=input.static_input_indices,
pre_dispatch=input.pre_dispatch,
cache_info=None,
)
def aot_dispatch_base(
flat_fn,
flat_args: list[Any],
@ -272,7 +251,6 @@ def aot_dispatch_base(
indices_of_inps_to_detach=[],
forward_time_taken_ns=time_taken_ns,
backward_time_taken_ns=0,
sanitized_aot_config=sanitize_aot_config(aot_config),
)
AOTAutogradCache.save(
cache_info.cache_key, entry, remote=should_use_remote_autograd_cache()
@ -1327,7 +1305,6 @@ def aot_dispatch_autograd(
_indices_of_inps_to_detach,
forward_time_taken_ns,
backward_time_taken_ns,
sanitized_aot_config=sanitize_aot_config(aot_config),
)
remote = should_use_remote_autograd_cache()
AOTAutogradCache.save(cache_info.cache_key, entry, remote)

View File

@ -41,7 +41,7 @@ treat_parameters_as_free_to_save = True
# Applies CSE to the graph before partitioning
cse = True
from torch._environment import is_fbcode
from torch._inductor.config import is_fbcode
enable_autograd_cache: bool = Config(

View File

@ -9,8 +9,6 @@ from typing import Any, IO, Optional, TYPE_CHECKING, Union
import torch._inductor.config
import torch.fx
from .standalone_compile import CompiledArtifact # noqa: TC001
if TYPE_CHECKING:
from torch._inductor.utils import InputType
@ -22,7 +20,6 @@ __all__ = [
"list_mode_options",
"list_options",
"cudagraph_mark_step_begin",
"standalone_compile",
]
@ -361,34 +358,3 @@ def cudagraph_mark_step_begin():
from .cudagraph_trees import mark_step_begin
mark_step_begin()
def standalone_compile(
gm: torch.fx.GraphModule,
example_inputs: list[InputType],
options: Optional[dict[str, Any]] = None,
) -> CompiledArtifact:
"""
Precompilation API for inductor.
.. code-block:: python
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path, format="binary")
# Later on a new process
loaded = torch._inductor.CompiledArtifact.load(path=path, format="binary")
compiled_out = loaded(*args)
Args:
gm: Graph Module
example_inputs: Inputs for the graph module
options: Inductor compilation options
Returns:
CompiledArtifact that can be saved to disk or invoked directly.
"""
from .standalone_compile import standalone_compile
options = options if options else {}
return standalone_compile(gm, example_inputs, **options)

View File

@ -1046,20 +1046,16 @@ class FxGraphCache:
# If there's not a cache hit, we don't want the evaluation to
# affect the current env, e.g., cause the creation of new guards,
# so we evaluate with the hints instead of the symbols.
if config.unsafe_skip_cache_dynamic_shape_guards:
hit = True
else:
hit = bool(
shape_env.evaluate_guards_expression(candidate.guards_expr, hints)
)
log.debug(
"fx graph cache key %s evaluating guards [%s] with values %s => hit=%s",
key,
candidate.guards_expr,
hints,
hit,
)
hit = bool(
shape_env.evaluate_guards_expression(candidate.guards_expr, hints)
)
log.debug(
"fx graph cache key %s evaluating guards [%s] with values %s => hit=%s",
key,
candidate.guards_expr,
hints,
hit,
)
if hit:
graph = candidate
break
@ -1107,7 +1103,7 @@ class FxGraphCache:
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
# Now re-evaluate with the symints to add any guards to the current env.
if not config.unsafe_skip_cache_dynamic_shape_guards and graph.guards_expr:
if graph.guards_expr:
check = bool(
shape_env.evaluate_guards_expression(graph.guards_expr, symints)
)

View File

@ -70,7 +70,7 @@ from torch._inductor.output_code import (
index_expanded_dims,
OutputCode,
)
from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import (
BoxedBool,
count_tangents,

View File

@ -117,9 +117,6 @@ bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_de
# Force disabled all inductor level caching -- This will override any other caching flag
force_disable_caches: bool = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
# Unsafe way to skip dynamic shape guards to get faster cache load
unsafe_skip_cache_dynamic_shape_guards: bool = False
# sleep in inductor for testing
sleep_sec_TESTING_ONLY: Optional[int] = None

View File

@ -2,8 +2,6 @@ import getpass
import os
import re
import tempfile
from collections.abc import Generator
from contextlib import contextmanager
# Factoring out to file without torch dependencies
@ -33,16 +31,3 @@ def triton_cache_dir(device: int) -> str:
"triton",
str(device),
)
@contextmanager
def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
try:
yield
finally:
if original is None:
del os.environ["TORCHINDUCTOR_CACHE_DIR"]
else:
os.environ["TORCHINDUCTOR_CACHE_DIR"] = original

View File

@ -1,184 +0,0 @@
from __future__ import annotations
import logging
import os
import shutil
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
from torch._inductor.utils import BoxedBool, InputType
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config
from .utils import shape_env_from_inputs
if TYPE_CHECKING:
from collections.abc import Sequence
from torch.compiler._cache import CacheInfo
from torch.fx import GraphModule
log = logging.getLogger(__name__)
class CompiledArtifact:
"""
CompiledArtifact class represents the precompiled inductor artifact that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
to create a fresh CompiledArtifact from a GraphModule and example inputs.
Later this CompiledArtifact can be saved to disk, either as a binary or unpacked
into the provided folder via the CompiledArtifact.save function.
CompiledArtifact.load provides a way to create a CompiledArtifact from the
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the precompiled artifact.
"""
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
def __call__(self, *args: Any) -> Any:
return self._compiled_fn(*args)[0]
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
with dynamo_timed("CompiledArtifact.save"):
if self._artifacts is None:
raise RuntimeError(
"CompiledArtifact.save failed to save since there's no artifact to save"
)
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1
key = cache_info.aot_autograd_artifacts[0]
if format == "binary":
# cant assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter(0)
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
with open(path, "wb") as file:
file.write(writer.to_bytes())
else:
assert format == "unpacked"
assert os.path.isdir(path)
shutil.rmtree(path, ignore_errors=True)
with temporary_cache_dir(path):
# This function unpacks the cache artifacts to disk
torch.compiler.load_cache_artifacts(artifact_bytes)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
# cant assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
with open(path, "rb") as file:
artifacts = file.read()
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
reader = BytesReader(artifacts)
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
autograd_cache_dir = os.path.join(path, "aotautograd")
assert os.path.isdir(autograd_cache_dir)
files = list(os.listdir(autograd_cache_dir))
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
with cache_dir_ctx:
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
entry = AOTAutogradCache._lookup(key, local=True, remote=False)
assert entry is not None
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
with (
torch._guards.tracing(context),
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
def standalone_compile(
gm: GraphModule, example_inputs: Sequence[InputType], **kwargs: Any
) -> CompiledArtifact:
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
shape_env = shape_env_from_inputs(example_inputs, default=True)
assert shape_env is not None
context = torch._guards.TracingContext(FakeTensorMode(shape_env=shape_env))
with torch._guards.tracing(context):
with CacheArtifactManager.with_fresh_cache():
compiled_fn = compile_fx(gm, example_inputs, **kwargs)
assert callable(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
"standalone_compile artifact generation failed, cannot save. "
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CompiledArtifact(compiled_fn, artifacts)

View File

@ -50,7 +50,6 @@ import sympy
import torch
from torch._inductor.runtime.hints import DeviceProperties
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map_only
@ -61,6 +60,7 @@ if TYPE_CHECKING:
from torch import SymBool, SymFloat, SymInt
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.node import Node
from .codegen.common import WorkspaceArg
@ -2412,9 +2412,7 @@ def run_and_get_cpp_code(
return result, s
def shape_env_from_inputs(
inputs: Sequence[InputType], default: bool = False
) -> Optional[ShapeEnv]:
def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
fake_mode = detect_fake_mode(inputs)
# TODO(voz): It would be nice to enable this assert, but there are lots of tests that
@ -2430,9 +2428,6 @@ def shape_env_from_inputs(
if isinstance(input, torch.SymInt):
return input.node.shape_env
if default:
return ShapeEnv()
# TODO(voz): Should we always have one anyway?
return None

View File

@ -2,8 +2,6 @@ import copy
import dataclasses
import logging
import os
from collections.abc import Generator
from contextlib import contextmanager
from enum import Enum
from typing import Optional, Union
@ -123,26 +121,6 @@ class CacheArtifactManager:
cls._serializer.clear()
cls._cache_info.clear()
@classmethod
@contextmanager
def with_fresh_cache(cls) -> Generator[None, None, None]:
original_new_cache_artifacts = cls._new_cache_artifacts
original_seen_artifacts = cls._seen_artifacts
original_serializer = cls._serializer
original_cache_info = cls._cache_info
cls._new_cache_artifacts = []
cls._seen_artifacts = OrderedSet()
cls._serializer = AppendingByteSerializer(serialize_fn=CacheArtifact.serialize)
cls._cache_info = CacheInfo()
try:
yield
finally:
cls._new_cache_artifacts = original_new_cache_artifacts
cls._seen_artifacts = original_seen_artifacts
cls._serializer = original_serializer
cls._cache_info = original_cache_info
@classmethod
def record_artifact(
cls,

View File

@ -1404,10 +1404,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
# ProxyTorchDispatchMode state was (if there was any).
# This lets us properly reset the state on exit.
self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
self.decomp_layers: int = 0
self.decomp_layers = 0
from torch._inductor import config
self.emulate_precision_casts: bool = config.emulate_precision_casts
self.emulate_precision_casts = config.emulate_precision_casts
@count
def __torch_dispatch__(