Implement an AOT precompile mode for standalone_compile (#165843)

This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843
Approved by: https://github.com/oulgen
This commit is contained in:
James Wu 2025-10-20 17:28:45 -07:00 committed by PyTorch MergeBot
parent 0bff65503c
commit 06773663b5
6 changed files with 339 additions and 130 deletions

View File

@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase):
@parametrize("format", ("binary", "unpacked"))
@parametrize("dynamic", (False, True))
@parametrize("graph_partition", (False, True))
@parametrize("is_aot", (False, True))
def test_basic(
self, device: str, format: str, dynamic: bool, graph_partition: bool
self,
device: str,
format: str,
dynamic: bool,
graph_partition: bool,
is_aot: bool,
) -> None:
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
# AOT mode does not support unpacked format
if is_aot and format == "unpacked":
raise unittest.SkipTest("AOT mode does not support unpacked format")
mod = torch.nn.Linear(1, 3, device=device)
x = torch.randn(4, 1, device=device)
if dynamic:
@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase):
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact = torch._inductor.standalone_compile(
gm, args, aot=is_aot
)
compiled_artifact.save(path=path, format=format)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase):
compiled_out = loaded(*concrete_args)
self.assertEqual(eager_out, compiled_out)
if not is_aot:
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic", (False, True))
def test_call_in_backend(self, dynamic: bool) -> None:
@parametrize("is_aot", (False, True))
def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
if dynamic:
@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase):
eager_out = f(x)
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(gm, args)
return torch._inductor.standalone_compile(gm, args, aot=is_aot)
with fresh_cache():
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_dynamic_shapes_from_graph(self):
@parametrize("is_aot", (False, True))
def test_dynamic_shapes_from_graph(self, is_aot: bool):
def f(x):
return x.shape[0] * x
@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes="from_graph"
gm, args, dynamic_shapes="from_graph", aot=is_aot
)
x = torch.ones(4)
(result,) = compiled_artifact(4, x)
@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_normalize_inputs": True})
def test_split_module(self):
@parametrize("is_aot", (False, True))
def test_split_module(self, is_aot):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
x = x + (a0**2) + (a1 / 2)
@ -2116,13 +2132,21 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
split = torch.fx.passes.split_module.split_module(gm, gm, split)
# Each of the split graphs only has one output.
ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1))
ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1))
ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1))
ca0 = torch._inductor.standalone_compile(
split.submod_0, (a0, x, a1), aot=is_aot
)
ca1 = torch._inductor.standalone_compile(
split.submod_1, (b0, x, b1), aot=is_aot
)
ca2 = torch._inductor.standalone_compile(
split.submod_2, (c0, x, c1), aot=is_aot
)
y = ca0(a0, x, a1)
y = ca1(b0, y, b1)
y = ca2(c0, y, c1)
if not is_aot:
# fx graph cache doesn't run in AOT mode
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (False, True))
@parametrize("config_patches", [True, False])
def test_dynamic_shapes_from_example_inputs(self, config_patches):
def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot):
def f(x):
return x.shape[0] * x
@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
(5, torch.ones(4)),
dynamic_shapes="from_example_inputs",
options={"config_patches": config_patches},
aot=is_aot,
)
x = torch.ones(4)
(result,) = compiled_artifact(3, x)
@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
def test_static_shapes(self, dynamic_shapes):
def test_static_shapes(self, dynamic_shapes, is_aot):
def f(x):
return x.shape[0] * x
@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
static_gm, [static_x], dynamic_shapes=dynamic_shapes
static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot
)
x = torch.randn(3)
(result,) = compiled_artifact(x)
@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("is_aot", (True, False))
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
def test_backend(self, dynamic_shapes):
def test_backend(self, dynamic_shapes, is_aot):
def f(x):
return x.shape[0] * x
@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot
)
y = torch.randn(4)
(result,) = compiled_artifact(4, y)
@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_backend_dynamic_shapes_from_example_inputs(self):
@parametrize("is_aot", (True, False))
def test_backend_dynamic_shapes_from_example_inputs(self, is_aot):
def f(x):
return x.shape[0] * x
@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot
)
y = torch.ones(4)
(result,) = compiled_artifact(4, y)

