diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 04af76c90c5..40b8b1a5b6c 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -41,6 +41,20 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor +def aot_eager_regional_inductor(): + """ + Regional inductor backend for AOT autograd. + Uses regional_inductor as both forward and backward compiler. + """ + from torch._dynamo.backends.common import aot_autograd + from torch.fx.passes.regional_inductor import regional_inductor + + return aot_autograd( + fw_compiler=regional_inductor, + bw_compiler=regional_inductor, + ) + + def saved_tensors_hooks_to_gm( pack_fn, unpack_fn, @@ -1898,6 +1912,171 @@ class AOTAutogradCacheTests(InductorTestCase): # no recompiles self.assertFalse(counters) + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"bundled_autograd_cache": True}) + def test_regional_inductor_basic(self): + """ + Basic test for regional inductor with bundled autograd cache. + Tests that regional inductor compilation results can be cached and hit. + """ + import torch.fx.traceback as fx_traceback + + def fn(x, y): + sin = torch.sin(x) + # Mark this region to be compiled with inductor + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, device="cpu") + y = torch.randn(10, device="cpu") + + # Compile with regional inductor backend + compiled_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + + # First call should miss in cache + result1 = compiled_fn(x, y) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Second call should hit (after clearing dynamo) + self._clear_dynamo_and_codecache() + result2 = compiled_fn(x, y) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Results should be the same + self.assertEqual(result1, result2) + + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"bundled_autograd_cache": True}) + def test_regional_inductor_with_backward(self): + """ + Test regional inductor with backward pass and bundled autograd cache. + Note: Regional inductor triggers multiple AOT autograd compilations: + - One for the outer graph (with regional inductor backend) + - One for each marked region (via standalone_compile) + """ + import torch.fx.traceback as fx_traceback + + def fn(x, y): + sin = torch.sin(x) + # Mark this region to be compiled with inductor + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + x2 = x.detach().clone().requires_grad_(True) + y2 = y.detach().clone().requires_grad_(True) + + # Compile with regional inductor backend + compiled_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + + # First call: AOT autograd compiles the outer graph (1 miss) + # Regional inductor then compiles the marked region (1 more miss) + result1 = compiled_fn(x, y) + result1.sum().backward() + + # We expect 2 cache misses: outer graph + marked region + initial_misses = counters["aot_autograd"]["autograd_cache_miss"] + initial_saves = counters["aot_autograd"]["autograd_cache_saved"] + self.assertGreater(initial_misses, 0) + self.assertGreater(initial_saves, 0) + + # Second call should hit (after clearing dynamo) + self._clear_dynamo_and_codecache() + result2 = compiled_fn(x2, y2) + result2.sum().backward() + + # Should have cache hits now + final_hits = counters["aot_autograd"]["autograd_cache_hit"] + self.assertGreater(final_hits, 0) + + # Cache misses and saves should not increase + self.assertEqual( + counters["aot_autograd"]["autograd_cache_miss"], initial_misses + ) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_saved"], initial_saves + ) + + # Results and gradients should be the same + self.assertEqual(result1, result2) + self.assertEqual(x.grad, x2.grad) + self.assertEqual(y.grad, y2.grad) + + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"bundled_autograd_cache": True}) + def test_regional_inductor_cache_miss_on_change(self): + """ + Test that changing the function causes a cache miss with regional inductor. + Regional inductor creates multiple AOT compilations, so we track + the change in cache misses rather than absolute counts. + """ + import torch.fx.traceback as fx_traceback + + def fn1(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + def fn2(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 2 # Changed from +1 to +2 + return torch.sin(add) + + x = torch.randn(10) + y = torch.randn(10) + + # Compile first function + compiled_fn1 = torch.compile( + fn1, backend=aot_eager_regional_inductor(), fullgraph=True + ) + result1 = compiled_fn1(x, y) + first_misses = counters["aot_autograd"]["autograd_cache_miss"] + first_saves = counters["aot_autograd"]["autograd_cache_saved"] + self.assertGreater(first_misses, 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertGreater(first_saves, 0) + + # Compile second function (different graph) + self._clear_dynamo_and_codecache() + compiled_fn2 = torch.compile( + fn2, backend=aot_eager_regional_inductor(), fullgraph=True + ) + result2 = compiled_fn2(x, y) + # Should miss because graph is different (more misses than before) + self.assertGreater( + counters["aot_autograd"]["autograd_cache_miss"], first_misses + ) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertGreater( + counters["aot_autograd"]["autograd_cache_saved"], first_saves + ) + + # Results should be different + self.assertNotEqual(result1, result2) + @functorch_config.patch({"bundled_autograd_cache": True}) class AOTAutogradCacheBundledTests(AOTAutogradCacheTests): diff --git a/test/dynamo/test_regional_inductor.py b/test/dynamo/test_regional_inductor.py index 0dc208ddadc..ea3fb2943a9 100644 --- a/test/dynamo/test_regional_inductor.py +++ b/test/dynamo/test_regional_inductor.py @@ -1,13 +1,16 @@ # Owner(s): ["module: dynamo"] import functools +from typing import TYPE_CHECKING import torch import torch._inductor.test_case import torch.fx.traceback as fx_traceback import torch.utils.checkpoint from torch._dynamo.backends.common import aot_autograd +from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward from torch._guards import detect_fake_mode +from torch._inductor.output_code import RegionalOutputCode from torch._inductor.test_case import run_tests from torch._inductor.utils import run_fw_bw_and_get_code from torch.fx._graph_pickler import GraphPickler @@ -21,6 +24,10 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.triton_utils import requires_cuda_and_triton +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs + + # Open questions / follow-ups # 1) CSE behavior with meta custom nodes # Common subexpression elimination may not differentiate between distinct meta @@ -462,5 +469,154 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase): self.assertEqual(len(codes), 2) +@skipIfTorchDynamo("Not a suitable dynamo wrapped test") +class TestRegionalOutputCode(torch._inductor.test_case.TestCase): + """Tests for RegionalOutputCode and RegionalAOTAutogradCacheEntry.""" + + def test_regional_output_code_serialization(self): + """Test that RegionalOutputCode can be serialized and deserialized.""" + + def fn(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Compile with regional inductor + with torch.fx.traceback.preserve_node_meta(enable=False): + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + + fake_mode = FakeTensorMode() + with fake_mode: + fake_x = fake_mode.from_tensor(x) + fake_y = fake_mode.from_tensor(y) + gm = make_fx(fn)(fake_x, fake_y) + + # Run regional_inductor on the graph + result_gm = regional_inductor(gm, fake_x, fake_y) + + # Create RegionalOutputCode + output_code = RegionalOutputCode(result_gm) + + # Test that we can call it + self.assertIsNotNone(output_code._graph_module) + + # Serialize + output_code.prepare_for_serialization() + self.assertIsNone(output_code._graph_module) + self.assertIsNotNone(output_code._serialized_graph_module) + + # Deserialize via post_compile + from torch._inductor.output_code import CompiledFxGraphConstants + + fx_config: _CompileFxKwargs = {"is_backward": False} + output_code.post_compile( + [fake_x, fake_y], CompiledFxGraphConstants(), fx_config + ) + self.assertIsNotNone(output_code._graph_module) + self.assertIsInstance(output_code._graph_module, torch.fx.GraphModule) + + # Test that deserialized graph works + with fake_mode: + result = output_code([fake_x, fake_y]) + self.assertIsNotNone(result) + + def test_regional_output_code_with_backward(self): + """Test RegionalOutputCode with both forward and backward compilation.""" + + def fn(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Compile with regional inductor backend + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + + fake_mode = FakeTensorMode() + with fake_mode: + fake_x = fake_mode.from_tensor(x) + fake_y = fake_mode.from_tensor(y) + + # Create forward graph + with torch.fx.traceback.preserve_node_meta(enable=False): + gm = make_fx(fn)(fake_x, fake_y) + forward_gm = regional_inductor(gm, fake_x, fake_y) + + # Create forward output code + fw_code = RegionalOutputCode(forward_gm) + + # Verify it can be called + with fake_mode: + result = fw_code([fake_x, fake_y]) + self.assertIsNotNone(result) + + # Test serialization round-trip + fw_code.prepare_for_serialization() + + # Deserialize via post_compile + + from torch._inductor.output_code import CompiledFxGraphConstants + + fx_config: _CompileFxKwargs = {"is_backward": False} + fw_code.post_compile([fake_x, fake_y], CompiledFxGraphConstants(), fx_config) + + with fake_mode: + result2 = fw_code([fake_x, fake_y]) + self.assertIsNotNone(result2) + + def test_regional_compiled_forward_backward(self): + """Test BundledCompiledForward and BundledCompiledBackward with RegionalOutputCode.""" + + def fn(x): + with fx_traceback.annotate({"compile_with_inductor": 0}): + return torch.sin(x) * 2 + + x = torch.randn(5, requires_grad=True) + + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + + fake_mode = FakeTensorMode() + with fake_mode: + fake_x = fake_mode.from_tensor(x) + + with torch.fx.traceback.preserve_node_meta(enable=False): + gm = make_fx(fn)(fake_x) + compiled_gm = regional_inductor(gm, fake_x) + + # Create forward using the generic BundledCompiledForward + fw_code = RegionalOutputCode(compiled_gm) + fw_compiled = BundledCompiledForward[RegionalOutputCode](result=fw_code) + + # Test pre_save + fw_compiled.pre_save() + # After pre_save, fw_compiled.result is a copy with serialized graph + self.assertIsNotNone(fw_compiled.result._serialized_graph_module) + self.assertIsNone( + fw_compiled.result._graph_module + ) # Should be cleared after serialization + + # Test load (doesn't deserialize yet) + loaded_code = fw_compiled.load([fake_x]) + self.assertIsNone(loaded_code._graph_module) # Not yet deserialized + self.assertIsNotNone(loaded_code._serialized_graph_module) + + fx_config: _CompileFxKwargs = {"is_backward": False} + post_compiled = fw_compiled.post_compile(loaded_code, fx_config) + self.assertIsNotNone(post_compiled) + self.assertIsNotNone(post_compiled._graph_module) # Now deserialized + + if __name__ == "__main__": run_tests() diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index e9df75de7a8..92b38fa1efb 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -524,45 +524,61 @@ class InductorOutput(ABC, Generic[TOut]): def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + @dataclass -class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]): +class BundledOutputCodeLoadable(InductorOutput[TOutputCode], Generic[TOutputCode]): """ - A full compiled fx graph that doesn't need to lookup the FxGraphCache - to run + A generic wrapper for OutputCode objects that are bundled directly in the cache + (rather than looked up via FxGraphCache). + + This works for any OutputCode subclass (CompiledFxGraph, RegionalOutputCode, etc.) """ - result: CompiledFxGraph + result: TOutputCode def pre_save(self) -> None: - disk_compiled_graph = copy(self.result) - disk_compiled_graph.prepare_for_serialization() - self.result = disk_compiled_graph + disk_result = copy(self.result) + disk_result.prepare_for_serialization() + self.result = disk_result return - def load(self, example_inputs) -> CompiledFxGraph: + def load(self, example_inputs) -> TOutputCode: self.example_inputs = example_inputs - return self.result def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: constants = CompiledFxGraphConstants() - # Cache hit specific post compile - graph, cache_info = FxGraphCache.cache_hit_post_compile(result, {}, constants) - if graph is None: - raise BypassAOTAutogradCache("Failed to reload cache entry from disk") - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "fx_graph_bundled_cache_hit", # always a hit - "encoding": "json", - }, - payload_fn=lambda: json.dumps(cache_info), - ) + + # Special handling for CompiledFxGraph - needs FxGraphCache.cache_hit_post_compile + if isinstance(result, CompiledFxGraph): + graph, cache_info = FxGraphCache.cache_hit_post_compile( + result, {}, constants + ) + if graph is None: + raise BypassAOTAutogradCache("Failed to reload cache entry from disk") + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_bundled_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + result = graph # type: ignore[assignment] + # Run normal post compile - graph.post_compile(self.example_inputs, constants, fx_config) - return graph + result.post_compile(self.example_inputs, constants, fx_config) + return result + + +# Backwards compatibility alias +CompiledFxGraphLoadable: type[BundledOutputCodeLoadable[CompiledFxGraph]] = ( + BundledOutputCodeLoadable[CompiledFxGraph] +) @dataclass @@ -689,18 +705,31 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa ) -# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence -class BundledCompiledForward(CompiledFxGraphLoadable): - pass +# Generic bundled forward/backward classes that work with any OutputCode type +@dataclass +class BundledCompiledForward( + BundledOutputCodeLoadable[TOutputCode], Generic[TOutputCode] +): + """ + Generic forward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ @dataclass class BundledCompiledBackward( - GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable + GenericCompiledBackward[TOutputCode], + BundledOutputCodeLoadable[TOutputCode], + Generic[TOutputCode], ): + """ + Generic backward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: compiled_bw = super().post_compile(result, fx_config) # See note [Wrapping bw_compiler in disable] # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py @@ -990,11 +1019,44 @@ class AOTAutogradCacheEntry( class BundledAOTAutogradCacheEntry( - GenericAOTAutogradCacheEntry[BundledCompiledForward, BundledCompiledBackward] + GenericAOTAutogradCacheEntry[ + BundledCompiledForward[TOutputCode], BundledCompiledBackward[TOutputCode] + ], + Generic[TOutputCode], ): """ - AOTAutogradCacheEntry where we save the entire CompiledFxGraph instead - of relying on cache keys from FxGraphCache + Generic AOTAutogradCacheEntry where we bundle the entire OutputCode directly + (rather than looking it up via FxGraphCache). + + This works with any OutputCode type: + - CompiledFxGraph: Traditional inductor compilation + - RegionalOutputCode: Regional inductor compilation with GraphPickler serialization + - Any future OutputCode subclasses + + Type parameter: + TOutputCode: The OutputCode subclass (e.g., CompiledFxGraph, RegionalOutputCode) + + Usage with CompiledFxGraph: + entry = BundledAOTAutogradCacheEntry[CompiledFxGraph]( + compiled_fw=BundledCompiledForward(result=CompiledFxGraph(...)), + compiled_bw=BundledCompiledBackward( + result=CompiledFxGraph(...), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + + Usage with RegionalOutputCode: + entry = BundledAOTAutogradCacheEntry[RegionalOutputCode]( + compiled_fw=BundledCompiledForward(result=RegionalOutputCode(gm)), + compiled_bw=BundledCompiledBackward( + result=RegionalOutputCode(gm), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) """ @@ -1469,8 +1531,8 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): @staticmethod def make_entry( - compiled_fw_func: CompiledFxGraph, - compiled_bw_func: Optional[CompiledFxGraph], + compiled_fw_func: OutputCode, + compiled_bw_func: Optional[OutputCode], aot_joint_graph_str: Optional[str], aot_forward_graph_str: Optional[str], aot_backward_graph_str: Optional[str], @@ -1490,19 +1552,19 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): if should_bundle_autograd_cache(): # Helper function to unwrap all the wrappers we added during aotdispatch # They get reapplied on cache load - def unwrap_compiled_fx_graph(obj): + def unwrap_output_code(obj): while hasattr(obj, "__wrapped__"): obj = obj.__wrapped__ - assert isinstance(obj, CompiledFxGraph) + assert isinstance(obj, OutputCode) return obj - compiled_fw_graph = unwrap_compiled_fx_graph(compiled_fw_func) + compiled_fw_graph = unwrap_output_code(compiled_fw_func) bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph) bundled_compiled_backward = None if compiled_bw_func is not None: assert backward_state_indices is not None assert num_symints_saved_for_bw is not None - compiled_bw_graph = unwrap_compiled_fx_graph(compiled_bw_func) + compiled_bw_graph = unwrap_output_code(compiled_bw_func) bundled_compiled_backward = BundledCompiledBackward( compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw ) diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 6f1e192d46f..f5ab01374d8 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -88,6 +88,9 @@ class OutputCode: def __call__(self, inputs: Sequence[Any]) -> Any: raise NotImplementedError(type(self)) + def prepare_for_serialization(self) -> None: + raise NotImplementedError(type(self)) + def post_compile( self, example_inputs: Sequence[InputType], @@ -783,6 +786,9 @@ class CompiledAOTI(OutputCode): ) -> None: pass + def prepare_for_serialization(self) -> None: + pass + def set_triton_bundle(self, triton_bundle: Any) -> None: pass @@ -807,3 +813,97 @@ class MockFXGraphCacheOutput(OutputCode): def set_triton_bundle(self, triton_bundle: Any) -> None: pass + + +@dataclasses.dataclass +class RegionalOutputCode(OutputCode): + """ + OutputCode for regional inductor compilation results. + + Regional inductor returns a torch.fx.GraphModule that contains both + compiled regions (via standalone_compile) and eager regions. This needs + special serialization using GraphPickler instead of standard pickle. + + The serialization strategy stores the GraphModule as bytes using + GraphPickler.dumps(), which handles FakeTensors, AOTCompiledArtifacts, + and other special objects that standard pickle cannot handle. + """ + + # The serialized graph module as bytes (using GraphPickler) + _serialized_graph_module: Optional[bytes] = dataclasses.field( + default=None, init=False + ) + + # The actual graph module (cleared during serialization) + _graph_module: Optional[torch.fx.GraphModule] = dataclasses.field( + default=None, init=False + ) + + def __init__(self, graph_module: torch.fx.GraphModule): + """ + Args: + graph_module: The torch.fx.GraphModule returned by regional_inductor + """ + super().__init__() + self._graph_module = graph_module + self._serialized_graph_module = None + + def __call__(self, inputs: Sequence[Any]) -> Any: + """Execute the regional compiled graph.""" + if self._graph_module is None: + raise RuntimeError( + "RegionalOutputCode has no graph module loaded. " + "Did you forget to call post_compile()?" + ) + return self._graph_module(*inputs) + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + """ + Post-compile processing for regional inductor. + + This deserializes the GraphModule from bytes using GraphPickler, + extracting the fake_mode from example_inputs. + """ + if self._graph_module is not None: + return + assert self._serialized_graph_module is not None + # Get fake mode from example inputs + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + raise RuntimeError( + "Could not detect fake mode from example inputs. " + "Regional inductor requires fake mode for deserialization." + ) + + # Deserialize the graph module + from torch.fx._graph_pickler import GraphPickler + + gm = GraphPickler.loads(self._serialized_graph_module, fake_mode) + assert isinstance(gm, torch.fx.GraphModule) + gm.recompile() + self._graph_module = gm + + def set_triton_bundle(self, triton_bundle: Any) -> None: + """Regional inductor doesn't use triton bundles directly.""" + + def prepare_for_serialization(self) -> None: + """ + Prepare for serialization by converting the GraphModule to bytes. + + This uses GraphPickler to serialize the graph module since it contains + special objects like FakeTensors and AOTCompiledArtifacts that need + custom pickling. + """ + if self._graph_module is not None: + from torch.fx._graph_pickler import GraphPickler + + self._serialized_graph_module = GraphPickler.dumps(self._graph_module) + # Clear the graph module to avoid pickling it with standard pickle + self._graph_module = None