diff --git a/BUILD.bazel b/BUILD.bazel index 789695351c7..2d3e1d7cdf7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -376,6 +376,7 @@ cc_library( ":torch_headers", "@fbgemm", "@ideep", + "@nlohmann", ], alwayslink = True, ) diff --git a/aten/src/ATen/functorch/Interpreter.h b/aten/src/ATen/functorch/Interpreter.h index bdea11d3b2a..1c76230fb45 100644 --- a/aten/src/ATen/functorch/Interpreter.h +++ b/aten/src/ATen/functorch/Interpreter.h @@ -8,6 +8,8 @@ #include #include +#include + namespace at::functorch { // NOTE: [functorch interpreter stack] @@ -91,24 +93,95 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t); struct VmapInterpreterMeta { explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) : batchSize_(std::move(batchSize)), randomness_(randomness) {} + c10::SymInt batchSize_; RandomnessType randomness_; + + VmapInterpreterMeta() = default; + VmapInterpreterMeta(const VmapInterpreterMeta&) = default; + VmapInterpreterMeta(VmapInterpreterMeta&&) = default; + VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default; + VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default; + ~VmapInterpreterMeta() = default; + + template + friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) { + if (json_t.batchSize_.is_heap_allocated()) { + throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet"); + } + json_j["batchSize"] = json_t.batchSize_.as_int_unchecked(); + json_j["randomness"] = static_cast(json_t.randomness_); + } + + template + friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) { + json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]); + json_t.randomness_ = static_cast(json_j["randomness"]); + } }; struct GradInterpreterMeta { explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} + GradInterpreterMeta() = default; + GradInterpreterMeta(const GradInterpreterMeta&) = default; + GradInterpreterMeta(GradInterpreterMeta&&) = default; + GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default; + GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default; + ~GradInterpreterMeta() = default; + bool prevGradMode_; + template + friend void to_json(T& json_j, const GradInterpreterMeta& json_t) { + json_j["prevGradMode"] = json_t.prevGradMode_; + } + + template + friend void from_json(const T& json_j, GradInterpreterMeta& json_t) { + json_t.prevGradMode_ = json_j["prevGradMode"]; + } }; struct JvpInterpreterMeta { explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} + JvpInterpreterMeta() = default; + JvpInterpreterMeta(const JvpInterpreterMeta&) = default; + JvpInterpreterMeta(JvpInterpreterMeta&&) = default; + JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default; + JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default; + ~JvpInterpreterMeta() = default; + bool prevFwdGradMode_; + template + friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) { + json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_; + } + + template + friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) { + json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"]; + } }; struct FunctionalizeInterpreterMeta { explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) : functionalizeAddBackViews_(functionalizeAddBackViews) {} + FunctionalizeInterpreterMeta() = default; + FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default; + FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default; + FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default; + FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default; + ~FunctionalizeInterpreterMeta() = default; + bool functionalizeAddBackViews_; + template + friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) { + json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_; + } + + template + friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) { + json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"]; + } }; typedef std::variant< @@ -172,6 +245,75 @@ struct Interpreter { // Please don't use this explicit Interpreter() = default; + template + friend void to_json(T& json_j, const Interpreter& json_t) { + json_j["type"] = static_cast(json_t.type_); + json_j["level"] = json_t.level_; + if (json_t.savedLocalDispatchKeySet_) { + json_j["savedLocalDispatchKeySet"] = { + {"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()}, + {"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()} + }; + } else { + json_j["savedLocalDispatchKeySet"] = nlohmann::json(); + } + json_j["is_alive"] = *json_t.is_alive_; + std::visit([&](auto&& arg) { + using V = std::decay_t; + if constexpr (std::is_same_v) { + json_j["meta"] = {{"Torch", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Grad", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Jvp", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Vmap", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Functionalize", arg}}; + } else { + static_assert(false && sizeof(V), "unknown variant case"); + } + }, json_t.meta_); + } + + template + friend void from_json(const T& json_j, Interpreter& json_t) { + json_t.type_ = static_cast(json_j["type"]); + json_t.level_ = json_j["level"]; + auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"]; + if (savedLocalDispatchKeySet.is_null()) { + json_t.savedLocalDispatchKeySet_ = std::nullopt; + } else { + c10::impl::PODLocalDispatchKeySet pod; + pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get())); + pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get())); + json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod); + } + json_t.is_alive_ = std::make_shared(json_j["is_alive"]); + auto meta = json_j["meta"]; + if (meta.contains("Torch")) { + json_t.meta_.emplace(meta["Torch"].template get()); + } else if (meta.contains("Grad")) { + json_t.meta_.emplace(meta["Grad"].template get()); + } else if (meta.contains("Jvp")) { + json_t.meta_.emplace(meta["Jvp"].template get()); + } else if (meta.contains("Vmap")) { + json_t.meta_.emplace(meta["Vmap"].template get()); + } else if (meta.contains("Functionalize")) { + json_t.meta_.emplace(meta["Functionalize"].template get()); + } else { + throw std::runtime_error("unknown interpreter metadata type"); + } + } + + std::string serialize() const { + return nlohmann::json(*this).dump(); + } + + static Interpreter deserialize(const std::string& serialized) { + return nlohmann::json::parse(serialized).get(); + } + private: explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta): type_(type), level_(level), is_alive_(std::make_shared(false)), meta_(std::move(meta)) {} diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 4c22137621f..8bb2b96dc4c 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -94,7 +94,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): fn.__closure__ or (), [], # TODO tf_mode_stack, code_options, - lambda gm, *args, **kwargs: gm.forward, + torch._dynamo.lookup_backend("eager"), one_graph=False, export=False, export_constraints=None, @@ -326,6 +326,126 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): with torch.autograd.forward_ad.dual_level(): self._test_check_fn(ref, loaded, {"x": x}, False) + def test_functorch_stack_match(self): + # Test when functorch stack is empty. + def fn(x): + return torch.func.jvp(torch.sin, (x,), (x,)) + + x = torch.randn(3, 4) + ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x) + + self._test_check_fn(ref, loaded, {"x": x}, True) + with torch._functorch.vmap.vmap_increment_nesting(2, "error"): + self._test_check_fn(ref, loaded, {"x": x}, False) + + def fn(x): + def g(x): + return torch.vmap(torch.func.grad(torch.sin))(x) + + return torch.vmap(g)(x) + + x = torch.randn(4, 5) + ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x) + self._test_check_fn(ref, loaded, {"x": x}, True) + with torch._functorch.eager_transforms.grad_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, False) + + # Test when there are more than 0 functorch layers. + # Simulate the case where torch.compile is nested inside eager transforms. + + # Case 1: vmap + def fn(x): + return x.sum() + + ref = loaded = None + + def run(x): + nonlocal ref, loaded + # Turn off automatic dynamic shape to so that functionalization + # doesn't produce extra SymInt to serialize. + with torch._dynamo.config.patch(automatic_dynamic_shapes=False): + ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x) + return fn(x) + + torch.vmap(run)(x) + + self._test_check_fn(ref, loaded, {"x": x}, False) + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + self._test_check_fn(ref, loaded, {"x": x}, True) + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + self._test_check_fn(ref, loaded, {"x": x}, False) + + with torch._functorch.eager_transforms.grad_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, False) + + # Case 2: grad + x = torch.randn(3, 2) + ref = loaded = None + torch.func.grad(run)(x) + self._test_check_fn(ref, loaded, {"x": x}, False) + with torch._functorch.eager_transforms.grad_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, True) + with torch._functorch.eager_transforms.grad_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, False) + + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + self._test_check_fn(ref, loaded, {"x": x}, False) + + # Case 3: jvp + vmap + x = torch.randn(3, 4) + ref = loaded = None + + def fn(x): + return torch.func.jvp(torch.sin, (x,), (x,)) + + torch.func.jvp(torch.vmap(run), (x,), (x,)) + self._test_check_fn(ref, loaded, {"x": x}, False) + + with torch._functorch.eager_transforms.jvp_increment_nesting(): + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + self._test_check_fn(ref, loaded, {"x": x}, True) + + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + with torch._functorch.eager_transforms.jvp_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, False) + + # Case 4: functionalize + x = torch.randn(3, 2) + ref = loaded = None + torch.func.functionalize(run)(x) + self._test_check_fn(ref, loaded, {"x": x}, False) + + torch._C._functorch._func_increment_nesting(True) + try: + self._test_check_fn(ref, loaded, {"x": x}, True) + finally: + torch._C._functorch._func_decrement_nesting() + + with torch._functorch.eager_transforms.jvp_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, False) + + # Case 5: vmap + grad + def fn(x): + return x.sum() + + x = torch.randn(3, 2) + ref = loaded = None + torch.vmap(torch.func.grad(run))(x) + self._test_check_fn(ref, loaded, {"x": x}, False) + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + with torch._functorch.eager_transforms.grad_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, True) + + with torch._functorch.eager_transforms.grad_increment_nesting(): + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + self._test_check_fn(ref, loaded, {"x": x}, False) + + with torch._functorch.vmap.vmap_increment_nesting(1, "error"): + self._test_check_fn(ref, loaded, {"x": x}, False) + + with torch._functorch.eager_transforms.grad_increment_nesting(): + self._test_check_fn(ref, loaded, {"x": x}, False) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index 4cdfb1346fd..2e37b3d1099 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -49,6 +49,9 @@ class RandomnessType(Enum): class CInterpreter: def key(self) -> TransformType: ... def level(self) -> int: ... + def serialize(self) -> bytes: ... + @staticmethod + def deserialize(bytes) -> CInterpreter: ... class CGradInterpreterPtr: def __init__(self, interpreter: CInterpreter) -> None: ... diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a4135f3be7d..437ac999566 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -475,7 +475,11 @@ def get_verbose_code_part(code_part: str, guard: Guard) -> str: extra = f" # {format_frame(fs, line=True)}" break elif guard.stack: - extra = f" # {format_frame(guard.stack.summary()[-1])}" + summary = guard.stack.summary() + if len(summary) > 0: + extra = f" # {format_frame(summary[-1])}" + else: + extra = " # " return f"{code_part:<60}{extra}" @@ -1591,7 +1595,7 @@ class GuardBuilder(GuardBuilderBase): def FUNCTORCH_STACK_MATCH(self, guard: Guard): # Invalidate functorch code if current level is different than # the one when FX graph was generated - cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters() + cis = self.check_fn_manager.output_graph.functorch_layers states = [ci.get_state() for ci in cis] code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] self._set_guard_export_info(guard, code) @@ -2522,7 +2526,13 @@ class GuardsStatePickler(pickle.Pickler): def _unpickle_dispatch_key_set(cls, raw_repr: int): return torch._C.DispatchKeySet.from_raw_repr(raw_repr) + @classmethod + def _unpickle_functorch_interpreter(cls, json: bytes): + return torch._C._functorch.CInterpreter.deserialize(json) + def reducer_override(self, obj): + import sympy + if isinstance(obj, torch.Tensor) and obj.device.type != "meta": return type(self)._unpickle_tensor, ( torch.empty_like(obj, device="meta"), @@ -2543,6 +2553,20 @@ class GuardsStatePickler(pickle.Pickler): elif isinstance(obj, torch._C.DispatchKeySet): return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),) + elif isinstance(obj, torch._C._functorch.CInterpreter): + return type(self)._unpickle_functorch_interpreter, (obj.serialize(),) + + elif ( + inspect.isclass(obj) + and issubclass(obj, sympy.Function) + and hasattr(obj, "_torch_handler_name") + ): + assert hasattr(obj, "_torch_unpickler") + return obj._torch_unpickler, (obj._torch_handler_name,) + + elif isinstance(obj, torch.SymInt): + raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})") + if type(obj).__qualname__ != type(obj).__name__: raise RuntimeError( f"Type {type(obj)} for object {obj} cannot be saved " diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 566f4daa585..c76f32c24cb 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -301,6 +301,7 @@ class OutputGraphGuardsState: # Map from graph input's `Source` to sizes / strides metadata input_source_to_sizes_strides: dict[Source, dict[str, Any]] dual_level: int + functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter] export: bool = False export_constraints: bool = False @@ -354,6 +355,7 @@ class OutputGraph(OutputGraphGuardsState): guard_on_key_order=set(), input_source_to_sizes_strides={}, dual_level=torch.autograd.forward_ad._current_level, + functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(), ) self.tracers = [SubgraphTracer(self, is_export=export)] # Map from graph input's `Source` to its `VariableTracker` to @@ -590,6 +592,7 @@ class OutputGraph(OutputGraphGuardsState): guard_on_key_order=self.guard_on_key_order, input_source_to_sizes_strides=self.input_source_to_sizes_strides, dual_level=self.dual_level, + functorch_layers=self.functorch_layers, export=self.export, export_constraints=self.export_constraints, _guards=self.guards, diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index 28bd74f28d3..2976e22c47a 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import contextlib from abc import ABC, abstractmethod +from functools import cached_property from typing import Any import torch @@ -79,6 +80,11 @@ class FuncTorchInterpreter(ABC): def check_state(self, state): return state == self.get_state() + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_cptr", None) + return state + @contextlib.contextmanager def temporarily_pop_interpreter_stack(): @@ -123,7 +129,10 @@ class VmapInterpreter(FuncTorchInterpreter): # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr # so that we can access methods specific to the vmap interpreter self._cdata = cdata - self._cptr = CVmapInterpreterPtr(cdata) + + @cached_property + def _cptr(self): + return CVmapInterpreterPtr(self._cdata) def process(self, op, args, kwargs): kernel = op.functorch_table[TransformType.Vmap] @@ -159,7 +168,10 @@ class GradInterpreter(FuncTorchInterpreter): assert cdata.key() == TransformType.Grad # See NOTE: [Interpreter cdata vs cptr] self._cdata = cdata - self._cptr = CGradInterpreterPtr(cdata) + + @cached_property + def _cptr(self): + return CGradInterpreterPtr(self._cdata) def lift(self, args, kwargs): args, kwargs = pytree.tree_map_only( @@ -193,7 +205,10 @@ class JvpInterpreter(FuncTorchInterpreter): assert cdata.key() == TransformType.Jvp # See NOTE: [Interpreter cdata vs cptr] self._cdata = cdata - self._cptr = CJvpInterpreterPtr(cdata) + + @cached_property + def _cptr(self): + return CJvpInterpreterPtr(self._cdata) def lift(self, args, kwargs): args, kwargs = pytree.tree_map_only( @@ -226,7 +241,10 @@ class FunctionalizeInterpreter(FuncTorchInterpreter): def __init__(self, cdata: CInterpreter): assert cdata.key() == TransformType.Functionalize self._cdata = cdata - self._cptr = CFunctionalizeInterpreterPtr(cdata) + + @cached_property + def _cptr(self): + return CFunctionalizeInterpreterPtr(self._cdata) def process(self, op, args, kwargs): kernel = op.functorch_table[TransformType.Functionalize] diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 5a6ad0e29cc..3ad53c3f403 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -575,7 +575,9 @@ void initFuncTorchBindings(PyObject* module) { .value("Different", RandomnessType::Different); py::class_(m, "CInterpreter") .def("key", &Interpreter::key) - .def("level", &Interpreter::level); + .def("level", &Interpreter::level) + .def("serialize", &Interpreter::serialize) + .def_static("deserialize", &Interpreter::deserialize); py::class_(m, "CGradInterpreterPtr") .def(py::init()) .def("key", &GradInterpreterPtr::key) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 61c63ece236..23d0bc24bf1 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1318,6 +1318,7 @@ def make_opaque_unary_fn(name): """ _torch_handler_name = name + _torch_unpickler = make_opaque_unary_fn @classmethod def eval(cls, a): @@ -1378,6 +1379,9 @@ def make_opaque_bitwise_fn(name, real_op_name): class BitwiseFn(sympy.Function): _torch_handler_name = name precedence: int = prec + _torch_unpickler = functools.partial( + make_opaque_bitwise_fn, real_op_name=real_op_name + ) @classmethod def eval(cls, a, b):