View File

@ -1,4 +1,3 @@
import abc
import dataclasses
import importlib
import inspect
@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type
from torch._dynamo.package import SystemInfo
from . import convert_frame
from .aot_compile_types import (
BundledAOTAutogradSerializableCallable,
SerializableCallable,
)
from .hooks import Hooks
@ -26,18 +29,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
def bind_locals(
signature: inspect.Signature, *args: Any, **kwargs: Any
) -> dict[str, Any]:
@ -149,53 +140,6 @@ class AOTCompiledFunction:
self._guard_check_enabled = False
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
if hasattr(self, attr):
return getattr(super(), attr)
else:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)
def aot_compile_fullgraph(
model: Any,
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],

View File

@ -0,0 +1,61 @@
import abc
import pickle
from typing import Any
import torch
class SerializableCallable(abc.ABC):
@classmethod
@abc.abstractmethod
def serialize_compile_artifacts(cls, fn: Any) -> bytes:
pass
@classmethod
@abc.abstractmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
pass
class BundledAOTAutogradSerializableCallable(SerializableCallable):
"""
Represents a serializable callable generated by compile_fx.
This class wraps around the compiled function generated by AOTAutograd.
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
this object should be what's *returned* by aot_module_simplified.
We'll do that refactor in a later PR.
"""
def __init__(self, compiled_fn: Any) -> None:
"""
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
of a compiled function generated by AOTAutograd.
"""
assert hasattr(compiled_fn, "serialize")
self.compiled_fn = compiled_fn
def __getattr__(self, attr: Any) -> Any:
return getattr(self.compiled_fn, attr)
@classmethod
def serialize_compile_artifacts(
cls, fn: "BundledAOTAutogradSerializableCallable"
) -> bytes:
with torch._functorch.config.patch("bundled_autograd_cache", True):
result = pickle.dumps(fn.compiled_fn.serialize())
return result
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
from torch._functorch._aot_autograd.autograd_cache import (
deserialize_bundled_cache_entry,
)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_fn(*args, **kwargs)

View File

