mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add regional aot eager support to AOTAutogradCacheEntry (#166650)
This PR does two things: - It genericizes `BundledAOTAutogradCacheEntry` to support *any* outputcode, not just CompiledFxGraphs - It adds a brand new OutputCode for the `aot_eager_regional_inductor` backend, i.e. a graph module that has regional inductor components in it. This allows BundledAOTAutogradCache to just integrate nicely with inductor out of the box, but more importantly, it allows the result of aot_autograd to be fully serializable when using `aot_eager_regional_inductor`. This will allow us to AOT precompile cases where we have an eager graph that has scooped up inductor bits. It's a bit unfortunate that the naming makes BundledAOTAutogradCacheEntry sound like its primary use is for caching, but really the more common use is going to be as an AOTAutogradOutput. It may be worth revisiting how to refactor/rename these in a later PR: - AOTAutogradCacheEntry -> AOTAutogradResult - BundledAOTAutogradCacheEntry -> BundledAOTAutogradResult Pull Request resolved: https://github.com/pytorch/pytorch/pull/166650 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
b470e59c38
commit
30157d30f0
|
|
@ -41,6 +41,20 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
def aot_eager_regional_inductor():
|
||||
"""
|
||||
Regional inductor backend for AOT autograd.
|
||||
Uses regional_inductor as both forward and backward compiler.
|
||||
"""
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch.fx.passes.regional_inductor import regional_inductor
|
||||
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor,
|
||||
bw_compiler=regional_inductor,
|
||||
)
|
||||
|
||||
|
||||
def saved_tensors_hooks_to_gm(
|
||||
pack_fn,
|
||||
unpack_fn,
|
||||
|
|
@ -1898,6 +1912,171 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
# no recompiles
|
||||
self.assertFalse(counters)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
def test_regional_inductor_basic(self):
|
||||
"""
|
||||
Basic test for regional inductor with bundled autograd cache.
|
||||
Tests that regional inductor compilation results can be cached and hit.
|
||||
"""
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
# Mark this region to be compiled with inductor
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10, device="cpu")
|
||||
y = torch.randn(10, device="cpu")
|
||||
|
||||
# Compile with regional inductor backend
|
||||
compiled_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
|
||||
# First call should miss in cache
|
||||
result1 = compiled_fn(x, y)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# Second call should hit (after clearing dynamo)
|
||||
self._clear_dynamo_and_codecache()
|
||||
result2 = compiled_fn(x, y)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# Results should be the same
|
||||
self.assertEqual(result1, result2)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
def test_regional_inductor_with_backward(self):
|
||||
"""
|
||||
Test regional inductor with backward pass and bundled autograd cache.
|
||||
Note: Regional inductor triggers multiple AOT autograd compilations:
|
||||
- One for the outer graph (with regional inductor backend)
|
||||
- One for each marked region (via standalone_compile)
|
||||
"""
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
# Mark this region to be compiled with inductor
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
x2 = x.detach().clone().requires_grad_(True)
|
||||
y2 = y.detach().clone().requires_grad_(True)
|
||||
|
||||
# Compile with regional inductor backend
|
||||
compiled_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
|
||||
# First call: AOT autograd compiles the outer graph (1 miss)
|
||||
# Regional inductor then compiles the marked region (1 more miss)
|
||||
result1 = compiled_fn(x, y)
|
||||
result1.sum().backward()
|
||||
|
||||
# We expect 2 cache misses: outer graph + marked region
|
||||
initial_misses = counters["aot_autograd"]["autograd_cache_miss"]
|
||||
initial_saves = counters["aot_autograd"]["autograd_cache_saved"]
|
||||
self.assertGreater(initial_misses, 0)
|
||||
self.assertGreater(initial_saves, 0)
|
||||
|
||||
# Second call should hit (after clearing dynamo)
|
||||
self._clear_dynamo_and_codecache()
|
||||
result2 = compiled_fn(x2, y2)
|
||||
result2.sum().backward()
|
||||
|
||||
# Should have cache hits now
|
||||
final_hits = counters["aot_autograd"]["autograd_cache_hit"]
|
||||
self.assertGreater(final_hits, 0)
|
||||
|
||||
# Cache misses and saves should not increase
|
||||
self.assertEqual(
|
||||
counters["aot_autograd"]["autograd_cache_miss"], initial_misses
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["aot_autograd"]["autograd_cache_saved"], initial_saves
|
||||
)
|
||||
|
||||
# Results and gradients should be the same
|
||||
self.assertEqual(result1, result2)
|
||||
self.assertEqual(x.grad, x2.grad)
|
||||
self.assertEqual(y.grad, y2.grad)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
def test_regional_inductor_cache_miss_on_change(self):
|
||||
"""
|
||||
Test that changing the function causes a cache miss with regional inductor.
|
||||
Regional inductor creates multiple AOT compilations, so we track
|
||||
the change in cache misses rather than absolute counts.
|
||||
"""
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
def fn1(x, y):
|
||||
sin = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
def fn2(x, y):
|
||||
sin = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 2 # Changed from +1 to +2
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10)
|
||||
y = torch.randn(10)
|
||||
|
||||
# Compile first function
|
||||
compiled_fn1 = torch.compile(
|
||||
fn1, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
result1 = compiled_fn1(x, y)
|
||||
first_misses = counters["aot_autograd"]["autograd_cache_miss"]
|
||||
first_saves = counters["aot_autograd"]["autograd_cache_saved"]
|
||||
self.assertGreater(first_misses, 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertGreater(first_saves, 0)
|
||||
|
||||
# Compile second function (different graph)
|
||||
self._clear_dynamo_and_codecache()
|
||||
compiled_fn2 = torch.compile(
|
||||
fn2, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
result2 = compiled_fn2(x, y)
|
||||
# Should miss because graph is different (more misses than before)
|
||||
self.assertGreater(
|
||||
counters["aot_autograd"]["autograd_cache_miss"], first_misses
|
||||
)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertGreater(
|
||||
counters["aot_autograd"]["autograd_cache_saved"], first_saves
|
||||
)
|
||||
|
||||
# Results should be different
|
||||
self.assertNotEqual(result1, result2)
|
||||
|
||||
|
||||
@functorch_config.patch({"bundled_autograd_cache": True})
|
||||
class AOTAutogradCacheBundledTests(AOTAutogradCacheTests):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._inductor.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.output_code import RegionalOutputCode
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
|
@ -21,6 +24,10 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.compile_fx import _CompileFxKwargs
|
||||
|
||||
|
||||
# Open questions / follow-ups
|
||||
# 1) CSE behavior with meta custom nodes
|
||||
# Common subexpression elimination may not differentiate between distinct meta
|
||||
|
|
@ -462,5 +469,154 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
|||
self.assertEqual(len(codes), 2)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):
|
||||
"""Tests for RegionalOutputCode and RegionalAOTAutogradCacheEntry."""
|
||||
|
||||
def test_regional_output_code_serialization(self):
|
||||
"""Test that RegionalOutputCode can be serialized and deserialized."""
|
||||
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
# Compile with regional inductor
|
||||
with torch.fx.traceback.preserve_node_meta(enable=False):
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
fake_y = fake_mode.from_tensor(y)
|
||||
gm = make_fx(fn)(fake_x, fake_y)
|
||||
|
||||
# Run regional_inductor on the graph
|
||||
result_gm = regional_inductor(gm, fake_x, fake_y)
|
||||
|
||||
# Create RegionalOutputCode
|
||||
output_code = RegionalOutputCode(result_gm)
|
||||
|
||||
# Test that we can call it
|
||||
self.assertIsNotNone(output_code._graph_module)
|
||||
|
||||
# Serialize
|
||||
output_code.prepare_for_serialization()
|
||||
self.assertIsNone(output_code._graph_module)
|
||||
self.assertIsNotNone(output_code._serialized_graph_module)
|
||||
|
||||
# Deserialize via post_compile
|
||||
from torch._inductor.output_code import CompiledFxGraphConstants
|
||||
|
||||
fx_config: _CompileFxKwargs = {"is_backward": False}
|
||||
output_code.post_compile(
|
||||
[fake_x, fake_y], CompiledFxGraphConstants(), fx_config
|
||||
)
|
||||
self.assertIsNotNone(output_code._graph_module)
|
||||
self.assertIsInstance(output_code._graph_module, torch.fx.GraphModule)
|
||||
|
||||
# Test that deserialized graph works
|
||||
with fake_mode:
|
||||
result = output_code([fake_x, fake_y])
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_regional_output_code_with_backward(self):
|
||||
"""Test RegionalOutputCode with both forward and backward compilation."""
|
||||
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
return torch.sin(add)
|
||||
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
# Compile with regional inductor backend
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
fake_y = fake_mode.from_tensor(y)
|
||||
|
||||
# Create forward graph
|
||||
with torch.fx.traceback.preserve_node_meta(enable=False):
|
||||
gm = make_fx(fn)(fake_x, fake_y)
|
||||
forward_gm = regional_inductor(gm, fake_x, fake_y)
|
||||
|
||||
# Create forward output code
|
||||
fw_code = RegionalOutputCode(forward_gm)
|
||||
|
||||
# Verify it can be called
|
||||
with fake_mode:
|
||||
result = fw_code([fake_x, fake_y])
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# Test serialization round-trip
|
||||
fw_code.prepare_for_serialization()
|
||||
|
||||
# Deserialize via post_compile
|
||||
|
||||
from torch._inductor.output_code import CompiledFxGraphConstants
|
||||
|
||||
fx_config: _CompileFxKwargs = {"is_backward": False}
|
||||
fw_code.post_compile([fake_x, fake_y], CompiledFxGraphConstants(), fx_config)
|
||||
|
||||
with fake_mode:
|
||||
result2 = fw_code([fake_x, fake_y])
|
||||
self.assertIsNotNone(result2)
|
||||
|
||||
def test_regional_compiled_forward_backward(self):
|
||||
"""Test BundledCompiledForward and BundledCompiledBackward with RegionalOutputCode."""
|
||||
|
||||
def fn(x):
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
return torch.sin(x) * 2
|
||||
|
||||
x = torch.randn(5, requires_grad=True)
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
|
||||
with torch.fx.traceback.preserve_node_meta(enable=False):
|
||||
gm = make_fx(fn)(fake_x)
|
||||
compiled_gm = regional_inductor(gm, fake_x)
|
||||
|
||||
# Create forward using the generic BundledCompiledForward
|
||||
fw_code = RegionalOutputCode(compiled_gm)
|
||||
fw_compiled = BundledCompiledForward[RegionalOutputCode](result=fw_code)
|
||||
|
||||
# Test pre_save
|
||||
fw_compiled.pre_save()
|
||||
# After pre_save, fw_compiled.result is a copy with serialized graph
|
||||
self.assertIsNotNone(fw_compiled.result._serialized_graph_module)
|
||||
self.assertIsNone(
|
||||
fw_compiled.result._graph_module
|
||||
) # Should be cleared after serialization
|
||||
|
||||
# Test load (doesn't deserialize yet)
|
||||
loaded_code = fw_compiled.load([fake_x])
|
||||
self.assertIsNone(loaded_code._graph_module) # Not yet deserialized
|
||||
self.assertIsNotNone(loaded_code._serialized_graph_module)
|
||||
|
||||
fx_config: _CompileFxKwargs = {"is_backward": False}
|
||||
post_compiled = fw_compiled.post_compile(loaded_code, fx_config)
|
||||
self.assertIsNotNone(post_compiled)
|
||||
self.assertIsNotNone(post_compiled._graph_module) # Now deserialized
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -524,32 +524,40 @@ class InductorOutput(ABC, Generic[TOut]):
|
|||
def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ...
|
||||
|
||||
|
||||
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
|
||||
class BundledOutputCodeLoadable(InductorOutput[TOutputCode], Generic[TOutputCode]):
|
||||
"""
|
||||
A full compiled fx graph that doesn't need to lookup the FxGraphCache
|
||||
to run
|
||||
A generic wrapper for OutputCode objects that are bundled directly in the cache
|
||||
(rather than looked up via FxGraphCache).
|
||||
|
||||
This works for any OutputCode subclass (CompiledFxGraph, RegionalOutputCode, etc.)
|
||||
"""
|
||||
|
||||
result: CompiledFxGraph
|
||||
result: TOutputCode
|
||||
|
||||
def pre_save(self) -> None:
|
||||
disk_compiled_graph = copy(self.result)
|
||||
disk_compiled_graph.prepare_for_serialization()
|
||||
self.result = disk_compiled_graph
|
||||
disk_result = copy(self.result)
|
||||
disk_result.prepare_for_serialization()
|
||||
self.result = disk_result
|
||||
return
|
||||
|
||||
def load(self, example_inputs) -> CompiledFxGraph:
|
||||
def load(self, example_inputs) -> TOutputCode:
|
||||
self.example_inputs = example_inputs
|
||||
|
||||
return self.result
|
||||
|
||||
def post_compile(
|
||||
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
|
||||
) -> CompiledFxGraph:
|
||||
self, result: TOutputCode, fx_config: _CompileFxKwargs
|
||||
) -> TOutputCode:
|
||||
constants = CompiledFxGraphConstants()
|
||||
# Cache hit specific post compile
|
||||
graph, cache_info = FxGraphCache.cache_hit_post_compile(result, {}, constants)
|
||||
|
||||
# Special handling for CompiledFxGraph - needs FxGraphCache.cache_hit_post_compile
|
||||
if isinstance(result, CompiledFxGraph):
|
||||
graph, cache_info = FxGraphCache.cache_hit_post_compile(
|
||||
result, {}, constants
|
||||
)
|
||||
if graph is None:
|
||||
raise BypassAOTAutogradCache("Failed to reload cache entry from disk")
|
||||
torch._logging.trace_structured(
|
||||
|
|
@ -560,9 +568,17 @@ class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
|
|||
},
|
||||
payload_fn=lambda: json.dumps(cache_info),
|
||||
)
|
||||
result = graph # type: ignore[assignment]
|
||||
|
||||
# Run normal post compile
|
||||
graph.post_compile(self.example_inputs, constants, fx_config)
|
||||
return graph
|
||||
result.post_compile(self.example_inputs, constants, fx_config)
|
||||
return result
|
||||
|
||||
|
||||
# Backwards compatibility alias
|
||||
CompiledFxGraphLoadable: type[BundledOutputCodeLoadable[CompiledFxGraph]] = (
|
||||
BundledOutputCodeLoadable[CompiledFxGraph]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -689,18 +705,31 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
|
|||
)
|
||||
|
||||
|
||||
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
|
||||
class BundledCompiledForward(CompiledFxGraphLoadable):
|
||||
pass
|
||||
# Generic bundled forward/backward classes that work with any OutputCode type
|
||||
@dataclass
|
||||
class BundledCompiledForward(
|
||||
BundledOutputCodeLoadable[TOutputCode], Generic[TOutputCode]
|
||||
):
|
||||
"""
|
||||
Generic forward function for bundled compilation.
|
||||
Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.)
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BundledCompiledBackward(
|
||||
GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable
|
||||
GenericCompiledBackward[TOutputCode],
|
||||
BundledOutputCodeLoadable[TOutputCode],
|
||||
Generic[TOutputCode],
|
||||
):
|
||||
"""
|
||||
Generic backward function for bundled compilation.
|
||||
Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.)
|
||||
"""
|
||||
|
||||
def post_compile(
|
||||
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
|
||||
) -> CompiledFxGraph:
|
||||
self, result: TOutputCode, fx_config: _CompileFxKwargs
|
||||
) -> TOutputCode:
|
||||
compiled_bw = super().post_compile(result, fx_config)
|
||||
# See note [Wrapping bw_compiler in disable]
|
||||
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
|
||||
|
|
@ -990,11 +1019,44 @@ class AOTAutogradCacheEntry(
|
|||
|
||||
|
||||
class BundledAOTAutogradCacheEntry(
|
||||
GenericAOTAutogradCacheEntry[BundledCompiledForward, BundledCompiledBackward]
|
||||
GenericAOTAutogradCacheEntry[
|
||||
BundledCompiledForward[TOutputCode], BundledCompiledBackward[TOutputCode]
|
||||
],
|
||||
Generic[TOutputCode],
|
||||
):
|
||||
"""
|
||||
AOTAutogradCacheEntry where we save the entire CompiledFxGraph instead
|
||||
of relying on cache keys from FxGraphCache
|
||||
Generic AOTAutogradCacheEntry where we bundle the entire OutputCode directly
|
||||
(rather than looking it up via FxGraphCache).
|
||||
|
||||
This works with any OutputCode type:
|
||||
- CompiledFxGraph: Traditional inductor compilation
|
||||
- RegionalOutputCode: Regional inductor compilation with GraphPickler serialization
|
||||
- Any future OutputCode subclasses
|
||||
|
||||
Type parameter:
|
||||
TOutputCode: The OutputCode subclass (e.g., CompiledFxGraph, RegionalOutputCode)
|
||||
|
||||
Usage with CompiledFxGraph:
|
||||
entry = BundledAOTAutogradCacheEntry[CompiledFxGraph](
|
||||
compiled_fw=BundledCompiledForward(result=CompiledFxGraph(...)),
|
||||
compiled_bw=BundledCompiledBackward(
|
||||
result=CompiledFxGraph(...),
|
||||
backward_state_indices=[...],
|
||||
num_symints_saved_for_bw_=...,
|
||||
),
|
||||
...
|
||||
)
|
||||
|
||||
Usage with RegionalOutputCode:
|
||||
entry = BundledAOTAutogradCacheEntry[RegionalOutputCode](
|
||||
compiled_fw=BundledCompiledForward(result=RegionalOutputCode(gm)),
|
||||
compiled_bw=BundledCompiledBackward(
|
||||
result=RegionalOutputCode(gm),
|
||||
backward_state_indices=[...],
|
||||
num_symints_saved_for_bw_=...,
|
||||
),
|
||||
...
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -1469,8 +1531,8 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
|||
|
||||
@staticmethod
|
||||
def make_entry(
|
||||
compiled_fw_func: CompiledFxGraph,
|
||||
compiled_bw_func: Optional[CompiledFxGraph],
|
||||
compiled_fw_func: OutputCode,
|
||||
compiled_bw_func: Optional[OutputCode],
|
||||
aot_joint_graph_str: Optional[str],
|
||||
aot_forward_graph_str: Optional[str],
|
||||
aot_backward_graph_str: Optional[str],
|
||||
|
|
@ -1490,19 +1552,19 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
|||
if should_bundle_autograd_cache():
|
||||
# Helper function to unwrap all the wrappers we added during aotdispatch
|
||||
# They get reapplied on cache load
|
||||
def unwrap_compiled_fx_graph(obj):
|
||||
def unwrap_output_code(obj):
|
||||
while hasattr(obj, "__wrapped__"):
|
||||
obj = obj.__wrapped__
|
||||
assert isinstance(obj, CompiledFxGraph)
|
||||
assert isinstance(obj, OutputCode)
|
||||
return obj
|
||||
|
||||
compiled_fw_graph = unwrap_compiled_fx_graph(compiled_fw_func)
|
||||
compiled_fw_graph = unwrap_output_code(compiled_fw_func)
|
||||
bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph)
|
||||
bundled_compiled_backward = None
|
||||
if compiled_bw_func is not None:
|
||||
assert backward_state_indices is not None
|
||||
assert num_symints_saved_for_bw is not None
|
||||
compiled_bw_graph = unwrap_compiled_fx_graph(compiled_bw_func)
|
||||
compiled_bw_graph = unwrap_output_code(compiled_bw_func)
|
||||
bundled_compiled_backward = BundledCompiledBackward(
|
||||
compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw
|
||||
)
|
||||
|
|
|
|||
|
|
@ -88,6 +88,9 @@ class OutputCode:
|
|||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError(type(self))
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
raise NotImplementedError(type(self))
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
example_inputs: Sequence[InputType],
|
||||
|
|
@ -783,6 +786,9 @@ class CompiledAOTI(OutputCode):
|
|||
) -> None:
|
||||
pass
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
pass
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -807,3 +813,97 @@ class MockFXGraphCacheOutput(OutputCode):
|
|||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegionalOutputCode(OutputCode):
|
||||
"""
|
||||
OutputCode for regional inductor compilation results.
|
||||
|
||||
Regional inductor returns a torch.fx.GraphModule that contains both
|
||||
compiled regions (via standalone_compile) and eager regions. This needs
|
||||
special serialization using GraphPickler instead of standard pickle.
|
||||
|
||||
The serialization strategy stores the GraphModule as bytes using
|
||||
GraphPickler.dumps(), which handles FakeTensors, AOTCompiledArtifacts,
|
||||
and other special objects that standard pickle cannot handle.
|
||||
"""
|
||||
|
||||
# The serialized graph module as bytes (using GraphPickler)
|
||||
_serialized_graph_module: Optional[bytes] = dataclasses.field(
|
||||
default=None, init=False
|
||||
)
|
||||
|
||||
# The actual graph module (cleared during serialization)
|
||||
_graph_module: Optional[torch.fx.GraphModule] = dataclasses.field(
|
||||
default=None, init=False
|
||||
)
|
||||
|
||||
def __init__(self, graph_module: torch.fx.GraphModule):
|
||||
"""
|
||||
Args:
|
||||
graph_module: The torch.fx.GraphModule returned by regional_inductor
|
||||
"""
|
||||
super().__init__()
|
||||
self._graph_module = graph_module
|
||||
self._serialized_graph_module = None
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
"""Execute the regional compiled graph."""
|
||||
if self._graph_module is None:
|
||||
raise RuntimeError(
|
||||
"RegionalOutputCode has no graph module loaded. "
|
||||
"Did you forget to call post_compile()?"
|
||||
)
|
||||
return self._graph_module(*inputs)
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
example_inputs: Sequence[InputType],
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Post-compile processing for regional inductor.
|
||||
|
||||
This deserializes the GraphModule from bytes using GraphPickler,
|
||||
extracting the fake_mode from example_inputs.
|
||||
"""
|
||||
if self._graph_module is not None:
|
||||
return
|
||||
assert self._serialized_graph_module is not None
|
||||
# Get fake mode from example inputs
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode(example_inputs)
|
||||
if fake_mode is None:
|
||||
raise RuntimeError(
|
||||
"Could not detect fake mode from example inputs. "
|
||||
"Regional inductor requires fake mode for deserialization."
|
||||
)
|
||||
|
||||
# Deserialize the graph module
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
gm = GraphPickler.loads(self._serialized_graph_module, fake_mode)
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
gm.recompile()
|
||||
self._graph_module = gm
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
"""Regional inductor doesn't use triton bundles directly."""
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
"""
|
||||
Prepare for serialization by converting the GraphModule to bytes.
|
||||
|
||||
This uses GraphPickler to serialize the graph module since it contains
|
||||
special objects like FakeTensors and AOTCompiledArtifacts that need
|
||||
custom pickling.
|
||||
"""
|
||||
if self._graph_module is not None:
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
self._serialized_graph_module = GraphPickler.dumps(self._graph_module)
|
||||
# Clear the graph module to avoid pickling it with standard pickle
|
||||
self._graph_module = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user