Add regional aot eager support to AOTAutogradCacheEntry (#166650)

This PR does two things:

- It genericizes `BundledAOTAutogradCacheEntry` to support *any* outputcode, not just CompiledFxGraphs
- It adds a brand new OutputCode for the `aot_eager_regional_inductor` backend, i.e. a graph module that has regional inductor components in it.

This allows BundledAOTAutogradCache to just integrate nicely with inductor out of the box, but more importantly, it allows the result of aot_autograd to be fully serializable when using `aot_eager_regional_inductor`. This will allow us to AOT precompile cases where we have an eager graph that has scooped up inductor bits.

It's a bit unfortunate that the naming makes BundledAOTAutogradCacheEntry sound like its primary use is for caching, but really the more common use is going to be as an AOTAutogradOutput. It may be worth revisiting how to refactor/rename these in a later PR:

- AOTAutogradCacheEntry -> AOTAutogradResult
- BundledAOTAutogradCacheEntry -> BundledAOTAutogradResult

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166650
Approved by: https://github.com/zhxchen17
This commit is contained in:
James Wu 2025-10-30 09:53:41 -07:00 committed by PyTorch MergeBot
parent b470e59c38
commit 30157d30f0
4 changed files with 537 additions and 40 deletions

View File

@ -41,6 +41,20 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
def aot_eager_regional_inductor():
"""
Regional inductor backend for AOT autograd.
Uses regional_inductor as both forward and backward compiler.
"""
from torch._dynamo.backends.common import aot_autograd
from torch.fx.passes.regional_inductor import regional_inductor
return aot_autograd(
fw_compiler=regional_inductor,
bw_compiler=regional_inductor,
)
def saved_tensors_hooks_to_gm(
pack_fn,
unpack_fn,
@ -1898,6 +1912,171 @@ class AOTAutogradCacheTests(InductorTestCase):
# no recompiles
self.assertFalse(counters)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"bundled_autograd_cache": True})
def test_regional_inductor_basic(self):
"""
Basic test for regional inductor with bundled autograd cache.
Tests that regional inductor compilation results can be cached and hit.
"""
import torch.fx.traceback as fx_traceback
def fn(x, y):
sin = torch.sin(x)
# Mark this region to be compiled with inductor
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, device="cpu")
y = torch.randn(10, device="cpu")
# Compile with regional inductor backend
compiled_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
# First call should miss in cache
result1 = compiled_fn(x, y)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# Second call should hit (after clearing dynamo)
self._clear_dynamo_and_codecache()
result2 = compiled_fn(x, y)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# Results should be the same
self.assertEqual(result1, result2)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"bundled_autograd_cache": True})
def test_regional_inductor_with_backward(self):
"""
Test regional inductor with backward pass and bundled autograd cache.
Note: Regional inductor triggers multiple AOT autograd compilations:
- One for the outer graph (with regional inductor backend)
- One for each marked region (via standalone_compile)
"""
import torch.fx.traceback as fx_traceback
def fn(x, y):
sin = torch.sin(x)
# Mark this region to be compiled with inductor
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
x2 = x.detach().clone().requires_grad_(True)
y2 = y.detach().clone().requires_grad_(True)
# Compile with regional inductor backend
compiled_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
# First call: AOT autograd compiles the outer graph (1 miss)
# Regional inductor then compiles the marked region (1 more miss)
result1 = compiled_fn(x, y)
result1.sum().backward()
# We expect 2 cache misses: outer graph + marked region
initial_misses = counters["aot_autograd"]["autograd_cache_miss"]
initial_saves = counters["aot_autograd"]["autograd_cache_saved"]
self.assertGreater(initial_misses, 0)
self.assertGreater(initial_saves, 0)
# Second call should hit (after clearing dynamo)
self._clear_dynamo_and_codecache()
result2 = compiled_fn(x2, y2)
result2.sum().backward()
# Should have cache hits now
final_hits = counters["aot_autograd"]["autograd_cache_hit"]
self.assertGreater(final_hits, 0)
# Cache misses and saves should not increase
self.assertEqual(
counters["aot_autograd"]["autograd_cache_miss"], initial_misses
)
self.assertEqual(
counters["aot_autograd"]["autograd_cache_saved"], initial_saves
)
# Results and gradients should be the same
self.assertEqual(result1, result2)
self.assertEqual(x.grad, x2.grad)
self.assertEqual(y.grad, y2.grad)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"bundled_autograd_cache": True})
def test_regional_inductor_cache_miss_on_change(self):
"""
Test that changing the function causes a cache miss with regional inductor.
Regional inductor creates multiple AOT compilations, so we track
the change in cache misses rather than absolute counts.
"""
import torch.fx.traceback as fx_traceback
def fn1(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
def fn2(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 2 # Changed from +1 to +2
return torch.sin(add)
x = torch.randn(10)
y = torch.randn(10)
# Compile first function
compiled_fn1 = torch.compile(
fn1, backend=aot_eager_regional_inductor(), fullgraph=True
)
result1 = compiled_fn1(x, y)
first_misses = counters["aot_autograd"]["autograd_cache_miss"]
first_saves = counters["aot_autograd"]["autograd_cache_saved"]
self.assertGreater(first_misses, 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertGreater(first_saves, 0)
# Compile second function (different graph)
self._clear_dynamo_and_codecache()
compiled_fn2 = torch.compile(
fn2, backend=aot_eager_regional_inductor(), fullgraph=True
)
result2 = compiled_fn2(x, y)
# Should miss because graph is different (more misses than before)
self.assertGreater(
counters["aot_autograd"]["autograd_cache_miss"], first_misses
)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertGreater(
counters["aot_autograd"]["autograd_cache_saved"], first_saves
)
# Results should be different
self.assertNotEqual(result1, result2)
@functorch_config.patch({"bundled_autograd_cache": True})
class AOTAutogradCacheBundledTests(AOTAutogradCacheTests):

View File

@ -1,13 +1,16 @@
# Owner(s): ["module: dynamo"]
import functools
from typing import TYPE_CHECKING
import torch
import torch._inductor.test_case
import torch.fx.traceback as fx_traceback
import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
from torch._guards import detect_fake_mode
from torch._inductor.output_code import RegionalOutputCode
from torch._inductor.test_case import run_tests
from torch._inductor.utils import run_fw_bw_and_get_code
from torch.fx._graph_pickler import GraphPickler
@ -21,6 +24,10 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.triton_utils import requires_cuda_and_triton
if TYPE_CHECKING:
from torch._inductor.compile_fx import _CompileFxKwargs
# Open questions / follow-ups
# 1) CSE behavior with meta custom nodes
# Common subexpression elimination may not differentiate between distinct meta
@ -462,5 +469,154 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
self.assertEqual(len(codes), 2)
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):
"""Tests for RegionalOutputCode and RegionalAOTAutogradCacheEntry."""
def test_regional_output_code_serialization(self):
"""Test that RegionalOutputCode can be serialized and deserialized."""
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
# Compile with regional inductor
with torch.fx.traceback.preserve_node_meta(enable=False):
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
fake_mode = FakeTensorMode()
with fake_mode:
fake_x = fake_mode.from_tensor(x)
fake_y = fake_mode.from_tensor(y)
gm = make_fx(fn)(fake_x, fake_y)
# Run regional_inductor on the graph
result_gm = regional_inductor(gm, fake_x, fake_y)
# Create RegionalOutputCode
output_code = RegionalOutputCode(result_gm)
# Test that we can call it
self.assertIsNotNone(output_code._graph_module)
# Serialize
output_code.prepare_for_serialization()
self.assertIsNone(output_code._graph_module)
self.assertIsNotNone(output_code._serialized_graph_module)
# Deserialize via post_compile
from torch._inductor.output_code import CompiledFxGraphConstants
fx_config: _CompileFxKwargs = {"is_backward": False}
output_code.post_compile(
[fake_x, fake_y], CompiledFxGraphConstants(), fx_config
)
self.assertIsNotNone(output_code._graph_module)
self.assertIsInstance(output_code._graph_module, torch.fx.GraphModule)
# Test that deserialized graph works
with fake_mode:
result = output_code([fake_x, fake_y])
self.assertIsNotNone(result)
def test_regional_output_code_with_backward(self):
"""Test RegionalOutputCode with both forward and backward compilation."""
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
# Compile with regional inductor backend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
fake_mode = FakeTensorMode()
with fake_mode:
fake_x = fake_mode.from_tensor(x)
fake_y = fake_mode.from_tensor(y)
# Create forward graph
with torch.fx.traceback.preserve_node_meta(enable=False):
gm = make_fx(fn)(fake_x, fake_y)
forward_gm = regional_inductor(gm, fake_x, fake_y)
# Create forward output code
fw_code = RegionalOutputCode(forward_gm)
# Verify it can be called
with fake_mode:
result = fw_code([fake_x, fake_y])
self.assertIsNotNone(result)
# Test serialization round-trip
fw_code.prepare_for_serialization()
# Deserialize via post_compile
from torch._inductor.output_code import CompiledFxGraphConstants
fx_config: _CompileFxKwargs = {"is_backward": False}
fw_code.post_compile([fake_x, fake_y], CompiledFxGraphConstants(), fx_config)
with fake_mode:
result2 = fw_code([fake_x, fake_y])
self.assertIsNotNone(result2)
def test_regional_compiled_forward_backward(self):
"""Test BundledCompiledForward and BundledCompiledBackward with RegionalOutputCode."""
def fn(x):
with fx_traceback.annotate({"compile_with_inductor": 0}):
return torch.sin(x) * 2
x = torch.randn(5, requires_grad=True)
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
fake_mode = FakeTensorMode()
with fake_mode:
fake_x = fake_mode.from_tensor(x)
with torch.fx.traceback.preserve_node_meta(enable=False):
gm = make_fx(fn)(fake_x)
compiled_gm = regional_inductor(gm, fake_x)
# Create forward using the generic BundledCompiledForward
fw_code = RegionalOutputCode(compiled_gm)
fw_compiled = BundledCompiledForward[RegionalOutputCode](result=fw_code)
# Test pre_save
fw_compiled.pre_save()
# After pre_save, fw_compiled.result is a copy with serialized graph
self.assertIsNotNone(fw_compiled.result._serialized_graph_module)
self.assertIsNone(
fw_compiled.result._graph_module
) # Should be cleared after serialization
# Test load (doesn't deserialize yet)
loaded_code = fw_compiled.load([fake_x])
self.assertIsNone(loaded_code._graph_module) # Not yet deserialized
self.assertIsNotNone(loaded_code._serialized_graph_module)
fx_config: _CompileFxKwargs = {"is_backward": False}
post_compiled = fw_compiled.post_compile(loaded_code, fx_config)
self.assertIsNotNone(post_compiled)
self.assertIsNotNone(post_compiled._graph_module) # Now deserialized
if __name__ == "__main__":
run_tests()

View File

@ -524,32 +524,40 @@ class InductorOutput(ABC, Generic[TOut]):
def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ...
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
@dataclass
class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
class BundledOutputCodeLoadable(InductorOutput[TOutputCode], Generic[TOutputCode]):
"""
A full compiled fx graph that doesn't need to lookup the FxGraphCache
to run
A generic wrapper for OutputCode objects that are bundled directly in the cache
(rather than looked up via FxGraphCache).
This works for any OutputCode subclass (CompiledFxGraph, RegionalOutputCode, etc.)
"""
result: CompiledFxGraph
result: TOutputCode
def pre_save(self) -> None:
disk_compiled_graph = copy(self.result)
disk_compiled_graph.prepare_for_serialization()
self.result = disk_compiled_graph
disk_result = copy(self.result)
disk_result.prepare_for_serialization()
self.result = disk_result
return
def load(self, example_inputs) -> CompiledFxGraph:
def load(self, example_inputs) -> TOutputCode:
self.example_inputs = example_inputs
return self.result
def post_compile(
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
) -> CompiledFxGraph:
self, result: TOutputCode, fx_config: _CompileFxKwargs
) -> TOutputCode:
constants = CompiledFxGraphConstants()
# Cache hit specific post compile
graph, cache_info = FxGraphCache.cache_hit_post_compile(result, {}, constants)
# Special handling for CompiledFxGraph - needs FxGraphCache.cache_hit_post_compile
if isinstance(result, CompiledFxGraph):
graph, cache_info = FxGraphCache.cache_hit_post_compile(
result, {}, constants
)
if graph is None:
raise BypassAOTAutogradCache("Failed to reload cache entry from disk")
torch._logging.trace_structured(
@ -560,9 +568,17 @@ class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]):
},
payload_fn=lambda: json.dumps(cache_info),
)
result = graph # type: ignore[assignment]
# Run normal post compile
graph.post_compile(self.example_inputs, constants, fx_config)
return graph
result.post_compile(self.example_inputs, constants, fx_config)
return result
# Backwards compatibility alias
CompiledFxGraphLoadable: type[BundledOutputCodeLoadable[CompiledFxGraph]] = (
BundledOutputCodeLoadable[CompiledFxGraph]
)
@dataclass
@ -689,18 +705,31 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
)
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
class BundledCompiledForward(CompiledFxGraphLoadable):
pass
# Generic bundled forward/backward classes that work with any OutputCode type
@dataclass
class BundledCompiledForward(
BundledOutputCodeLoadable[TOutputCode], Generic[TOutputCode]
):
"""
Generic forward function for bundled compilation.
Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.)
"""
@dataclass
class BundledCompiledBackward(
GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable
GenericCompiledBackward[TOutputCode],
BundledOutputCodeLoadable[TOutputCode],
Generic[TOutputCode],
):
"""
Generic backward function for bundled compilation.
Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.)
"""
def post_compile(
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
) -> CompiledFxGraph:
self, result: TOutputCode, fx_config: _CompileFxKwargs
) -> TOutputCode:
compiled_bw = super().post_compile(result, fx_config)
# See note [Wrapping bw_compiler in disable]
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
@ -990,11 +1019,44 @@ class AOTAutogradCacheEntry(
class BundledAOTAutogradCacheEntry(
GenericAOTAutogradCacheEntry[BundledCompiledForward, BundledCompiledBackward]
GenericAOTAutogradCacheEntry[
BundledCompiledForward[TOutputCode], BundledCompiledBackward[TOutputCode]
],
Generic[TOutputCode],
):
"""
AOTAutogradCacheEntry where we save the entire CompiledFxGraph instead
of relying on cache keys from FxGraphCache
Generic AOTAutogradCacheEntry where we bundle the entire OutputCode directly
(rather than looking it up via FxGraphCache).
This works with any OutputCode type:
- CompiledFxGraph: Traditional inductor compilation
- RegionalOutputCode: Regional inductor compilation with GraphPickler serialization
- Any future OutputCode subclasses
Type parameter:
TOutputCode: The OutputCode subclass (e.g., CompiledFxGraph, RegionalOutputCode)
Usage with CompiledFxGraph:
entry = BundledAOTAutogradCacheEntry[CompiledFxGraph](
compiled_fw=BundledCompiledForward(result=CompiledFxGraph(...)),
compiled_bw=BundledCompiledBackward(
result=CompiledFxGraph(...),
backward_state_indices=[...],
num_symints_saved_for_bw_=...,
),
...
)
Usage with RegionalOutputCode:
entry = BundledAOTAutogradCacheEntry[RegionalOutputCode](
compiled_fw=BundledCompiledForward(result=RegionalOutputCode(gm)),
compiled_bw=BundledCompiledBackward(
result=RegionalOutputCode(gm),
backward_state_indices=[...],
num_symints_saved_for_bw_=...,
),
...
)
"""
@ -1469,8 +1531,8 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
@staticmethod
def make_entry(
compiled_fw_func: CompiledFxGraph,
compiled_bw_func: Optional[CompiledFxGraph],
compiled_fw_func: OutputCode,
compiled_bw_func: Optional[OutputCode],
aot_joint_graph_str: Optional[str],
aot_forward_graph_str: Optional[str],
aot_backward_graph_str: Optional[str],
@ -1490,19 +1552,19 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
if should_bundle_autograd_cache():
# Helper function to unwrap all the wrappers we added during aotdispatch
# They get reapplied on cache load
def unwrap_compiled_fx_graph(obj):
def unwrap_output_code(obj):
while hasattr(obj, "__wrapped__"):
obj = obj.__wrapped__
assert isinstance(obj, CompiledFxGraph)
assert isinstance(obj, OutputCode)
return obj
compiled_fw_graph = unwrap_compiled_fx_graph(compiled_fw_func)
compiled_fw_graph = unwrap_output_code(compiled_fw_func)
bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph)
bundled_compiled_backward = None
if compiled_bw_func is not None:
assert backward_state_indices is not None
assert num_symints_saved_for_bw is not None
compiled_bw_graph = unwrap_compiled_fx_graph(compiled_bw_func)
compiled_bw_graph = unwrap_output_code(compiled_bw_func)
bundled_compiled_backward = BundledCompiledBackward(
compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw
)

View File

@ -88,6 +88,9 @@ class OutputCode:
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError(type(self))
def prepare_for_serialization(self) -> None:
raise NotImplementedError(type(self))
def post_compile(
self,
example_inputs: Sequence[InputType],
@ -783,6 +786,9 @@ class CompiledAOTI(OutputCode):
) -> None:
pass
def prepare_for_serialization(self) -> None:
pass
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass
@ -807,3 +813,97 @@ class MockFXGraphCacheOutput(OutputCode):
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass
@dataclasses.dataclass
class RegionalOutputCode(OutputCode):
"""
OutputCode for regional inductor compilation results.
Regional inductor returns a torch.fx.GraphModule that contains both
compiled regions (via standalone_compile) and eager regions. This needs
special serialization using GraphPickler instead of standard pickle.
The serialization strategy stores the GraphModule as bytes using
GraphPickler.dumps(), which handles FakeTensors, AOTCompiledArtifacts,
and other special objects that standard pickle cannot handle.
"""
# The serialized graph module as bytes (using GraphPickler)
_serialized_graph_module: Optional[bytes] = dataclasses.field(
default=None, init=False
)
# The actual graph module (cleared during serialization)
_graph_module: Optional[torch.fx.GraphModule] = dataclasses.field(
default=None, init=False
)
def __init__(self, graph_module: torch.fx.GraphModule):
"""
Args:
graph_module: The torch.fx.GraphModule returned by regional_inductor
"""
super().__init__()
self._graph_module = graph_module
self._serialized_graph_module = None
def __call__(self, inputs: Sequence[Any]) -> Any:
"""Execute the regional compiled graph."""
if self._graph_module is None:
raise RuntimeError(
"RegionalOutputCode has no graph module loaded. "
"Did you forget to call post_compile()?"
)
return self._graph_module(*inputs)
def post_compile(
self,
example_inputs: Sequence[InputType],
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
"""
Post-compile processing for regional inductor.
This deserializes the GraphModule from bytes using GraphPickler,
extracting the fake_mode from example_inputs.
"""
if self._graph_module is not None:
return
assert self._serialized_graph_module is not None
# Get fake mode from example inputs
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(example_inputs)
if fake_mode is None:
raise RuntimeError(
"Could not detect fake mode from example inputs. "
"Regional inductor requires fake mode for deserialization."
)
# Deserialize the graph module
from torch.fx._graph_pickler import GraphPickler
gm = GraphPickler.loads(self._serialized_graph_module, fake_mode)
assert isinstance(gm, torch.fx.GraphModule)
gm.recompile()
self._graph_module = gm
def set_triton_bundle(self, triton_bundle: Any) -> None:
"""Regional inductor doesn't use triton bundles directly."""
def prepare_for_serialization(self) -> None:
"""
Prepare for serialization by converting the GraphModule to bytes.
This uses GraphPickler to serialize the graph module since it contains
special objects like FakeTensors and AOTCompiledArtifacts that need
custom pickling.
"""
if self._graph_module is not None:
from torch.fx._graph_pickler import GraphPickler
self._serialized_graph_module = GraphPickler.dumps(self._graph_module)
# Clear the graph module to avoid pickling it with standard pickle
self._graph_module = None