pytorch/torch/_inductor/standalone_compile.py
James Wu e21ff9c3be Add logging for guard miss failure (#153125)
Differential Revision: [D74371381](https://our.internmc.facebook.com/intern/diff/D74371381/)

This PR adds some logging for guard misses to tlparse, so that we know when AOTAutogradCache and FxGraphCache miss due to guards.

Example tlparse result:
https://gist.github.com/jamesjwu/afa19335c0aee85b24546b13c1cf6427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153125
Approved by: https://github.com/oulgen, https://github.com/jingsh
2025-05-09 16:51:04 +00:00

242 lines
9.2 KiB
Python

from __future__ import annotations
import copy
import logging
import os
import pickle
import shutil
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
from torch._inductor.utils import BoxedBool, InputType
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config
if TYPE_CHECKING:
from collections.abc import Sequence
from torch.compiler._cache import CacheInfo
from torch.fx import GraphModule
log = logging.getLogger(__name__)
class CompiledArtifact:
"""
CompiledArtifact class represents the precompiled inductor artifact that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
to create a fresh CompiledArtifact from a GraphModule and example inputs.
Later this CompiledArtifact can be saved to disk, either as a binary or unpacked
into the provided folder via the CompiledArtifact.save function.
CompiledArtifact.load provides a way to create a CompiledArtifact from the
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the precompiled artifact.
"""
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
def __call__(self, *args: Any) -> Any:
return self._compiled_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
with dynamo_timed("CompiledArtifact.save"):
if self._artifacts is None:
raise RuntimeError(
"CompiledArtifact.save failed to save since there's no artifact to save"
)
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
key = cache_info.aot_autograd_artifacts[0]
if format == "binary":
# cant assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
with open(path, "wb") as file:
file.write(writer.to_bytes())
else:
assert format == "unpacked"
if os.path.exists(path):
assert os.path.isdir(path)
shutil.rmtree(path, ignore_errors=True)
from .codecache import FxGraphCache
with temporary_cache_dir(path):
# This function unpacks the cache artifacts to disk
loaded_cache_info = torch.compiler.load_cache_artifacts(
artifact_bytes
)
assert loaded_cache_info is not None
# Now write all the output_code artifacts to disk so that
# they can be inspected and modified
for key in loaded_cache_info.inductor_artifacts:
subdir = FxGraphCache._get_tmp_dir_for_key(key)
assert os.path.exists(subdir)
for path in sorted(os.listdir(subdir)):
with open(os.path.join(subdir, path), "rb") as f:
graph = pickle.load(f)
output_file = graph.write_to_disk()
log.info("Output code written to: %s", output_file)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
# cant assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
with open(path, "rb") as file:
artifacts = file.read()
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
reader = BytesReader(artifacts)
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
autograd_cache_dir = os.path.join(path, "aotautograd")
assert os.path.isdir(autograd_cache_dir)
files = list(os.listdir(autograd_cache_dir))
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
entry = AOTAutogradCache._lookup(
key, local=True, remote=False, args=[], cache_info={}
)
assert entry is not None
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
def standalone_compile(
gm: GraphModule,
example_inputs: Sequence[InputType],
*,
dynamic_shapes: Any,
options: Any,
) -> CompiledArtifact:
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
ignore_shape_env = False
if dynamic_shapes == "from_example_inputs":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# tells compile_fx to ignore the shape_envs on the ambient context
# and the graph_module.
ignore_shape_env = True
elif dynamic_shapes == "from_tracing_context":
# Reuse fake_mode from the TracingContext.
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
context = torch._guards.TracingContext.get()
fake_mode = context.fake_mode
elif dynamic_shapes == "from_graph":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode.
# The graph passed to standalone_compile must be an Inductor-approved graph,
# which means that there is at least one Tensor output and the output node
# contains a flat list of Tensors.
last_node = next(iter(reversed(gm.graph.nodes)))
assert last_node.op == "output"
assert len(last_node.args) == 1
for node in last_node.args[0]:
if "example_value" in node.meta:
maybe_tensor = node.meta["example_value"]
if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor):
fake_mode = maybe_tensor.fake_mode
else:
raise ValueError(
f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}."
)
context = torch._guards.TracingContext(fake_mode)
with (
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
compiled_fn = compile_fx(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
"standalone_compile artifact generation failed, cannot save. "
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CompiledArtifact(compiled_fn, artifacts)