@ -391,6 +391,7 @@ def standalone_compile(
"from_example_inputs", "from_tracing_context", "from_graph"
] = "from_graph",
options: Optional[dict[str, Any]] = None,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Precompilation API for inductor.
@ -422,5 +423,5 @@ def standalone_compile(
options = options if options else {}
return standalone_compile(
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot
)

View File

@ -5,10 +5,12 @@ import logging
import os
import pickle
import shutil
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
@ -30,9 +32,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class CompiledArtifact:
class CompiledArtifact(ABC):
"""
CompiledArtifact class represents the precompiled inductor artifact that
CompiledArtifact class represents the inductor cache artifacts that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
@ -45,11 +47,68 @@ class CompiledArtifact:
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the precompiled artifact.
to execute the cached artifact.
"""
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
@abstractmethod
def __call__(self, *args: Any) -> Any: ...
@abstractmethod
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None: ...
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
# If format is unpacked, it must be a CacheCompiledArtifact
return CacheCompiledArtifact.load(path=path, format=format)
assert format == "binary"
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
if header == AOTCompiledArtifact.AOT_HEADER:
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
# Otherwise, it's in the CacheCompiledArtifact format
elif header == CacheCompiledArtifact.CACHE_HEADER:
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return CacheCompiledArtifact._load_impl(nullcontext(), key)
else:
raise RuntimeError(
"Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: "
+ header.decode("utf-8")
)
class CacheCompiledArtifact(CompiledArtifact):
"""
CompiledArtifact that depends on torch.compiler.save_cache_artifacts
"""
CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8")
def __init__(
self,
@ -83,6 +142,7 @@ class CompiledArtifact:
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER)
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
@ -116,39 +176,9 @@ class CompiledArtifact:
log.info("Output code written to: %s", output_file)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
def _load_impl(
cache_dir_ctx: AbstractContextManager[Any], key: str
) -> CompiledArtifact:
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
# can't assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
with open(path, "rb") as file:
artifacts = file.read()
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
reader = BytesReader(artifacts)
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
autograd_cache_dir = os.path.join(path, "aotautograd")
assert os.path.isdir(autograd_cache_dir)
files = list(os.listdir(autograd_cache_dir))
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
@ -177,14 +207,147 @@ class CompiledArtifact:
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv()))
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None)
@staticmethod
def _prepare_load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> tuple[str, AbstractContextManager[Any]]:
"""
Do format specific prep and loads, return a context manager and key
"""
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
# can't assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
with open(path, "rb") as file:
artifacts = file.read()
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
reader = BytesReader(artifacts)
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
return key, nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
autograd_cache_dir = os.path.join(path, "aotautograd")
assert os.path.isdir(autograd_cache_dir)
files = list(os.listdir(autograd_cache_dir))
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
return key, cache_dir_ctx
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
key, cache_dir_ctx = CacheCompiledArtifact._prepare_load(
path=path, format=format
)
return CacheCompiledArtifact._load_impl(cache_dir_ctx, key)
class AOTCompiledArtifact(CompiledArtifact):
"""
Similar to CompiledArtifact, but the object is a single, bundled precompiled function.
This object is always a serializable callable function.
This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which
is used by torch._dynamo.aot_compile for AOT Precompilation.
"""
AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8")
def __init__(
self,
compiled_fn: Callable[..., Any],
):
self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
self._artifacts = (
None # We don't need artifacts, the inner object handles everything
)
@staticmethod
def from_bundled_callable(
bundled_fn: BundledAOTAutogradSerializableCallable,
) -> AOTCompiledArtifact:
return AOTCompiledArtifact(bundled_fn.compiled_fn)
def __call__(self, *args: Any) -> Any:
return self.inner_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
result_bytes = self.serialize()
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(AOTCompiledArtifact.AOT_HEADER)
writer.write_bytes(torch_key())
writer.write_bytes(result_bytes)
from torch._inductor.codecache import write_atomic
# Save a sentinel file to indicate that this is AOT
write_atomic(path, writer.to_bytes())
def serialize(self) -> bytes:
return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts(
self.inner_fn
)
@staticmethod
def deserialize(result_bytes: bytes) -> AOTCompiledArtifact:
deserialized = (
BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts(
result_bytes
)
)
assert isinstance(deserialized, BundledAOTAutogradSerializableCallable)
return AOTCompiledArtifact.from_bundled_callable(deserialized)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
if format == "unpacked":
raise RuntimeError(
"AOTCompiledArtifact does not support unpacked format yet"
)
with open(path, "rb") as file:
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
result_bytes = file.read()
reader = BytesReader(result_bytes)
header = reader.read_bytes()
assert header == AOTCompiledArtifact.AOT_HEADER
assert reader.read_bytes() == torch_key()
artifact = reader.read_bytes()
assert reader.is_finished()
return AOTCompiledArtifact.deserialize(artifact)
def standalone_compile(
@ -193,7 +356,11 @@ def standalone_compile(
*,
dynamic_shapes: Any,
options: Any,
aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache
) -> CompiledArtifact:
"""
Implementation of torch.inductor.standalone_compile
"""
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
@ -249,6 +416,7 @@ def standalone_compile(
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
torch._functorch.config.patch("bundled_autograd_cache", aot),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
@ -256,7 +424,12 @@ def standalone_compile(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
if aot:
if not hasattr(compiled_fn, "serialize"):
raise RuntimeError(
"Compiled function should have serialize method when aot=True"
)
return AOTCompiledArtifact(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
@ -264,4 +437,4 @@ def standalone_compile(
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CompiledArtifact(compiled_fn, artifacts)
return CacheCompiledArtifact(compiled_fn, artifacts)

View File

@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix):
def _compile_submod(gm, prefix):
from torch._inductor.standalone_compile import AOTCompiledArtifact
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith(prefix):
fake_inputs = []
@ -56,13 +58,12 @@ def _compile_submod(gm, prefix):
submod = getattr(gm, node.target)
compiled_fn = torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True
)
assert isinstance(compiled_fn, AOTCompiledArtifact)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(
torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context"
)
)
compiled_submod = _dummy_wrapper(compiled_fn)
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
compiled_submod, args=node.args, kwargs=node.kwargs