Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)

functorch used to have a switch that enables/disables autograd.Function.
That switch now enables/disables torch.autograd.function._SingleLevelFunction, so
I've renamed it accordingly.

We could just delete the switch because users should not be directly
working with torch.autograd.function._SingleLevelFunction. However,
it was useful for debugging when something went wrong when I was
implementing the autograd.Function <> functorch interaction, so I want
to keep it around as a debugging tool for a while since the code is
already there.

Test Plan:
- updated tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025
Approved by: https://github.com/soulitzer
This commit is contained in:
Richard Zou 2023-01-16 19:04:47 -08:00 committed by PyTorch MergeBot
parent 14ff58d4fa
commit 7aaad0b832
9 changed files with 42 additions and 66 deletions

View File

@ -28,10 +28,7 @@ struct TORCH_API FuncTorchTLSBase {
virtual ~FuncTorchTLSBase() = default; virtual ~FuncTorchTLSBase() = default;
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0; virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
// functorch doesn't always work with autograd.Function. virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
// This is a hook to get into functorch -- functorch will determine
// if it should raise an error message
virtual int64_t checkSupportsAutogradFunction() const = 0;
virtual void checkSupportsInplaceRequiresGrad() const = 0; virtual void checkSupportsInplaceRequiresGrad() const = 0;
virtual void checkSupportsRetainGrad() const = 0; virtual void checkSupportsRetainGrad() const = 0;
}; };

View File

