From 57d8278ab902b34e70e23f044e7e15a1bcea400c Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 30 Jan 2025 14:05:26 -0800 Subject: [PATCH] pickler for GraphModule (#141659) Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that. Differential Revision: [D68921318](https://our.internmc.facebook.com/intern/diff/D68921318) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659 Approved by: https://github.com/jamesjwu --- .../fsdp/test_fully_shard_compile.py | 8 +- test/fx/test_graph_pickler.py | 96 +++ test/inductor/test_torchinductor.py | 2 +- torch/_dynamo/variables/misc.py | 14 + torch/_inductor/codecache.py | 7 + torch/_inductor/compile_fx.py | 287 ++++++++- torch/_inductor/graph.py | 9 +- torch/_inductor/output_code.py | 5 - torch/_inductor/utils.py | 8 +- torch/_subclasses/fake_tensor.py | 2 +- torch/_subclasses/meta_utils.py | 32 +- torch/fx/_graph_pickler.py | 582 ++++++++++++++++++ torch/fx/node.py | 3 +- 13 files changed, 1014 insertions(+), 41 deletions(-) create mode 100644 test/fx/test_graph_pickler.py create mode 100644 torch/fx/_graph_pickler.py diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index f8f24a68ddd..6351a74459b 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -751,7 +751,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, "inductor", fwd_fullgraph=fwd_fullgraph, bwd_resize_count_before_inductor=48 if fwd_fullgraph else None, - ) + ), ) if fwd_fullgraph: self.assertEqual( @@ -834,7 +834,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, *self._create_nested_fully_shard_factory_fns(fwd_fullgraph=False), "inductor", fwd_fullgraph=False, - ) + ), ) # TODO: when fwd_fullgraph=False and there is graph break in FWD graph, # there are several recompiles, need to figure out why. @@ -978,7 +978,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, "inductor", fwd_fullgraph=fwd_fullgraph, bwd_resize_count_before_inductor=76 if fwd_fullgraph else None, - ) + ), ) if fwd_fullgraph: self.assertEqual( @@ -1071,7 +1071,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, ), "inductor", fwd_fullgraph=fwd_fullgraph, - ) + ), ) # TODO: when fwd_fullgraph=False and there is graph break in FWD graph, # there are several recompiles, need to figure out why. diff --git a/test/fx/test_graph_pickler.py b/test/fx/test_graph_pickler.py new file mode 100644 index 00000000000..1cfd6a2ef57 --- /dev/null +++ b/test/fx/test_graph_pickler.py @@ -0,0 +1,96 @@ +# Owner(s): ["module: fx"] + +# +# Tests the graph pickler by using pickling on all the inductor tests. +# + +import contextlib +import importlib +import os +import sys +from unittest.mock import patch + +import torch +import torch.library +from torch._dynamo.testing import make_test_cls_with_patches +from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import TEST_WITH_ASAN +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU + + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + CommonTemplate, + copy_tests, +) + + +importlib.import_module("filelock") + +# xfail by default, set is_skip=True to skip +test_failures = {} + + +def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"): + return make_test_cls_with_patches( + cls, + "GraphPickler", + "_graph_pickler", + ( + torch._inductor.compile_fx, + "fx_compile_mode", + torch._inductor.compile_fx.FxCompileMode.SERIALIZE, + ), + xfail_prop=xfail_prop, + ) + + +GraphPicklerCommonTemplate = make_test_cls(CommonTemplate) + + +if HAS_CPU: + + class GraphPicklerCpuTests(TestCase): + common = check_model + device = "cpu" + + copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures) + + +class TestGraphPickler(TestCase): + def setUp(self): + torch._dynamo.reset() + TestCase.setUp(self) + + self._stack = contextlib.ExitStack() + self._stack.enter_context( + patch( + "torch._inductor.compile_fx.fx_compile_mode", + torch._inductor.compile_fx.FxCompileMode.SERIALIZE, + ) + ) + + def tearDown(self): + self._stack.close() + TestCase.tearDown(self) + torch._dynamo.reset() + + def test_simple(self): + # Make sure that compiling works when we pass the input + output from + # fx_codegen_and_compile() through serde. + + def fn(a, b): + return a + b + + check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True]))) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 + if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: + run_tests(needs="filelock") diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bb8cc3d09e0..df54173d66a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12370,7 +12370,7 @@ def copy_tests( new_test = unittest.expectedFailure(new_test) tf = test_failures and test_failures.get(name) - if tf is not None and suffix in tf.suffixes: + if tf and suffix in tf.suffixes: skip_func = ( unittest.skip("Skipped!") if tf.is_skip diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index aa69ee9fd35..483f1bf0624 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1237,6 +1237,10 @@ class TypingVariable(VariableTracker): @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): + """ + This generates a mapping from numpy modules to their torch._numpy + modules equivalents. + """ from ..utils import NP_TO_TNP_MODULE np_fn_to_tnp_fn = {} @@ -1252,6 +1256,16 @@ def get_np_to_tnp_map(): return np_fn_to_tnp_fn +@functools.lru_cache(maxsize=1) +def get_tnp_to_np_map(): + """ + This is just the reverse mapping of get_np_to_tnp_map() - mapping from + torch._numpy modules to numpy equivalents. + """ + m = get_np_to_tnp_map() + return {v: k for k, v in m.items()} + + class NumpyVariable(VariableTracker): """ Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6e1b5f94c1b..6a3474987d3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1053,6 +1053,13 @@ class FxGraphCache: try: artifact_path = graph.after_deserialization(constants) + + from .graph import GraphLowering + + # This is used by tests to check the output for specific details. + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(graph.source_code) + except OSError: # Not expected, but in case the PyCodeCache entry is removed from # underneath us, treat it as a cache miss and recompile. diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b18b68c9d59..b7ad9f1fde8 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,21 +1,25 @@ from __future__ import annotations import contextlib +import enum import functools import io import itertools import json import logging +import os import sys import time import warnings from abc import ABC, abstractmethod +from dataclasses import dataclass from inspect import currentframe from itertools import count from typing import ( Any, Callable, ContextManager, + Mapping, Optional, TYPE_CHECKING, TypeVar, @@ -56,12 +60,18 @@ from torch._functorch.aot_autograd import ( make_boxed_func, SerializableAOTDispatchCompiler, ) -from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log +from torch._inductor.codecache import ( + BypassFxGraphCache, + code_hash, + FxGraphCache, + output_code_log, +) from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo from torch._inductor.debug import save_args_for_compile_fx_inner from torch._inductor.output_code import ( CompiledAOTI, CompiledFxGraph, + CompiledFxGraphConstants, CompiledFxGraphConstantsWithGm, get_expanded_dims, index_expanded_dims, @@ -146,6 +156,43 @@ if TYPE_CHECKING: GraphSignature, ) + +# For testing - use the serde FxCompile scheme to debug serialization and +# deserialization of GraphMoule and CompiledFxGraph. +class FxCompileMode(enum.Enum): + NORMAL = 0 + # For testing - use the serde FxCompile scheme to debug serialization and + # deserialization of GraphMoule and CompiledFxGraph. + SERIALIZE = 1 + + +def _fx_compile_mode_default() -> FxCompileMode: + name = "TORCHINDUCTOR_FX_COMPILE_MODE" + value = os.environ.get(name) + NORMAL = FxCompileMode.NORMAL + if value is None: + return NORMAL + try: + value = value.upper() + return FxCompileMode[value] + except KeyError: + import logging + + log = logging.getLogger(__name__) + log.error( + "Invalid value of %s for %s. Expected one of %s. Using default.", + value, + name, + ", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())), + ) + # Remove from the environment so subprocesses don't ALSO complain. + os.environ.pop(name) + return FxCompileMode.NORMAL + + +fx_compile_mode = _fx_compile_mode_default() + + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs") @@ -755,9 +802,11 @@ def _compile_fx_inner( cache_event_time=start_time, key=cache_info.get("key") if cache_info else None, components=cache_info.get("components") if cache_info else None, - cache_bypass_reason=cache_info.get("cache_bypass_reason") - if cache_info - else "cache not enabled", + cache_bypass_reason=( + cache_info.get("cache_bypass_reason") + if cache_info + else "cache not enabled" + ), remote_cache_enabled=remote, local_cache_enabled=local, ) @@ -787,6 +836,11 @@ def _compile_fx_inner( class FxCompile(ABC): + """ + An FxCompile represents a mechanism that can turn a GraphModule into an + OutputCode. + """ + # TODO: We should probably eventually add some kind of async version of this # so we can kick off a compile and then go do other things - but we'll need # to know what kind of API we want for that first. @@ -1137,6 +1191,195 @@ class _InProcessFxCompile(FxCompile): ) +def _current_fake_mode() -> torch._subclasses.FakeTensorMode: + fake_mode = None + if context := torch._guards.TracingContext.try_get(): + fake_mode = context.fake_mode + if fake_mode is not None: + return fake_mode + + shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv() + return torch._subclasses.FakeTensorMode(shape_env=shape_env) + + +@dataclass +class _WireProtocolInput: + """ + For _SerializedFxCompile - encapsulates all the data being transferred + (sent) from the parent to the child. + """ + + gm: torch.fx.GraphModule + example_inputs: Sequence[InputType] + inputs_to_check: Sequence[int] + graph_kwargs: _CompileFxKwargs + # TODO: Add additional state to transfer to the child. + + def serialize(self) -> _WireProtocolPickledInput: + """ + Turns this object into a _WireProtocolPickledInput which can be + directly transferred across a stream. + """ + from torch.fx._graph_pickler import GraphPickler + + return _WireProtocolPickledInput(GraphPickler.dumps(self)) + + +@dataclass +class _WireProtocolPickledInput: + value: bytes + + def deserialize(self) -> _WireProtocolInput: + """ + Turn this streamable object back into a _WireProtocolInput. + """ + from torch.fx._graph_pickler import GraphPickler + + fake_mode = _current_fake_mode() + result = GraphPickler.loads(self.value, fake_mode) + assert isinstance(result, _WireProtocolInput) + return result + + +@dataclass +class _WireProtocolOutput: + """ + For _SerializedFxCompile - encapsulates all the data being transferred + (returned) back from the child to the parent. + """ + + graph: OutputCode + + def serialize(self) -> _WireProtocolPickledOutput: + """ + Turns this object into a _WireProtocolPickledOutput which can be + directly transferred across a stream. + """ + from torch.fx._graph_pickler import GraphPickler + + if isinstance(self.graph, CompiledFxGraph): + self.graph.prepare_for_serialization() + return _WireProtocolPickledOutput(GraphPickler.dumps(self)) + + +@dataclass +class _WireProtocolPickledOutput: + value: bytes + + def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput: + """ + Turn this streamable object back into a _WireProtocolOutput. + """ + from torch.fx._graph_pickler import GraphPickler + + fake_mode = _current_fake_mode() + result = GraphPickler.loads(self.value, fake_mode) + assert isinstance(result, _WireProtocolOutput) + if isinstance(result.graph, CompiledFxGraph): + result.graph.after_deserialization(constants) + return result + + +class _SerializedFxCompile(FxCompile): + """ + This is used to represent an FxCompile which occurs across a serialized + boundary. + """ + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + # _context = torch._guards.TracingContext.try_get() + constants = CompiledFxGraphConstantsWithGm(gm) + + try: + input = _WireProtocolInput( + gm, + example_inputs, + inputs_to_check, + graph_kwargs, + ).serialize() + except (AttributeError, BypassFxGraphCache): + # For example: AttributeError: Can't pickle local object + # 'make_opaque_unary_fn..OpaqueUnaryFn' + + # TODO: scuba record about not being able to do this? + log.debug("Unable to pickle input graph or example inputs", exc_info=True) + + # Fallback to in-process + return _InProcessFxCompile().codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + output = self._send_to_child(input).deserialize(constants) + + self._postprocess(output) + + # TODO: Do we need to figure out what changed in TracingContext in the + # child and plumb that back up to the parent? + + return output.graph + + @abstractmethod + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + # The implementation of this should transfer `input` to the child, call + # `_run_in_child(input)` and transfer the result back. + ... + + def _postprocess(self, output: _WireProtocolOutput) -> None: + pass + + @classmethod + def _run_in_child( + cls, + pickled_input: _WireProtocolPickledInput, + extra_env: Optional[Mapping[str, str]] = None, + ) -> _WireProtocolPickledOutput: + with contextlib.ExitStack() as stack: + if extra_env is not None: + import unittest + + stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env)) + + # TODO: Should we split the input into multiple sections where each + # section sets up state for the previous section? (i.e. a Config section + # which we decode and apply, followed by a FakeTensorMode section which + # we decode and apply, etc) + input = pickled_input.deserialize() + + stack.enter_context(DebugContext()) + + output_graph = _InProcessFxCompile().codegen_and_compile( + input.gm, + input.example_inputs, + input.inputs_to_check, + input.graph_kwargs, + ) + + return _WireProtocolOutput( + output_graph, + ).serialize() + + +# This is a debugging/testing implementation of FxCompile which serializes the +# input and output but still runs the FxCompile in-process. +class _DebugSerdeFxCompile(_SerializedFxCompile): + @override + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + # For debugging just serde the input and output but don't run in a + # subprocess. + return self._run_in_child(pickled_input) + + def fx_codegen_and_compile( gm: GraphModule, example_inputs: Sequence[InputType], @@ -1145,7 +1388,13 @@ def fx_codegen_and_compile( inputs_to_check: Sequence[int], **graph_kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: - scheme: FxCompile = _InProcessFxCompile() + scheme: FxCompile + if fx_compile_mode == FxCompileMode.NORMAL: + scheme = _InProcessFxCompile() + elif fx_compile_mode == FxCompileMode.SERIALIZE: + scheme = _DebugSerdeFxCompile() + else: + raise NotImplementedError return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -1272,11 +1521,13 @@ def cudagraphify_impl( # allocate static tensor inputs static_inputs = [ - x - if not isinstance(x, torch.Tensor) - else static_input(x) - if idx not in static_input_idxs - else x.detach() + ( + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() + ) for idx, x in enumerate(inputs) ] @@ -1506,9 +1757,11 @@ def fw_compiler_freezing( def get_cpp_wrapper_config() -> dict[str, object]: return { # Set autotune_at_compile_time to True as default if the option is not explicitly set - "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time - if config.triton.autotune_at_compile_time is not None - else has_triton(), + "triton.autotune_at_compile_time": ( + config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + else has_triton() + ), "triton.autotune_cublasLt": False, "triton.cudagraphs": False, # TODO: to be removed "triton.store_cubin": True, @@ -1842,9 +2095,11 @@ def compile_fx( model_outputs_node.meta["user_visible_output_idxs"] = [] fixed = count_tangents(gm) - with config.patch( - get_cpp_wrapper_config() - ) if config.cpp_wrapper else contextlib.nullcontext(): + with ( + config.patch(get_cpp_wrapper_config()) + if config.cpp_wrapper + else contextlib.nullcontext() + ): return inner_compile( gm, example_inputs, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 66e3a841dcd..76a3fee4860 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1989,10 +1989,8 @@ class GraphLowering(torch.fx.Interpreter): return total_bytes, node_counts, node_runtimes - @staticmethod - def save_output_code(code: str) -> None: - # No-op to be patched for unit tests - pass + # No-op to be patched for unit tests + save_output_code: Optional[Callable[[str], None]] = None def compile_to_module(self) -> ModuleType: with dynamo_timed( @@ -2018,7 +2016,8 @@ class GraphLowering(torch.fx.Interpreter): + '"""\n' ) code = tuning_code + code - GraphLowering.save_output_code(code) + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(code) output_code_log.debug("Output code: \n%s", code) inductor_meta = autotune_cache.inductor_meta_from_config() diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 0171bb647a0..393e282d03c 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -546,11 +546,6 @@ class CompiledFxGraph(OutputCode): write_atomic(artifact_path, code, make_dirs=True) - from .graph import GraphLowering - - # This is used by tests to check the output for specific details. - GraphLowering.save_output_code(code) - try: with dynamo_timed( "PyCodeCache.load_by_key_path", diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index bca55585450..bca692af9ad 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1465,12 +1465,16 @@ class DebugDirManager: torch._dynamo.config.debug_dir_root = self.prev_debug_name -def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]: +def run_and_get_code( + fn: Callable[P, _T], + *args: P.args, + **kwargs: P.kwargs, +) -> tuple[_T, list[str]]: from .graph import GraphLowering source_codes: list[str] = [] - def save_output_code(code: str): + def save_output_code(code: str) -> None: source_codes.append(code) with mock.patch.object(GraphLowering, "save_output_code", save_output_code): diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index adcf264e038..b3e0813f853 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -379,7 +379,7 @@ class FakeTensorConverter: out = self.meta_converter( t, shape_env=shape_env, - callback=mk_fake_tensor, # type: ignore[arg-type] + callback=mk_fake_tensor, source=source, symbolic_context=symbolic_context, trace=trace, diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index d0d1905ae6e..3ffbaa96ec5 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -541,11 +541,27 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]): return self.func(new_base, symint_visitor_fn, tensor_visitor_fn) +# A callback where the device is either optional or required. +# All of these satisfy this protocol: +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str]) +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") +# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) +class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): + def __call__( + self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str] + ) -> _TensorT_cov: + ... + + class _MetaTensorCallbackKwargs(TypedDict, total=False): device: Union[torch.device, str] -class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): +# A callback where the device may not be provided (is optional). +# All of these satisfy this protocol: +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") +# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) +class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]): def __call__( self, arg: Callable[[], torch.Tensor], @@ -832,11 +848,13 @@ class MetaConverter(Generic[_TensorT]): self, t: MetaTensorDesc, shape_env: Optional[ShapeEnv], - callback: _MetaTensorCallback[_TensorT], + callback_: _MetaTensorCallback[_TensorT], source: Optional[Source], symbolic_context: Optional[SymbolicContext], ) -> _TensorT: - callback = functools.partial(callback, device=t.device) + callback: _MetaTensorCallbackOptDevice = functools.partial( + callback_, device=t.device + ) if source is None: from torch._dynamo.source import ConstantSource @@ -981,7 +999,7 @@ class MetaConverter(Generic[_TensorT]): symbolic_context: Optional[ torch.fx.experimental.symbolic_shapes.SymbolicContext ], - callback: _MetaTensorCallback[_TensorT], + callback: _MetaTensorCallbackOptDevice[_TensorT], source: torch._guards.Source, ) -> _TensorT: # We are hitting plain meta_desc tensor so actually @@ -1216,7 +1234,7 @@ class MetaConverter(Generic[_TensorT]): shape_env: Optional[ torch.fx.experimental.symbolic_shapes.ShapeEnv ] = shape_env, - callback: _MetaTensorCallback[_TensorT] = callback, + callback: _MetaTensorCallbackOptDevice[_TensorT] = callback, ) -> torch.Tensor: # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: @@ -1769,7 +1787,9 @@ class MetaConverter(Generic[_TensorT]): # Thanks to storage resizing, it's possible to end up with a tensor # that advertises a real size, but has a storage that actually has zero bytes. # Need to reflect this in the generated FakeTensor. - if t.storage is not None and t.storage.size == 0: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if t.storage is not None and guard_size_oblivious(t.storage.size == 0): r.untyped_storage().resize_(0) if t.is_parameter: diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py new file mode 100644 index 00000000000..5a5c77cd588 --- /dev/null +++ b/torch/fx/_graph_pickler.py @@ -0,0 +1,582 @@ +import dataclasses +import importlib +import io +import pickle +from abc import abstractmethod +from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union +from typing_extensions import override, Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import TracingContext +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor +from torch._subclasses.meta_utils import ( + MetaConverter, + MetaTensorDesc, + MetaTensorDescriber, +) +from torch.fx.experimental.sym_node import SymNode +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._mode_utils import no_dispatch + + +_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat) + + +class GraphPickler(pickle.Pickler): + """ + GraphPickler is a Pickler which helps pickling fx graph - in particular + GraphModule. + """ + + def __init__(self, file: io.BytesIO) -> None: + super().__init__(file) + + # This abomination is so we can pass external decoding state to the + # unpickler functions. We serialize _unpickle_state as a persistent + # external item and when we deserialize it we return the common state + # object. + self._unpickle_state = _UnpickleStateToken(object()) + + # This is used to describe tensors. It needs to be common across the + # pickle so that duplicates and views are properly handled. + self._meta_tensor_describer = MetaTensorDescriber(copy_data=False) + + @override + def reducer_override( + self, obj: object + ) -> Tuple[Callable[..., Any], Tuple[Any, ...]]: + # This function is supposed to return either NotImplemented (meaning to + # do the default pickle behavior) or a pair of (unpickle callable, data + # to pass to unpickle). + + # We could instead teach individual classes how to pickle themselves but + # that has a few problems: + # + # 1. If we have some special needs (maybe for this use-case we don't + # want to fully serialize every field) then we're adding private + # details to a public interface. + # + # 2. If we need to have some common shared data (such as a + # FakeTensorMode) which is passed to each value it's harder to + # support. + + # These are the types that need special handling. See the individual + # *PickleData classes for details on pickling that particular type. + if isinstance(obj, FakeTensor): + return _TensorPickleData.reduce_helper(self, obj) + elif isinstance(obj, torch.fx.GraphModule): + return _GraphModulePickleData.reduce_helper(self, obj) + elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)): + return _OpPickleData.reduce_helper(self, obj) + elif isinstance(obj, ShapeEnv): + return _ShapeEnvPickleData.reduce_helper(self, obj) + elif isinstance(obj, torch.SymInt): + return _SymNodePickleData.reduce_helper(self, obj) + elif isinstance(obj, torch._guards.TracingContext): + return _TracingContextPickleData.reduce_helper(self, obj) + else: + # We should never get a raw Node! + assert not isinstance(obj, torch.fx.Node) + if reduce := _TorchNumpyPickleData.reduce_helper(self, obj): + return reduce + + # returning `NotImplemented` causes pickle to revert to the default + # behavior for this object. + return NotImplemented + + @override + def persistent_id(self, obj: object) -> Optional[str]: + if obj is self._unpickle_state: + return "unpickle_state" + else: + return None + + @classmethod + def dumps(cls, obj: object) -> bytes: + """ + Pickle an object. + """ + with io.BytesIO() as stream: + pickler = cls(stream) + pickler.dump(obj) + return stream.getvalue() + + @staticmethod + def loads(data: bytes, fake_mode: FakeTensorMode) -> object: + """ + Unpickle an object. + """ + state = _UnpickleState(fake_mode) + with io.BytesIO(data) as stream: + unpickler = _GraphUnpickler(stream, state) + return unpickler.load() + + +class _UnpickleState: + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.fake_mode = fake_mode + self.meta_converter: MetaConverter[FakeTensor] = MetaConverter() + + +# This token is passed when pickling to indicate that we want to use the +# unpickler's _UnpickleState as a parameter in that position. +_UnpickleStateToken = NewType("_UnpickleStateToken", object) + + +class _GraphUnpickler(pickle.Unpickler): + def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None: + super().__init__(stream) + self._unpickle_state = unpickle_state + + @override + def persistent_load(self, pid: object) -> object: + if pid == "unpickle_state": + return self._unpickle_state + else: + raise pickle.UnpicklingError("Invalid persistent ID") + + +class _ShapeEnvPickleData: + data: Dict[str, object] + + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: ShapeEnv + ) -> Tuple[ + Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken] + ]: + return cls.unpickle, (cls(obj), pickler._unpickle_state) + + def __init__(self, env: ShapeEnv) -> None: + # In theory pickle should recognize that a given ShapeEnv was already + # pickled and reuse the resulting _ShapeEnvPickleData (so two objects + # pointing at the same ShapeEnv get the same ShapeEnv out). + assert not env._translation_validation_enabled + self.data = env.__dict__.copy() + del self.data["tracked_fakes"] + del self.data["fake_tensor_cache"] + + def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv: + # Fill in the existing ShapeEnv rather than creating a new one + assert unpickle_state.fake_mode + assert unpickle_state.fake_mode.shape_env + + for k, v in self.data.items(): + setattr(unpickle_state.fake_mode.shape_env, k, v) + + return unpickle_state.fake_mode.shape_env + + +class _SymNodePickleData: + @classmethod + def reduce_helper( + cls, + pickler: GraphPickler, + obj: _SymNodeT, + ) -> Tuple[ + Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken] + ]: + args = (cls(obj.node), pickler._unpickle_state) + if isinstance(obj, torch.SymInt): + return _SymNodePickleData.unpickle_sym_int, args + else: + raise NotImplementedError(f"Unhandled SymNode type {type(obj)}") + + def __init__(self, node: SymNode) -> None: + self.expr = node._expr + self.shape_env = node.shape_env + self.pytype = node.pytype + self.hint = node._hint + + def _to_sym_node(self) -> SymNode: + from torch.fx.experimental.sym_node import SymNode + + assert self.shape_env is not None + return SymNode(self.expr, self.shape_env, self.pytype, self.hint) + + def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt: + return torch.SymInt(self._to_sym_node()) + + +class _TensorPickleData: + metadata: MetaTensorDesc[FakeTensor] + + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: FakeTensor + ) -> Tuple[ + Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken] + ]: + return cls.unpickle, ( + cls(pickler._meta_tensor_describer, obj), + pickler._unpickle_state, + ) + + def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None: + # THINGS TO WORRY ABOUT: + # 1. Need to make sure that two tensors with the same id end up with the + # same id on the other side of the wire. + + metadata = describer.describe_tensor(t) + + # view_func is fine if it's either None or a _FakeTensorViewFunc. A + # custom one (which is basically a lambda) can't be serialized. + assert not metadata.view_func or isinstance( + metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc + ) + self.metadata = dataclasses.replace(metadata, fake_mode=None) + + # Some debugging/verification + for k in MetaTensorDesc._UNSERIALIZABLE: + if k in ("fake_mode", "view_func"): + continue + assert ( + getattr(self.metadata, k) is None + ), f"not None: {k}: {getattr(self.metadata, k)}" + + def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: + # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? + metadata = dataclasses.replace( + self.metadata, + fake_mode=unpickle_state.fake_mode, + ) + + def with_fake( + make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str] + ) -> FakeTensor: + with no_dispatch(): + return FakeTensor( + unpickle_state.fake_mode, + make_meta_t(), + device, + ) + + return unpickle_state.meta_converter.meta_tensor( + metadata, + unpickle_state.fake_mode.shape_env, + with_fake, + None, + None, + ) + + +class _TorchNumpyPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: object + ) -> Optional[ + Tuple[ + Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken] + ] + ]: + if data := cls.from_object(obj): + return (cls.unpickle, (data, pickler._unpickle_state)) + else: + return None + + def __init__(self, mod: str, name: str) -> None: + self.mod = mod + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]: + np = getattr(importlib.import_module(self.mod), self.name) + return torch._dynamo.variables.misc.get_np_to_tnp_map()[np] + + @classmethod + def from_object(cls, tnp: object) -> Optional[Self]: + if not callable(tnp): + return None + + tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map() + try: + if not (np := tnp_to_np.get(tnp)): + return None + except TypeError: + return None + + if not (mod := getattr(np, "__module__", None)): + mod = "numpy" + + if not (name := getattr(np, "__name__", None)): + return None + + assert np == getattr(importlib.import_module(mod), name) + return cls(mod, name) + + +class _GraphModulePickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: torch.fx.GraphModule + ) -> Tuple[ + Callable[[Self, _UnpickleState], torch.fx.GraphModule], + Tuple[Self, _UnpickleStateToken], + ]: + return cls.unpickle, ( + cls(obj), + pickler._unpickle_state, + ) + + def __init__(self, gm: torch.fx.GraphModule) -> None: + # Need to do this to ensure the code is created for later pickling. + if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule): + _python_code = gm._real_recompile() + else: + _python_code = gm.recompile() + self.gm_dict = gm.__dict__.copy() + del self.gm_dict["_graph"] + self.graph = _GraphPickleData(gm._graph) + + def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule: + gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule) + gm.__dict__ = self.gm_dict + gm._graph = self.graph.unpickle(gm, unpickle_state) + return gm + + +class _NodePickleData: + def __init__( + self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"] + ) -> None: + self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args) + self.kwargs = pytree.tree_map_only( + torch.fx.Node, lambda n: mapping[n], node.kwargs + ) + # -- self.graph = node.graph + self.name = node.name + self.op = node.op + self.target = _OpPickleData.pickle(node.target) + # self.input_nodes = node._input_nodes + # self.users = node.users + self.type = node.type + # self.sort_key = node._sort_key + # self.repr_fn = node._repr_fn + # self.meta = node.meta + self.meta = node.meta + + def unpickle( + self, + graph: torch.fx.Graph, + mapping: Dict["_NodePickleData", torch.fx.Node], + unpickle_state: _UnpickleState, + ) -> torch.fx.Node: + args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args) + kwargs = pytree.tree_map_only( + _NodePickleData, lambda n: mapping[n], self.kwargs + ) + target = self.target.unpickle(unpickle_state) + assert callable(target) or isinstance(target, str) + node = graph.create_node(self.op, target, args, kwargs, self.name, self.type) + node.meta = self.meta + return node + + +class _OpPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, op: object + ) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]: + result = cls.pickle(op) + return (result.unpickle, (pickler._unpickle_state,)) + + @classmethod + def pickle(cls, op: object) -> "_OpPickleData": + if isinstance(op, str): + return _OpStrPickleData(op) + + name = torch.fx.Node._pretty_print_target(op) + if isinstance(op, torch._ops.OpOverload): + return cls._pickle_op(name, _OpOverloadPickleData) + elif isinstance(op, torch._ops.OpOverloadPacket): + return cls._pickle_op(name, _OpOverloadPacketPickleData) + elif name.startswith(("builtins.", "math.", "torch.")): + root, detail = name.split(".", 1) + return _OpBuiltinPickleData(root, detail) + elif name.startswith("operator."): + _, detail = name.split(".", 1) + return _OpOperatorPickleData(detail) + else: + # TODO: raise a BypassFxGraphCache so we will just bypass this one... + raise NotImplementedError(f"TARGET: {type(op)} {op} {name}") + + @staticmethod + def _pickle_op( + name: str, + datacls: Union[ + Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"] + ], + ) -> "_OpPickleData": + if not name.startswith("torch.ops.aten"): # TODO: What's the full list? + from torch._inductor.codecache import BypassFxGraphCache + + raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}") + return datacls(name) + + @abstractmethod + def unpickle(self, unpickle_state: _UnpickleState) -> object: + pass + + @classmethod + def _lookup_global_by_name(cls, name: str) -> object: + """ + Like `globals()[name]` but supports dotted names. + """ + if "." in name: + mod, rest = name.split(".", 1) + root = globals()[mod] + return cls._getattr_by_name(root, rest) + else: + return globals()[name] + + @staticmethod + def _getattr_by_name(root: object, name: str) -> object: + """ + Like `getattr(root, name)` but supports dotted names. + """ + while "." in name: + mod, name = name.split(".", 1) + root = getattr(root, mod) + return getattr(root, name) + + +class _OpStrPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> str: + return self.name + + +class _OpOverloadPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload: + obj = self._lookup_global_by_name(self.name) + assert isinstance(obj, torch._ops.OpOverload) + return obj + + +class _OpOverloadPacketPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket: + obj = self._lookup_global_by_name(self.name) + assert isinstance(obj, torch._ops.OpOverloadPacket) + return obj + + +class _OpBuiltinPickleData(_OpPickleData): + def __init__(self, root: str, name: str) -> None: + self.root = root + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> object: + if self.root == "builtins": + return __builtins__.get(self.name) # type: ignore[attr-defined] + elif self.root == "math": + import math + + return self._getattr_by_name(math, self.name) + elif self.root == "torch": + return self._getattr_by_name(torch, self.name) + else: + raise NotImplementedError + + +class _OpOperatorPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> object: + import operator + + return self._getattr_by_name(operator, self.name) + + +class _GraphPickleData: + def __init__(self, graph: torch.fx.Graph) -> None: + self.tracer_cls = graph._tracer_cls + self.tracer_extras = graph._tracer_extras + + nodes: Dict[torch.fx.Node, _NodePickleData] = {} + for node in graph.nodes: + nodes[node] = _NodePickleData(node, nodes) + self.nodes = tuple(nodes.values()) + + # Unpickled variables: + # self._used_names = graph._used_names + # -- self._insert = self._root.prepend + # self._len = graph._len + # self._graph_namespace = graph._graph_namespace + # self._owning_module = graph._owning_module + # self._codegen = graph._codegen + # self._co_fields: Dict[str, Any] = graph._co_fields + # -- self._find_nodes_lookup_table = _FindNodesLookupTable() + + def unpickle( + self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState + ) -> torch.fx.Graph: + graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras) + + nodes: Dict[_NodePickleData, torch.fx.Node] = {} + for nd in self.nodes: + nodes[nd] = nd.unpickle(graph, nodes, unpickle_state) + + return graph + + +class _TracingContextPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: torch._guards.TracingContext + ) -> Tuple[ + Callable[[Self, _UnpickleState], torch._guards.TracingContext], + Tuple[Self, _UnpickleStateToken], + ]: + return ( + cls.unpickle, + ( + cls(obj), + pickler._unpickle_state, + ), + ) + + def __init__(self, context: TracingContext) -> None: + # TODO: Do we really need all of this? + self.module_context = context.module_context + self.frame_summary_stack = context.frame_summary_stack + self.loc_in_frame = context.loc_in_frame + self.aot_graph_name = context.aot_graph_name + self.params_flat = context.params_flat + self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses + self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index + self.output_strides = context.output_strides + self.force_unspec_int_unbacked_size_like = ( + context.force_unspec_int_unbacked_size_like + ) + # Not saved (because it's difficult and maybe not needed?): + # self.fw_metadata = context.fw_metadata + # self.guards_context = None + # self.global_context = None + # self.fake_mode = None + # self.fakify_first_call = None + # self.hop_dispatch_set_cache = None + # self.tensor_to_context = context.tensor_to_context + + def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext: + context = TracingContext(unpickle_state.fake_mode) + context.module_context = self.module_context + context.frame_summary_stack = self.frame_summary_stack + context.loc_in_frame = self.loc_in_frame + context.aot_graph_name = self.aot_graph_name + context.params_flat = self.params_flat + context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses + context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index + context.output_strides = self.output_strides + context.force_unspec_int_unbacked_size_like = ( + self.force_unspec_int_unbacked_size_like + ) + return context diff --git a/torch/fx/node.py b/torch/fx/node.py index 87cd4a04aeb..39c7f82d8c2 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -602,7 +602,8 @@ class Node(_NodeBase): return self._repr_fn(self) return self.name - def _pretty_print_target(self, target: object) -> str: + @staticmethod + def _pretty_print_target(target: object) -> str: """ Make target printouts more user-friendly. 1) builtins will be printed as `builtins.xyz`