From 7aaad0b8325b83cf1bb551db271a1cb370b9b88c Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 16 Jan 2023 19:04:47 -0800 Subject: [PATCH] Rename flag that enables/disables _SingleLevelFunction for functorch (#92025) functorch used to have a switch that enables/disables autograd.Function. That switch now enables/disables torch.autograd.function._SingleLevelFunction, so I've renamed it accordingly. We could just delete the switch because users should not be directly working with torch.autograd.function._SingleLevelFunction. However, it was useful for debugging when something went wrong when I was implementing the autograd.Function <> functorch interaction, so I want to keep it around as a debugging tool for a while since the code is already there. Test Plan: - updated tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025 Approved by: https://github.com/soulitzer --- aten/src/ATen/FuncTorchTLS.h | 5 +--- aten/src/ATen/functorch/DynamicLayer.cpp | 19 ++++++------- aten/src/ATen/functorch/DynamicLayer.h | 13 +++++---- test/functorch/test_eager_transforms.py | 34 ++---------------------- torch/_C/_functorch.pyi | 4 +-- torch/_functorch/autograd_function.py | 4 +-- torch/_functorch/utils.py | 12 ++++----- torch/csrc/autograd/python_function.cpp | 9 +++++-- torch/csrc/functorch/init.cpp | 8 +++--- 9 files changed, 42 insertions(+), 66 deletions(-) diff --git a/aten/src/ATen/FuncTorchTLS.h b/aten/src/ATen/FuncTorchTLS.h index 9242a1d177a..b8fde728fad 100644 --- a/aten/src/ATen/FuncTorchTLS.h +++ b/aten/src/ATen/FuncTorchTLS.h @@ -28,10 +28,7 @@ struct TORCH_API FuncTorchTLSBase { virtual ~FuncTorchTLSBase() = default; virtual std::unique_ptr deepcopy() const = 0; - // functorch doesn't always work with autograd.Function. - // This is a hook to get into functorch -- functorch will determine - // if it should raise an error message - virtual int64_t checkSupportsAutogradFunction() const = 0; + virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; virtual void checkSupportsInplaceRequiresGrad() const = 0; virtual void checkSupportsRetainGrad() const = 0; }; diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index cd8db068b5e..5acec2a3b01 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -91,10 +91,11 @@ class FuncTorchTLS : public FuncTorchTLSBase { return result; } - int64_t checkSupportsAutogradFunction() const override { - TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(), - "functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ", - "Please rewrite your function to not use autograd.Function while we work on fixing this"); + int64_t checkSupportsSingleLevelAutogradFunction() const override { + TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() == 0 || getSingleLevelAutogradFunctionAllowed(), + "functorch functions (vmap, grad, vjp, etc.) incorrectly used with ", + "torch.autograd.function._SingleLevelFunction. ", + "This is not expected, please file a bug."); return 0; } @@ -119,7 +120,7 @@ class FuncTorchTLS : public FuncTorchTLSBase { std::vector dynamicLayerStack; bool allow_inplace_requires_grad_ = false; - bool allow_autograd_function_ = false; + bool allow_single_level_autograd_function_ = false; }; static FuncTorchTLS* getRawFunctorchTLS() { @@ -143,14 +144,14 @@ bool getInplaceRequiresGradAllowed() { return functorch_tls->allow_inplace_requires_grad_; } -void setAutogradFunctionAllowed(bool allowed) { +void setSingleLevelAutogradFunctionAllowed(bool allowed) { auto* functorch_tls = getRawFunctorchTLS(); - functorch_tls->allow_autograd_function_ = allowed; + functorch_tls->allow_single_level_autograd_function_ = allowed; } -bool getAutogradFunctionAllowed() { +bool getSingleLevelAutogradFunctionAllowed() { auto* functorch_tls = getRawFunctorchTLS(); - return functorch_tls->allow_autograd_function_; + return functorch_tls->allow_single_level_autograd_function_; } static std::vector& dynamicLayerStackAccessor() { diff --git a/aten/src/ATen/functorch/DynamicLayer.h b/aten/src/ATen/functorch/DynamicLayer.h index 060470a2b63..2d5682b8bf4 100644 --- a/aten/src/ATen/functorch/DynamicLayer.h +++ b/aten/src/ATen/functorch/DynamicLayer.h @@ -108,11 +108,14 @@ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor); TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); -// While a functorch transform is active, autograd.Function is disabled -// by default. The following two APIs are APIs for enabling -// autograd.Function. These are not user-facing APIs. -TORCH_API void setAutogradFunctionAllowed(bool allowed); -TORCH_API bool getAutogradFunctionAllowed(); +// While a functorch transform is active, torch.autograd.function._SingleLevelFunction +// is disabled by default. The following two APIs are APIs for enabling +// it. These are not user-facing APIs. We can delete this in the future, but +// it is useful for debugging when something goes wrong with the +// autograd.Function <> functorch interaction, which uses _SingleLevelFunction, +// because it leads to loud errors if something is incorrect. +TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed); +TORCH_API bool getSingleLevelAutogradFunctionAllowed(); // While a functorch grad transform is active, Tensor.requires_grad_() gets // disabled. These two functions are the mechanism to controlling that. diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 4a94645239c..a5e9d7bc2c2 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -38,7 +38,7 @@ from torch._functorch.make_functional import ( from torch._functorch.eager_transforms import _slice_argnums from functorch.experimental import functionalize from torch._ops import PyOperator -from torch._functorch.utils import enable_autograd_function +from torch._functorch.utils import enable_single_level_autograd_function from torch.autograd.function import _set_autograd_function_extension_enabled import torch.autograd.forward_ad as fwAD from torch.func import functional_call, stack_module_state @@ -3035,36 +3035,6 @@ class TestComposability(TestCase): with self.assertRaises(RuntimeError): grad(f)(x) - def test_autograd_function_debug_switch(self, device): - class MySin(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return x.sin() - - @staticmethod - def backward(ctx, gy): - x, = ctx.saved_tensors - return gy * x.cos() - - x = torch.randn([]) - - with torch.autograd.function._set_autograd_function_extension_enabled(False): - # by default, autograd.Function is disabled in a functorch transform - with self.assertRaisesRegex(RuntimeError, "autograd.Function"): - grad(MySin.apply)(x) - - # we have a debug switch to allow it - self.assertFalse(torch._C._functorch.get_autograd_function_allowed()) - try: - torch._C._functorch.set_autograd_function_allowed(True) - self.assertTrue(torch._C._functorch.get_autograd_function_allowed()) - y = grad(MySin.apply)(x) - finally: - torch._C._functorch.set_autograd_function_allowed(False) - self.assertFalse(torch._C._functorch.get_autograd_function_allowed()) - self.assertEqual(y, x.cos()) - @_set_autograd_function_extension_enabled() @parametrize('transform', [ 'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize' @@ -4346,7 +4316,7 @@ def construct_sum_pyop(): def backward(ctx, gy): return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None - with enable_autograd_function(): + with enable_single_level_autograd_function(): return MySum.apply(x, dim) @mysum.py_impl(torch._C.DispatchKey.AutogradCPU) diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index 4b7b19dcdba..bafa63daa44 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -19,8 +19,8 @@ def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: def current_level() -> int: ... def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ... -def set_autograd_function_allowed(allowed: bool) -> None: ... -def get_autograd_function_allowed() -> bool: ... +def set_single_level_autograd_function_allowed(allowed: bool) -> None: ... +def get_single_level_autograd_function_allowed() -> bool: ... # Defined in aten/src/ATen/functorch/Interpreter.h class TransformType(Enum): diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 8fb3f1e4334..eb234102c7c 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -1,7 +1,7 @@ import torch from torch._ops import PyOperator from torch._C._functorch import TransformType -from torch._functorch.utils import enable_autograd_function +from torch._functorch.utils import enable_single_level_autograd_function import torch.utils._pytree as pytree from torch._C._functorch import ( _wrap_for_grad, @@ -86,7 +86,7 @@ custom_function_call = CustomFunctionPyOperator() @custom_function_call.py_impl(TransformType.Jvp) def custom_function_call_grad(interpreter, autograd_function, *operands): Generated = generate_single_level_function(interpreter, autograd_function) - with enable_autograd_function(): + with enable_single_level_autograd_function(): flat_out = Generated.apply(*operands) return flat_out diff --git a/torch/_functorch/utils.py b/torch/_functorch/utils.py index 13326d05f6a..d09535850fb 100644 --- a/torch/_functorch/utils.py +++ b/torch/_functorch/utils.py @@ -1,19 +1,19 @@ import contextlib import torch from torch._C._functorch import ( - set_autograd_function_allowed, - get_autograd_function_allowed, + set_single_level_autograd_function_allowed, + get_single_level_autograd_function_allowed, unwrap_if_dead, ) @contextlib.contextmanager -def enable_autograd_function(): +def enable_single_level_autograd_function(): try: - prev_state = get_autograd_function_allowed() - set_autograd_function_allowed(True) + prev_state = get_single_level_autograd_function_allowed() + set_single_level_autograd_function_allowed(True) yield finally: - set_autograd_function_allowed(prev_state) + set_single_level_autograd_function_allowed(prev_state) def unwrap_dead_wrappers(args): # NB: doesn't use tree_map_only for performance reasons diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index e1f031fec22..0b08a52778e 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -877,10 +877,15 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()), seq_id); - // Temporary hack to improve functorch UX. We'll find a better solution. const auto& functorch_tls = at::functorch::functorchTLSAccessor(); if (functorch_tls) { - functorch_tls->checkSupportsAutogradFunction(); + // autograd.Function support for functorch is handled in Python. + // If we have gotten here, then either we are dealing with a + // torch.autograd.function._SingleLevelFunction, or something in + // the implementation went wrong. + // The following code is useful for debugging when something goes wrong + // because it'll raise a loud error (instead of being silently incorrect). + functorch_tls->checkSupportsSingleLevelAutogradFunction(); } THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index b97bad6c6de..232b403f668 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -447,11 +447,11 @@ void initFuncTorchBindings(PyObject* module) { "get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed); m.def( - "set_autograd_function_allowed", - &at::functorch::setAutogradFunctionAllowed); + "set_single_level_autograd_function_allowed", + &at::functorch::setSingleLevelAutogradFunctionAllowed); m.def( - "get_autograd_function_allowed", - &at::functorch::getAutogradFunctionAllowed); + "get_single_level_autograd_function_allowed", + &at::functorch::getSingleLevelAutogradFunctionAllowed); m.def("unwrap_if_dead", &unwrapIfDead); m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper); m.def("dlevel", &dlevel, "dlevel");