From 06773663b53f151b1a19f496eaa6077ab4ac3e8f Mon Sep 17 00:00:00 2001 From: James Wu Date: Mon, 20 Oct 2025 17:28:45 -0700 Subject: [PATCH] Implement an AOT precompile mode for standalone_compile (#165843) This PR introduces an `aot` flag to standalone_compile that uses BundledAOTAutogradCacheEntry, and then allows regional_inductor to use this so that we can start aot compiling regional compiler graphs. The diff above this will attempt to allow GraphPickler to fully serialize graphs that have regionally compiled subgraphs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165843 Approved by: https://github.com/oulgen --- test/inductor/test_codecache.py | 71 ++++--- torch/_dynamo/aot_compile.py | 64 +------ torch/_dynamo/aot_compile_types.py | 61 ++++++ torch/_inductor/__init__.py | 3 +- torch/_inductor/standalone_compile.py | 257 +++++++++++++++++++++----- torch/fx/passes/regional_inductor.py | 13 +- 6 files changed, 339 insertions(+), 130 deletions(-) create mode 100644 torch/_dynamo/aot_compile_types.py diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index ca2e9007109..2fcd5f42b77 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1839,12 +1839,22 @@ class TestStandaloneCompile(TestCase): @parametrize("format", ("binary", "unpacked")) @parametrize("dynamic", (False, True)) @parametrize("graph_partition", (False, True)) + @parametrize("is_aot", (False, True)) def test_basic( - self, device: str, format: str, dynamic: bool, graph_partition: bool + self, + device: str, + format: str, + dynamic: bool, + graph_partition: bool, + is_aot: bool, ) -> None: if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") + # AOT mode does not support unpacked format + if is_aot and format == "unpacked": + raise unittest.SkipTest("AOT mode does not support unpacked format") + mod = torch.nn.Linear(1, 3, device=device) x = torch.randn(4, 1, device=device) if dynamic: @@ -1869,7 +1879,9 @@ class TestStandaloneCompile(TestCase): gm, args, kwargs = self.capture(f)(x) assert not kwargs - compiled_artifact = torch._inductor.standalone_compile(gm, args) + compiled_artifact = torch._inductor.standalone_compile( + gm, args, aot=is_aot + ) compiled_artifact.save(path=path, format=format) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @@ -1885,13 +1897,15 @@ class TestStandaloneCompile(TestCase): compiled_out = loaded(*concrete_args) self.assertEqual(eager_out, compiled_out) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + if not is_aot: + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) @parametrize("dynamic", (False, True)) - def test_call_in_backend(self, dynamic: bool) -> None: + @parametrize("is_aot", (False, True)) + def test_call_in_backend(self, dynamic: bool, is_aot: bool) -> None: mod = torch.nn.Linear(1, 3) x = torch.randn(4, 1) if dynamic: @@ -1904,7 +1918,7 @@ class TestStandaloneCompile(TestCase): eager_out = f(x) def backend(gm, args, **kwargs): - return torch._inductor.standalone_compile(gm, args) + return torch._inductor.standalone_compile(gm, args, aot=is_aot) with fresh_cache(): compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) @@ -2055,7 +2069,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) - def test_dynamic_shapes_from_graph(self): + @parametrize("is_aot", (False, True)) + def test_dynamic_shapes_from_graph(self, is_aot: bool): def f(x): return x.shape[0] * x @@ -2067,7 +2082,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): assert not kwargs compiled_artifact = torch._inductor.standalone_compile( - gm, args, dynamic_shapes="from_graph" + gm, args, dynamic_shapes="from_graph", aot=is_aot ) x = torch.ones(4) (result,) = compiled_artifact(4, x) @@ -2077,7 +2092,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) @functorch_config.patch({"autograd_cache_normalize_inputs": True}) - def test_split_module(self): + @parametrize("is_aot", (False, True)) + def test_split_module(self, is_aot): class Mod(torch.nn.Module): def forward(self, x, a0, a1, b0, b1, c0, c1): x = x + (a0**2) + (a1 / 2) @@ -2116,16 +2132,24 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): split = torch.fx.passes.split_module.split_module(gm, gm, split) # Each of the split graphs only has one output. - ca0 = torch._inductor.standalone_compile(split.submod_0, (a0, x, a1)) - ca1 = torch._inductor.standalone_compile(split.submod_1, (b0, x, b1)) - ca2 = torch._inductor.standalone_compile(split.submod_2, (c0, x, c1)) + ca0 = torch._inductor.standalone_compile( + split.submod_0, (a0, x, a1), aot=is_aot + ) + ca1 = torch._inductor.standalone_compile( + split.submod_1, (b0, x, b1), aot=is_aot + ) + ca2 = torch._inductor.standalone_compile( + split.submod_2, (c0, x, c1), aot=is_aot + ) y = ca0(a0, x, a1) y = ca1(b0, y, b1) y = ca2(c0, y, c1) - self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) + if not is_aot: + # fx graph cache doesn't run in AOT mode + self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) # TODO: split_module causes ca1 and ca2 to have different type annotations # for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) @@ -2138,8 +2162,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) + @parametrize("is_aot", (False, True)) @parametrize("config_patches", [True, False]) - def test_dynamic_shapes_from_example_inputs(self, config_patches): + def test_dynamic_shapes_from_example_inputs(self, config_patches, is_aot): def f(x): return x.shape[0] * x @@ -2161,6 +2186,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): (5, torch.ones(4)), dynamic_shapes="from_example_inputs", options={"config_patches": config_patches}, + aot=is_aot, ) x = torch.ones(4) (result,) = compiled_artifact(3, x) @@ -2175,8 +2201,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) + @parametrize("is_aot", (True, False)) @parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"]) - def test_static_shapes(self, dynamic_shapes): + def test_static_shapes(self, dynamic_shapes, is_aot): def f(x): return x.shape[0] * x @@ -2186,7 +2213,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x) assert not kwargs compiled_artifact = torch._inductor.standalone_compile( - static_gm, [static_x], dynamic_shapes=dynamic_shapes + static_gm, [static_x], dynamic_shapes=dynamic_shapes, aot=is_aot ) x = torch.randn(3) (result,) = compiled_artifact(x) @@ -2198,8 +2225,9 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) + @parametrize("is_aot", (True, False)) @parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"]) - def test_backend(self, dynamic_shapes): + def test_backend(self, dynamic_shapes, is_aot): def f(x): return x.shape[0] * x @@ -2208,7 +2236,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): def backend(gm, args, **kwargs): compiled_artifact = torch._inductor.standalone_compile( - gm, args, dynamic_shapes=dynamic_shapes + gm, args, dynamic_shapes=dynamic_shapes, aot=is_aot ) y = torch.randn(4) (result,) = compiled_artifact(4, y) @@ -2221,7 +2249,8 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) - def test_backend_dynamic_shapes_from_example_inputs(self): + @parametrize("is_aot", (True, False)) + def test_backend_dynamic_shapes_from_example_inputs(self, is_aot): def f(x): return x.shape[0] * x @@ -2230,7 +2259,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): def backend(gm, args, **kwargs): compiled_artifact = torch._inductor.standalone_compile( - gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs" + gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs", aot=is_aot ) y = torch.ones(4) (result,) = compiled_artifact(4, y) diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index cc1391cb774..396bc8729f7 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -1,4 +1,3 @@ -import abc import dataclasses import importlib import inspect @@ -15,6 +14,10 @@ from torch._dynamo.graph_utils import _graph_device_type from torch._dynamo.package import SystemInfo from . import convert_frame +from .aot_compile_types import ( + BundledAOTAutogradSerializableCallable, + SerializableCallable, +) from .hooks import Hooks @@ -26,18 +29,6 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) -class SerializableCallable(abc.ABC): - @classmethod - @abc.abstractmethod - def serialize_compile_artifacts(cls, fn: Any) -> bytes: - pass - - @classmethod - @abc.abstractmethod - def deserialize_compile_artifacts(cls, data: bytes) -> Any: - pass - - def bind_locals( signature: inspect.Signature, *args: Any, **kwargs: Any ) -> dict[str, Any]: @@ -149,53 +140,6 @@ class AOTCompiledFunction: self._guard_check_enabled = False -class BundledAOTAutogradSerializableCallable(SerializableCallable): - """ - Represents a serializable callable generated by compile_fx. - This class wraps around the compiled function generated by AOTAutograd. - - TODO: Instead of using PrecompileContext to grab it from AOTAutograd, - this object should be what's *returned* by aot_module_simplified. - We'll do that refactor in a later PR. - """ - - def __init__(self, compiled_fn: Any) -> None: - """ - Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form - of a compiled function generated by AOTAutograd. - """ - assert hasattr(compiled_fn, "serialize") - self.compiled_fn = compiled_fn - - def __getattr__(self, attr: Any) -> Any: - if hasattr(self, attr): - return getattr(super(), attr) - else: - return getattr(self.compiled_fn, attr) - - @classmethod - def serialize_compile_artifacts( - cls, fn: "BundledAOTAutogradSerializableCallable" - ) -> bytes: - with torch._functorch.config.patch("bundled_autograd_cache", True): - result = pickle.dumps(fn.compiled_fn.serialize()) - return result - - @classmethod - def deserialize_compile_artifacts(cls, data: bytes) -> Any: - from torch._functorch._aot_autograd.autograd_cache import ( - deserialize_bundled_cache_entry, - ) - - entry = pickle.loads(data) - - compiled_fn = deserialize_bundled_cache_entry(entry) - return cls(compiled_fn) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.compiled_fn(*args, **kwargs) - - def aot_compile_fullgraph( model: Any, example_inputs: tuple[tuple[Any, ...], dict[str, Any]], diff --git a/torch/_dynamo/aot_compile_types.py b/torch/_dynamo/aot_compile_types.py new file mode 100644 index 00000000000..2d605531bd0 --- /dev/null +++ b/torch/_dynamo/aot_compile_types.py @@ -0,0 +1,61 @@ +import abc +import pickle +from typing import Any + +import torch + + +class SerializableCallable(abc.ABC): + @classmethod + @abc.abstractmethod + def serialize_compile_artifacts(cls, fn: Any) -> bytes: + pass + + @classmethod + @abc.abstractmethod + def deserialize_compile_artifacts(cls, data: bytes) -> Any: + pass + + +class BundledAOTAutogradSerializableCallable(SerializableCallable): + """ + Represents a serializable callable generated by compile_fx. + This class wraps around the compiled function generated by AOTAutograd. + + TODO: Instead of using PrecompileContext to grab it from AOTAutograd, + this object should be what's *returned* by aot_module_simplified. + We'll do that refactor in a later PR. + """ + + def __init__(self, compiled_fn: Any) -> None: + """ + Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form + of a compiled function generated by AOTAutograd. + """ + assert hasattr(compiled_fn, "serialize") + self.compiled_fn = compiled_fn + + def __getattr__(self, attr: Any) -> Any: + return getattr(self.compiled_fn, attr) + + @classmethod + def serialize_compile_artifacts( + cls, fn: "BundledAOTAutogradSerializableCallable" + ) -> bytes: + with torch._functorch.config.patch("bundled_autograd_cache", True): + result = pickle.dumps(fn.compiled_fn.serialize()) + return result + + @classmethod + def deserialize_compile_artifacts(cls, data: bytes) -> Any: + from torch._functorch._aot_autograd.autograd_cache import ( + deserialize_bundled_cache_entry, + ) + + entry = pickle.loads(data) + + compiled_fn = deserialize_bundled_cache_entry(entry) + return cls(compiled_fn) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.compiled_fn(*args, **kwargs) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 9c109068401..a49b64a28cd 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -391,6 +391,7 @@ def standalone_compile( "from_example_inputs", "from_tracing_context", "from_graph" ] = "from_graph", options: Optional[dict[str, Any]] = None, + aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache ) -> CompiledArtifact: """ Precompilation API for inductor. @@ -422,5 +423,5 @@ def standalone_compile( options = options if options else {} return standalone_compile( - gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options + gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot ) diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index 0d21b06f718..536cfaaa1ec 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -5,10 +5,12 @@ import logging import os import pickle import shutil +from abc import ABC, abstractmethod from contextlib import AbstractContextManager, nullcontext from typing import Any, Callable, Literal, Optional, TYPE_CHECKING import torch.fx +from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable from torch._dynamo.utils import dynamo_timed from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.cudagraph_utils import BoxedDeviceIndex @@ -30,9 +32,9 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) -class CompiledArtifact: +class CompiledArtifact(ABC): """ - CompiledArtifact class represents the precompiled inductor artifact that + CompiledArtifact class represents the inductor cache artifacts that can be invoked in order to avoid repeated compilation. CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs) @@ -45,11 +47,68 @@ class CompiledArtifact: binary or unpacked data. Finally, the CompiledArtifact can be invoked via the __call__ method - to execute the precompiled artifact. + to execute the cached artifact. """ - _compiled_fn: Callable[..., Any] - _artifacts: Optional[tuple[bytes, CacheInfo]] + def __init__( + self, + compiled_fn: Callable[..., Any], + artifacts: Optional[tuple[bytes, CacheInfo]], + ): + self._compiled_fn = compiled_fn + self._artifacts = artifacts + + @abstractmethod + def __call__(self, *args: Any) -> Any: ... + + @abstractmethod + def save( + self, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: ... + + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + if format == "unpacked": + # If format is unpacked, it must be a CacheCompiledArtifact + return CacheCompiledArtifact.load(path=path, format=format) + + assert format == "binary" + with open(path, "rb") as file: + from torch.utils._appending_byte_serializer import BytesReader + + from .codecache import torch_key + + result_bytes = file.read() + reader = BytesReader(result_bytes) + header = reader.read_bytes() + if header == AOTCompiledArtifact.AOT_HEADER: + assert reader.read_bytes() == torch_key() + artifact = reader.read_bytes() + assert reader.is_finished() + return AOTCompiledArtifact.deserialize(artifact) + # Otherwise, it's in the CacheCompiledArtifact format + elif header == CacheCompiledArtifact.CACHE_HEADER: + assert reader.read_bytes() == torch_key() + key = reader.read_str() + artifact_bytes = reader.read_bytes() + assert reader.is_finished() + torch.compiler.load_cache_artifacts(artifact_bytes) + return CacheCompiledArtifact._load_impl(nullcontext(), key) + else: + raise RuntimeError( + "Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: " + + header.decode("utf-8") + ) + + +class CacheCompiledArtifact(CompiledArtifact): + """ + CompiledArtifact that depends on torch.compiler.save_cache_artifacts + """ + + CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8") def __init__( self, @@ -83,6 +142,7 @@ class CompiledArtifact: from .codecache import torch_key writer = BytesWriter() + writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER) writer.write_bytes(torch_key()) writer.write_str(key) writer.write_bytes(artifact_bytes) @@ -116,9 +176,51 @@ class CompiledArtifact: log.info("Output code written to: %s", output_file) @staticmethod - def load( - *, path: str, format: Literal["binary", "unpacked"] = "binary" + def _load_impl( + cache_dir_ctx: AbstractContextManager[Any], key: str ) -> CompiledArtifact: + with ( + cache_dir_ctx, + config.patch(unsafe_skip_cache_dynamic_shape_guards=True), + ): + with torch._functorch.config.patch(strict_autograd_cache=True): + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache, + ) + + result = AOTAutogradCache._lookup( + key, + local=True, + remote=False, + args=[], + cache_info={}, + aot_config=None, + ) + + assert result is not None + (entry, _) = result + + from .compile_fx import _CompileFxKwargs + + fx_config = _CompileFxKwargs( + cudagraphs=BoxedBool(False), + boxed_forward_device_index=BoxedDeviceIndex(0), + ) + + context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) + with torch._guards.tracing(context): + compiled_fn = entry.wrap_post_compile( + [], entry.sanitized_aot_config, fx_config + ) + return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None) + + @staticmethod + def _prepare_load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> tuple[str, AbstractContextManager[Any]]: + """ + Do format specific prep and loads, return a context manager and key + """ path = normalize_path_separator(path) with dynamo_timed("CompiledArtifact.load"): if format == "binary": @@ -137,8 +239,7 @@ class CompiledArtifact: assert reader.is_finished() torch.compiler.load_cache_artifacts(artifact_bytes) - - cache_dir_ctx: AbstractContextManager[None] = nullcontext() + return key, nullcontext() else: assert format == "unpacked" assert os.path.isdir(path) @@ -148,43 +249,105 @@ class CompiledArtifact: assert len(files) == 1 key = files[0] cache_dir_ctx = temporary_cache_dir(path) + return key, cache_dir_ctx - with ( - cache_dir_ctx, - config.patch(unsafe_skip_cache_dynamic_shape_guards=True), - ): - with torch._functorch.config.patch(strict_autograd_cache=True): - from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCache, - ) + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + key, cache_dir_ctx = CacheCompiledArtifact._prepare_load( + path=path, format=format + ) + return CacheCompiledArtifact._load_impl(cache_dir_ctx, key) - result = AOTAutogradCache._lookup( - key, - local=True, - remote=False, - args=[], - cache_info={}, - aot_config=None, - ) - assert result is not None - (entry, _) = result +class AOTCompiledArtifact(CompiledArtifact): + """ + Similar to CompiledArtifact, but the object is a single, bundled precompiled function. + This object is always a serializable callable function. - from .compile_fx import _CompileFxKwargs + This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which + is used by torch._dynamo.aot_compile for AOT Precompilation. + """ - fx_config = _CompileFxKwargs( - cudagraphs=BoxedBool(False), - boxed_forward_device_index=BoxedDeviceIndex(0), - ) + AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8") - context = torch._guards.TracingContext( - FakeTensorMode(shape_env=ShapeEnv()) - ) - with torch._guards.tracing(context): - compiled_fn = entry.wrap_post_compile( - [], entry.sanitized_aot_config, fx_config - ) - return CompiledArtifact(lambda *args: compiled_fn(list(args)), None) + def __init__( + self, + compiled_fn: Callable[..., Any], + ): + self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn) + self._artifacts = ( + None # We don't need artifacts, the inner object handles everything + ) + + @staticmethod + def from_bundled_callable( + bundled_fn: BundledAOTAutogradSerializableCallable, + ) -> AOTCompiledArtifact: + return AOTCompiledArtifact(bundled_fn.compiled_fn) + + def __call__(self, *args: Any) -> Any: + return self.inner_fn(*args) + + def save( + self, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: + if format == "unpacked": + raise RuntimeError( + "AOTCompiledArtifact does not support unpacked format yet" + ) + result_bytes = self.serialize() + from torch.utils._appending_byte_serializer import BytesWriter + + from .codecache import torch_key + + writer = BytesWriter() + writer.write_bytes(AOTCompiledArtifact.AOT_HEADER) + writer.write_bytes(torch_key()) + writer.write_bytes(result_bytes) + + from torch._inductor.codecache import write_atomic + + # Save a sentinel file to indicate that this is AOT + write_atomic(path, writer.to_bytes()) + + def serialize(self) -> bytes: + return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts( + self.inner_fn + ) + + @staticmethod + def deserialize(result_bytes: bytes) -> AOTCompiledArtifact: + deserialized = ( + BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts( + result_bytes + ) + ) + assert isinstance(deserialized, BundledAOTAutogradSerializableCallable) + return AOTCompiledArtifact.from_bundled_callable(deserialized) + + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + if format == "unpacked": + raise RuntimeError( + "AOTCompiledArtifact does not support unpacked format yet" + ) + with open(path, "rb") as file: + from torch.utils._appending_byte_serializer import BytesReader + + from .codecache import torch_key + + result_bytes = file.read() + reader = BytesReader(result_bytes) + header = reader.read_bytes() + assert header == AOTCompiledArtifact.AOT_HEADER + assert reader.read_bytes() == torch_key() + artifact = reader.read_bytes() + assert reader.is_finished() + return AOTCompiledArtifact.deserialize(artifact) def standalone_compile( @@ -193,7 +356,11 @@ def standalone_compile( *, dynamic_shapes: Any, options: Any, + aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache ) -> CompiledArtifact: + """ + Implementation of torch.inductor.standalone_compile + """ from torch.compiler._cache import CacheArtifactManager from .compile_fx import compile_fx @@ -249,6 +416,7 @@ def standalone_compile( torch._guards.tracing(context), CacheArtifactManager.with_fresh_cache(), config.patch("triton.autotune_at_compile_time", True), + torch._functorch.config.patch("bundled_autograd_cache", aot), ): # compile_fx can mutate gm gm = copy.deepcopy(gm) @@ -256,7 +424,12 @@ def standalone_compile( gm, example_inputs, ignore_shape_env=ignore_shape_env, **options ) assert callable(compiled_fn) - + if aot: + if not hasattr(compiled_fn, "serialize"): + raise RuntimeError( + "Compiled function should have serialize method when aot=True" + ) + return AOTCompiledArtifact(compiled_fn) artifacts = torch.compiler.save_cache_artifacts() if artifacts is None: log.warning( @@ -264,4 +437,4 @@ def standalone_compile( "Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem" ) - return CompiledArtifact(compiled_fn, artifacts) + return CacheCompiledArtifact(compiled_fn, artifacts) diff --git a/torch/fx/passes/regional_inductor.py b/torch/fx/passes/regional_inductor.py index dfd1643513e..4da422e4884 100644 --- a/torch/fx/passes/regional_inductor.py +++ b/torch/fx/passes/regional_inductor.py @@ -43,6 +43,8 @@ def _partition_by_supported_nodes(gm, supported_ops, prefix): def _compile_submod(gm, prefix): + from torch._inductor.standalone_compile import AOTCompiledArtifact + for node in gm.graph.nodes: if node.op == "call_module" and node.target.startswith(prefix): fake_inputs = [] @@ -56,13 +58,12 @@ def _compile_submod(gm, prefix): submod = getattr(gm, node.target) - # _dummy_wrapper is to make call_function happy - compiled_submod = _dummy_wrapper( - torch._inductor.standalone_compile( - submod, fake_inputs, dynamic_shapes="from_tracing_context" - ) + compiled_fn = torch._inductor.standalone_compile( + submod, fake_inputs, dynamic_shapes="from_tracing_context", aot=True ) - + assert isinstance(compiled_fn, AOTCompiledArtifact) + # _dummy_wrapper is to make call_function happy + compiled_submod = _dummy_wrapper(compiled_fn) with gm.graph.inserting_after(node): new_node = gm.graph.call_function( compiled_submod, args=node.args, kwargs=node.kwargs