@ -91,10 +91,11 @@ class FuncTorchTLS : public FuncTorchTLSBase {
return result; return result;
} }
int64_t checkSupportsAutogradFunction() const override { int64_t checkSupportsSingleLevelAutogradFunction() const override {
TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(), TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() == 0 || getSingleLevelAutogradFunctionAllowed(),
"functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ", "functorch functions (vmap, grad, vjp, etc.) incorrectly used with ",
"Please rewrite your function to not use autograd.Function while we work on fixing this"); "torch.autograd.function._SingleLevelFunction. ",
"This is not expected, please file a bug.");
return 0; return 0;
} }
@ -119,7 +120,7 @@ class FuncTorchTLS : public FuncTorchTLSBase {
std::vector<DynamicLayer> dynamicLayerStack; std::vector<DynamicLayer> dynamicLayerStack;
bool allow_inplace_requires_grad_ = false; bool allow_inplace_requires_grad_ = false;
bool allow_autograd_function_ = false; bool allow_single_level_autograd_function_ = false;
}; };
static FuncTorchTLS* getRawFunctorchTLS() { static FuncTorchTLS* getRawFunctorchTLS() {
@ -143,14 +144,14 @@ bool getInplaceRequiresGradAllowed() {
return functorch_tls->allow_inplace_requires_grad_; return functorch_tls->allow_inplace_requires_grad_;
} }
void setAutogradFunctionAllowed(bool allowed) { void setSingleLevelAutogradFunctionAllowed(bool allowed) {
auto* functorch_tls = getRawFunctorchTLS(); auto* functorch_tls = getRawFunctorchTLS();
functorch_tls->allow_autograd_function_ = allowed; functorch_tls->allow_single_level_autograd_function_ = allowed;
} }
bool getAutogradFunctionAllowed() { bool getSingleLevelAutogradFunctionAllowed() {
auto* functorch_tls = getRawFunctorchTLS(); auto* functorch_tls = getRawFunctorchTLS();
return functorch_tls->allow_autograd_function_; return functorch_tls->allow_single_level_autograd_function_;
} }
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() { static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {

View File

@ -108,11 +108,14 @@ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack); TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
// While a functorch transform is active, autograd.Function is disabled // While a functorch transform is active, torch.autograd.function._SingleLevelFunction
// by default. The following two APIs are APIs for enabling // is disabled by default. The following two APIs are APIs for enabling
// autograd.Function. These are not user-facing APIs. // it. These are not user-facing APIs. We can delete this in the future, but
TORCH_API void setAutogradFunctionAllowed(bool allowed); // it is useful for debugging when something goes wrong with the
TORCH_API bool getAutogradFunctionAllowed(); // autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
// because it leads to loud errors if something is incorrect.
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
// While a functorch grad transform is active, Tensor.requires_grad_() gets // While a functorch grad transform is active, Tensor.requires_grad_() gets
// disabled. These two functions are the mechanism to controlling that. // disabled. These two functions are the mechanism to controlling that.

View File

@ -38,7 +38,7 @@ from torch._functorch.make_functional import (
from torch._functorch.eager_transforms import _slice_argnums from torch._functorch.eager_transforms import _slice_argnums
from functorch.experimental import functionalize from functorch.experimental import functionalize
from torch._ops import PyOperator from torch._ops import PyOperator
from torch._functorch.utils import enable_autograd_function from torch._functorch.utils import enable_single_level_autograd_function
from torch.autograd.function import _set_autograd_function_extension_enabled from torch.autograd.function import _set_autograd_function_extension_enabled
import torch.autograd.forward_ad as fwAD import torch.autograd.forward_ad as fwAD
from torch.func import functional_call, stack_module_state from torch.func import functional_call, stack_module_state
@ -3035,36 +3035,6 @@ class TestComposability(TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
grad(f)(x) grad(f)(x)
def test_autograd_function_debug_switch(self, device):
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.sin()
@staticmethod
def backward(ctx, gy):
x, = ctx.saved_tensors
return gy * x.cos()
x = torch.randn([])
with torch.autograd.function._set_autograd_function_extension_enabled(False):
# by default, autograd.Function is disabled in a functorch transform
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
grad(MySin.apply)(x)
# we have a debug switch to allow it
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
try:
torch._C._functorch.set_autograd_function_allowed(True)
self.assertTrue(torch._C._functorch.get_autograd_function_allowed())
y = grad(MySin.apply)(x)
finally:
torch._C._functorch.set_autograd_function_allowed(False)
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
self.assertEqual(y, x.cos())
@_set_autograd_function_extension_enabled() @_set_autograd_function_extension_enabled()
@parametrize('transform', [ @parametrize('transform', [
'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize' 'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
@ -4346,7 +4316,7 @@ def construct_sum_pyop():
def backward(ctx, gy): def backward(ctx, gy):
return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
with enable_autograd_function(): with enable_single_level_autograd_function():
return MySum.apply(x, dim) return MySum.apply(x, dim)
@mysum.py_impl(torch._C.DispatchKey.AutogradCPU) @mysum.py_impl(torch._C.DispatchKey.AutogradCPU)

View File

@ -19,8 +19,8 @@ def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]:
def current_level() -> int: ... def current_level() -> int: ...
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ... def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
def set_autograd_function_allowed(allowed: bool) -> None: ... def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
def get_autograd_function_allowed() -> bool: ... def get_single_level_autograd_function_allowed() -> bool: ...
# Defined in aten/src/ATen/functorch/Interpreter.h # Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum): class TransformType(Enum):

View File

@ -1,7 +1,7 @@
import torch import torch
from torch._ops import PyOperator from torch._ops import PyOperator
from torch._C._functorch import TransformType from torch._C._functorch import TransformType
from torch._functorch.utils import enable_autograd_function from torch._functorch.utils import enable_single_level_autograd_function
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C._functorch import ( from torch._C._functorch import (
_wrap_for_grad, _wrap_for_grad,
@ -86,7 +86,7 @@ custom_function_call = CustomFunctionPyOperator()
@custom_function_call.py_impl(TransformType.Jvp) @custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands): def custom_function_call_grad(interpreter, autograd_function, *operands):
Generated = generate_single_level_function(interpreter, autograd_function) Generated = generate_single_level_function(interpreter, autograd_function)
with enable_autograd_function(): with enable_single_level_autograd_function():
flat_out = Generated.apply(*operands) flat_out = Generated.apply(*operands)
return flat_out return flat_out

View File

@ -1,19 +1,19 @@
import contextlib import contextlib
import torch import torch
from torch._C._functorch import ( from torch._C._functorch import (
set_autograd_function_allowed, set_single_level_autograd_function_allowed,
get_autograd_function_allowed, get_single_level_autograd_function_allowed,
unwrap_if_dead, unwrap_if_dead,
) )
@contextlib.contextmanager @contextlib.contextmanager
def enable_autograd_function(): def enable_single_level_autograd_function():
try: try:
prev_state = get_autograd_function_allowed() prev_state = get_single_level_autograd_function_allowed()
set_autograd_function_allowed(True) set_single_level_autograd_function_allowed(True)
yield yield
finally: finally:
set_autograd_function_allowed(prev_state) set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args): def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons # NB: doesn't use tree_map_only for performance reasons

View File

@ -877,10 +877,15 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()), unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
seq_id); seq_id);
// Temporary hack to improve functorch UX. We'll find a better solution.
const auto& functorch_tls = at::functorch::functorchTLSAccessor(); const auto& functorch_tls = at::functorch::functorchTLSAccessor();
if (functorch_tls) { if (functorch_tls) {
functorch_tls->checkSupportsAutogradFunction(); // autograd.Function support for functorch is handled in Python.
// If we have gotten here, then either we are dealing with a
// torch.autograd.function._SingleLevelFunction, or something in
// the implementation went wrong.
// The following code is useful for debugging when something goes wrong
// because it'll raise a loud error (instead of being silently incorrect).
functorch_tls->checkSupportsSingleLevelAutogradFunction();
} }
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));

View File

@ -447,11 +447,11 @@ void initFuncTorchBindings(PyObject* module) {
"get_inplace_requires_grad_allowed", "get_inplace_requires_grad_allowed",
&at::functorch::getInplaceRequiresGradAllowed); &at::functorch::getInplaceRequiresGradAllowed);
m.def( m.def(
"set_autograd_function_allowed", "set_single_level_autograd_function_allowed",
&at::functorch::setAutogradFunctionAllowed); &at::functorch::setSingleLevelAutogradFunctionAllowed);
m.def( m.def(
"get_autograd_function_allowed", "get_single_level_autograd_function_allowed",
&at::functorch::getAutogradFunctionAllowed); &at::functorch::getSingleLevelAutogradFunctionAllowed);
m.def("unwrap_if_dead", &unwrapIfDead); m.def("unwrap_if_dead", &unwrapIfDead);
m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper); m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
m.def("dlevel", &dlevel, "dlevel"); m.def("dlevel", &dlevel, "dlevel");