mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843 Approved by: https://github.com/oulgen
441 lines
16 KiB
Python
441 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
import pickle
|
|
import shutil
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import AbstractContextManager, nullcontext
|
|
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
|
|
|
|
import torch.fx
|
|
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
|
|
from torch._dynamo.utils import dynamo_timed
|
|
from torch._inductor.cpp_builder import normalize_path_separator
|
|
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
|
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
|
|
from torch._inductor.utils import BoxedBool, InputType
|
|
from torch._subclasses import FakeTensorMode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
from . import config
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
from torch.compiler._cache import CacheInfo
|
|
from torch.fx import GraphModule
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CompiledArtifact(ABC):
|
|
"""
|
|
CompiledArtifact class represents the inductor cache artifacts that
|
|
can be invoked in order to avoid repeated compilation.
|
|
|
|
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
|
|
to create a fresh CompiledArtifact from a GraphModule and example inputs.
|
|
|
|
Later this CompiledArtifact can be saved to disk, either as a binary or unpacked
|
|
into the provided folder via the CompiledArtifact.save function.
|
|
|
|
CompiledArtifact.load provides a way to create a CompiledArtifact from the
|
|
binary or unpacked data.
|
|
|
|
Finally, the CompiledArtifact can be invoked via the __call__ method
|
|
to execute the cached artifact.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
compiled_fn: Callable[..., Any],
|
|
artifacts: Optional[tuple[bytes, CacheInfo]],
|
|
):
|
|
self._compiled_fn = compiled_fn
|
|
self._artifacts = artifacts
|
|
|
|
@abstractmethod
|
|
def __call__(self, *args: Any) -> Any: ...
|
|
|
|
@abstractmethod
|
|
def save(
|
|
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> None: ...
|
|
|
|
@staticmethod
|
|
def load(
|
|
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> CompiledArtifact:
|
|
if format == "unpacked":
|
|
# If format is unpacked, it must be a CacheCompiledArtifact
|
|
return CacheCompiledArtifact.load(path=path, format=format)
|
|
|
|
assert format == "binary"
|
|
with open(path, "rb") as file:
|
|
from torch.utils._appending_byte_serializer import BytesReader
|
|
|
|
from .codecache import torch_key
|
|
|
|
result_bytes = file.read()
|
|
reader = BytesReader(result_bytes)
|
|
header = reader.read_bytes()
|
|
if header == AOTCompiledArtifact.AOT_HEADER:
|
|
assert reader.read_bytes() == torch_key()
|
|
artifact = reader.read_bytes()
|
|
assert reader.is_finished()
|
|
return AOTCompiledArtifact.deserialize(artifact)
|
|
# Otherwise, it's in the CacheCompiledArtifact format
|
|
elif header == CacheCompiledArtifact.CACHE_HEADER:
|
|
assert reader.read_bytes() == torch_key()
|
|
key = reader.read_str()
|
|
artifact_bytes = reader.read_bytes()
|
|
assert reader.is_finished()
|
|
torch.compiler.load_cache_artifacts(artifact_bytes)
|
|
return CacheCompiledArtifact._load_impl(nullcontext(), key)
|
|
else:
|
|
raise RuntimeError(
|
|
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
|
|
+ header.decode("utf-8")
|
|
)
|
|
|
|
|
|
class CacheCompiledArtifact(CompiledArtifact):
|
|
"""
|
|
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
|
|
"""
|
|
|
|
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
|
|
|
|
def __init__(
|
|
self,
|
|
compiled_fn: Callable[..., Any],
|
|
artifacts: Optional[tuple[bytes, CacheInfo]],
|
|
):
|
|
self._compiled_fn = compiled_fn
|
|
self._artifacts = artifacts
|
|
|
|
def __call__(self, *args: Any) -> Any:
|
|
return self._compiled_fn(*args)
|
|
|
|
def save(
|
|
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> None:
|
|
with dynamo_timed("CompiledArtifact.save"):
|
|
if self._artifacts is None:
|
|
raise RuntimeError(
|
|
"CompiledArtifact.save failed to save since there's no artifact to save"
|
|
)
|
|
artifact_bytes, cache_info = self._artifacts
|
|
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
|
|
key = cache_info.aot_autograd_artifacts[0]
|
|
|
|
if format == "binary":
|
|
# can't assert that it is a file since it might not exist yet
|
|
assert not os.path.isdir(path)
|
|
|
|
from torch.utils._appending_byte_serializer import BytesWriter
|
|
|
|
from .codecache import torch_key
|
|
|
|
writer = BytesWriter()
|
|
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
|
|
writer.write_bytes(torch_key())
|
|
writer.write_str(key)
|
|
writer.write_bytes(artifact_bytes)
|
|
|
|
from torch._inductor.codecache import write_atomic
|
|
|
|
write_atomic(path, writer.to_bytes())
|
|
else:
|
|
assert format == "unpacked"
|
|
if os.path.exists(path):
|
|
assert os.path.isdir(path)
|
|
shutil.rmtree(path, ignore_errors=True)
|
|
|
|
from .codecache import FxGraphCache
|
|
|
|
with temporary_cache_dir(path):
|
|
# This function unpacks the cache artifacts to disk
|
|
loaded_cache_info = torch.compiler.load_cache_artifacts(
|
|
artifact_bytes
|
|
)
|
|
assert loaded_cache_info is not None
|
|
# Now write all the output_code artifacts to disk so that
|
|
# they can be inspected and modified
|
|
for key in loaded_cache_info.inductor_artifacts:
|
|
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
|
assert os.path.exists(subdir)
|
|
for path in sorted(os.listdir(subdir)):
|
|
with open(os.path.join(subdir, path), "rb") as f:
|
|
graph = pickle.load(f)
|
|
output_file = graph.write_to_disk()
|
|
log.info("Output code written to: %s", output_file)
|
|
|
|
@staticmethod
|
|
def _load_impl(
|
|
cache_dir_ctx: AbstractContextManager[Any], key: str
|
|
) -> CompiledArtifact:
|
|
with (
|
|
cache_dir_ctx,
|
|
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
|
):
|
|
with torch._functorch.config.patch(strict_autograd_cache=True):
|
|
from torch._functorch._aot_autograd.autograd_cache import (
|
|
AOTAutogradCache,
|
|
)
|
|
|
|
result = AOTAutogradCache._lookup(
|
|
key,
|
|
local=True,
|
|
remote=False,
|
|
args=[],
|
|
cache_info={},
|
|
aot_config=None,
|
|
)
|
|
|
|
assert result is not None
|
|
(entry, _) = result
|
|
|
|
from .compile_fx import _CompileFxKwargs
|
|
|
|
fx_config = _CompileFxKwargs(
|
|
cudagraphs=BoxedBool(False),
|
|
boxed_forward_device_index=BoxedDeviceIndex(0),
|
|
)
|
|
|
|
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
|
|
with torch._guards.tracing(context):
|
|
compiled_fn = entry.wrap_post_compile(
|
|
[], entry.sanitized_aot_config, fx_config
|
|
)
|
|
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
|
|
|
@staticmethod
|
|
def _prepare_load(
|
|
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> tuple[str, AbstractContextManager[Any]]:
|
|
"""
|
|
Do format specific prep and loads, return a context manager and key
|
|
"""
|
|
path = normalize_path_separator(path)
|
|
with dynamo_timed("CompiledArtifact.load"):
|
|
if format == "binary":
|
|
# can't assert that it is a file since it might not exist yet
|
|
assert not os.path.isdir(path)
|
|
with open(path, "rb") as file:
|
|
artifacts = file.read()
|
|
from torch.utils._appending_byte_serializer import BytesReader
|
|
|
|
from .codecache import torch_key
|
|
|
|
reader = BytesReader(artifacts)
|
|
assert reader.read_bytes() == torch_key()
|
|
key = reader.read_str()
|
|
artifact_bytes = reader.read_bytes()
|
|
assert reader.is_finished()
|
|
|
|
torch.compiler.load_cache_artifacts(artifact_bytes)
|
|
return key, nullcontext()
|
|
else:
|
|
assert format == "unpacked"
|
|
assert os.path.isdir(path)
|
|
autograd_cache_dir = os.path.join(path, "aotautograd")
|
|
assert os.path.isdir(autograd_cache_dir)
|
|
files = list(os.listdir(autograd_cache_dir))
|
|
assert len(files) == 1
|
|
key = files[0]
|
|
cache_dir_ctx = temporary_cache_dir(path)
|
|
return key, cache_dir_ctx
|
|
|
|
@staticmethod
|
|
def load(
|
|
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> CompiledArtifact:
|
|
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
|
|
path=path, format=format
|
|
)
|
|
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
|
|
|
|
|
|
class AOTCompiledArtifact(CompiledArtifact):
|
|
"""
|
|
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
|
|
This object is always a serializable callable function.
|
|
|
|
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
|
|
is used by torch._dynamo.aot_compile for AOT Precompilation.
|
|
"""
|
|
|
|
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
|
|
|
|
def __init__(
|
|
self,
|
|
compiled_fn: Callable[..., Any],
|
|
):
|
|
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
|
|
self._artifacts = (
|
|
None # We don't need artifacts, the inner object handles everything
|
|
)
|
|
|
|
@staticmethod
|
|
def from_bundled_callable(
|
|
bundled_fn: BundledAOTAutogradSerializableCallable,
|
|
) -> AOTCompiledArtifact:
|
|
return AOTCompiledArtifact(bundled_fn.compiled_fn)
|
|
|
|
def __call__(self, *args: Any) -> Any:
|
|
return self.inner_fn(*args)
|
|
|
|
def save(
|
|
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> None:
|
|
if format == "unpacked":
|
|
raise RuntimeError(
|
|
"AOTCompiledArtifact does not support unpacked format yet"
|
|
)
|
|
result_bytes = self.serialize()
|
|
from torch.utils._appending_byte_serializer import BytesWriter
|
|
|
|
from .codecache import torch_key
|
|
|
|
writer = BytesWriter()
|
|
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
|
|
writer.write_bytes(torch_key())
|
|
writer.write_bytes(result_bytes)
|
|
|
|
from torch._inductor.codecache import write_atomic
|
|
|
|
# Save a sentinel file to indicate that this is AOT
|
|
write_atomic(path, writer.to_bytes())
|
|
|
|
def serialize(self) -> bytes:
|
|
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
|
|
self.inner_fn
|
|
)
|
|
|
|
@staticmethod
|
|
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
|
|
deserialized = (
|
|
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
|
|
result_bytes
|
|
)
|
|
)
|
|
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
|
|
return AOTCompiledArtifact.from_bundled_callable(deserialized)
|
|
|
|
@staticmethod
|
|
def load(
|
|
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
|
) -> CompiledArtifact:
|
|
if format == "unpacked":
|
|
raise RuntimeError(
|
|
"AOTCompiledArtifact does not support unpacked format yet"
|
|
)
|
|
with open(path, "rb") as file:
|
|
from torch.utils._appending_byte_serializer import BytesReader
|
|
|
|
from .codecache import torch_key
|
|
|
|
result_bytes = file.read()
|
|
reader = BytesReader(result_bytes)
|
|
header = reader.read_bytes()
|
|
assert header == AOTCompiledArtifact.AOT_HEADER
|
|
assert reader.read_bytes() == torch_key()
|
|
artifact = reader.read_bytes()
|
|
assert reader.is_finished()
|
|
return AOTCompiledArtifact.deserialize(artifact)
|
|
|
|
|
|
def standalone_compile(
|
|
gm: GraphModule,
|
|
example_inputs: Sequence[InputType],
|
|
*,
|
|
dynamic_shapes: Any,
|
|
options: Any,
|
|
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
|
|
) -> CompiledArtifact:
|
|
"""
|
|
Implementation of torch.inductor.standalone_compile
|
|
"""
|
|
from torch.compiler._cache import CacheArtifactManager
|
|
|
|
from .compile_fx import compile_fx
|
|
|
|
ignore_shape_env = False
|
|
if dynamic_shapes == "from_example_inputs":
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
|
# tells compile_fx to ignore the shape_envs on the ambient context
|
|
# and the graph_module.
|
|
ignore_shape_env = True
|
|
elif dynamic_shapes == "from_tracing_context":
|
|
# Reuse fake_mode from the TracingContext.
|
|
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
|
|
context = torch._guards.TracingContext.get()
|
|
assert context.fake_mode is not None
|
|
fake_mode = context.fake_mode
|
|
elif dynamic_shapes == "from_graph":
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
|
# Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode.
|
|
# The graph passed to standalone_compile must be an Inductor-approved graph,
|
|
# which means that there is at least one Tensor output and the output node
|
|
# contains a flat list of Tensors.
|
|
last_node = next(iter(reversed(gm.graph.nodes)))
|
|
assert last_node.op == "output"
|
|
assert len(last_node.args) == 1
|
|
|
|
def handle_node(node: torch.fx.Node) -> None:
|
|
nonlocal fake_mode
|
|
if "example_value" in node.meta:
|
|
maybe_tensor = node.meta["example_value"]
|
|
if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor):
|
|
fake_mode = maybe_tensor.fake_mode
|
|
|
|
# If gm came from Dynamo, then last_node.args[0] is always a list,
|
|
# even in single-Tensor returns.
|
|
#
|
|
# It's possible to get into a situation where last_node.args[0]
|
|
# is a Node (and not a list!). This happens if you call split_module
|
|
# on the graph. We allow for this case since it is common.
|
|
if isinstance(last_node.args[0], torch.fx.Node):
|
|
handle_node(last_node.args[0])
|
|
else:
|
|
for node in last_node.args[0]:
|
|
handle_node(node)
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}."
|
|
)
|
|
|
|
context = torch._guards.TracingContext(fake_mode)
|
|
with (
|
|
torch._guards.tracing(context),
|
|
CacheArtifactManager.with_fresh_cache(),
|
|
config.patch("triton.autotune_at_compile_time", True),
|
|
torch._functorch.config.patch("bundled_autograd_cache", aot),
|
|
):
|
|
# compile_fx can mutate gm
|
|
gm = copy.deepcopy(gm)
|
|
compiled_fn = compile_fx(
|
|
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
|
|
)
|
|
assert callable(compiled_fn)
|
|
if aot:
|
|
if not hasattr(compiled_fn, "serialize"):
|
|
raise RuntimeError(
|
|
"Compiled function should have serialize method when aot=True"
|
|
)
|
|
return AOTCompiledArtifact(compiled_fn)
|
|
artifacts = torch.compiler.save_cache_artifacts()
|
|
if artifacts is None:
|
|
log.warning(
|
|
"standalone_compile artifact generation failed, cannot save. "
|
|
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
|
|
)
|
|
|
|
return CacheCompiledArtifact(compiled_fn, artifacts)
|