pytorch/torch/_inductor/standalone_compile.py
James Wu 06773663b5 Implement an AOT precompile mode for standalone_compile (#165843)
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
2025-10-21 15:02:45 +00:00

441 lines
16 KiB
Python

from __future__ import annotations
import copy
import logging
import os
import pickle
import shutil
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
from torch._inductor.utils import BoxedBool, InputType
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config
if TYPE_CHECKING:
from collections.abc import Sequence
from torch.compiler._cache import CacheInfo
from torch.fx import GraphModule
log = logging.getLogger(__name__)
class CompiledArtifact(ABC):
"""
CompiledArtifact class represents the inductor cache artifacts that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
to create a fresh CompiledArtifact from a GraphModule and example inputs.
Later this CompiledArtifact can be saved to disk, either as a binary or unpacked
into the provided folder via the CompiledArtifact.save function.
CompiledArtifact.load provides a way to create a CompiledArtifact from the
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the cached artifact.
"""
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
@abstractmethod
def __call__(self, *args: Any) -> Any: ...
@abstractmethod
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None: ...
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
# If format is unpacked, it must be a CacheCompiledArtifact
return CacheCompiledArtifact.load(path=path, format=format)
assert format == "binary"
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
if header == AOTCompiledArtifact.AOT_HEADER:
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
# Otherwise, it's in the CacheCompiledArtifact format
elif header == CacheCompiledArtifact.CACHE_HEADER:
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return CacheCompiledArtifact._load_impl(nullcontext(), key)
else:
raise RuntimeError(
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
+ header.decode("utf-8")
)
class CacheCompiledArtifact(CompiledArtifact):
"""
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
"""
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
def __call__(self, *args: Any) -> Any:
return self._compiled_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
with dynamo_timed("CompiledArtifact.save"):
if self._artifacts is None:
raise RuntimeError(
"CompiledArtifact.save failed to save since there's no artifact to save"
)
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
key = cache_info.aot_autograd_artifacts[0]
if format == "binary":
# can't assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
from torch._inductor.codecache import write_atomic
write_atomic(path, writer.to_bytes())
else:
assert format == "unpacked"
if os.path.exists(path):
assert os.path.isdir(path)
shutil.rmtree(path, ignore_errors=True)
from .codecache import FxGraphCache
with temporary_cache_dir(path):
# This function unpacks the cache artifacts to disk
loaded_cache_info = torch.compiler.load_cache_artifacts(
artifact_bytes
)
assert loaded_cache_info is not None
# Now write all the output_code artifacts to disk so that
# they can be inspected and modified
for key in loaded_cache_info.inductor_artifacts:
subdir = FxGraphCache._get_tmp_dir_for_key(key)
assert os.path.exists(subdir)
for path in sorted(os.listdir(subdir)):
with open(os.path.join(subdir, path), "rb") as f:
graph = pickle.load(f)
output_file = graph.write_to_disk()
log.info("Output code written to: %s", output_file)
@staticmethod
def _load_impl(
cache_dir_ctx: AbstractContextManager[Any], key: str
) -> CompiledArtifact:
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert result is not None
(entry, _) = result
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
@staticmethod
def _prepare_load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> tuple[str, AbstractContextManager[Any]]:
"""
Do format specific prep and loads, return a context manager and key
"""
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
# can't assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
with open(path, "rb") as file:
artifacts = file.read()
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
reader = BytesReader(artifacts)
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return key, nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
autograd_cache_dir = os.path.join(path, "aotautograd")
assert os.path.isdir(autograd_cache_dir)
files = list(os.listdir(autograd_cache_dir))
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
return key, cache_dir_ctx
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
path=path, format=format
)
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
class AOTCompiledArtifact(CompiledArtifact):
"""
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
This object is always a serializable callable function.
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
is used by torch._dynamo.aot_compile for AOT Precompilation.
"""
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
def __init__(
self,
compiled_fn: Callable[..., Any],
):
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
self._artifacts = (
None # We don't need artifacts, the inner object handles everything
)
@staticmethod
def from_bundled_callable(
bundled_fn: BundledAOTAutogradSerializableCallable,
) -> AOTCompiledArtifact:
return AOTCompiledArtifact(bundled_fn.compiled_fn)
def __call__(self, *args: Any) -> Any:
return self.inner_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
result_bytes = self.serialize()
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
writer.write_bytes(torch_key())
writer.write_bytes(result_bytes)
from torch._inductor.codecache import write_atomic
# Save a sentinel file to indicate that this is AOT
write_atomic(path, writer.to_bytes())
def serialize(self) -> bytes:
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
self.inner_fn
)
@staticmethod
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
deserialized = (
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
result_bytes
)
)
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
return AOTCompiledArtifact.from_bundled_callable(deserialized)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
assert header == AOTCompiledArtifact.AOT_HEADER
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
def standalone_compile(
gm: GraphModule,
example_inputs: Sequence[InputType],
*,
dynamic_shapes: Any,
options: Any,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Implementation of torch.inductor.standalone_compile
"""
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
ignore_shape_env = False
if dynamic_shapes == "from_example_inputs":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# tells compile_fx to ignore the shape_envs on the ambient context
# and the graph_module.
ignore_shape_env = True
elif dynamic_shapes == "from_tracing_context":
# Reuse fake_mode from the TracingContext.
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
context = torch._guards.TracingContext.get()
assert context.fake_mode is not None
fake_mode = context.fake_mode
elif dynamic_shapes == "from_graph":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode.
# The graph passed to standalone_compile must be an Inductor-approved graph,
# which means that there is at least one Tensor output and the output node
# contains a flat list of Tensors.
last_node = next(iter(reversed(gm.graph.nodes)))
assert last_node.op == "output"
assert len(last_node.args) == 1
def handle_node(node: torch.fx.Node) -> None:
nonlocal fake_mode
if "example_value" in node.meta:
maybe_tensor = node.meta["example_value"]
if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor):
fake_mode = maybe_tensor.fake_mode
# If gm came from Dynamo, then last_node.args[0] is always a list,
# even in single-Tensor returns.
#
# It's possible to get into a situation where last_node.args[0]
# is a Node (and not a list!). This happens if you call split_module
# on the graph. We allow for this case since it is common.
if isinstance(last_node.args[0], torch.fx.Node):
handle_node(last_node.args[0])
else:
for node in last_node.args[0]:
handle_node(node)
else:
raise ValueError(
f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}."
)
context = torch._guards.TracingContext(fake_mode)
with (
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
torch._functorch.config.patch("bundled_autograd_cache", aot),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
compiled_fn = compile_fx(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
if aot:
if not hasattr(compiled_fn, "serialize"):
raise RuntimeError(
"Compiled function should have serialize method when aot=True"
)
return AOTCompiledArtifact(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
"standalone_compile artifact generation failed, cannot save. "
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CacheCompiledArtifact(compiled_fn, artifacts)