Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)

functorch used to have a switch that enables/disables autograd.Function.
That switch now enables/disables torch.autograd.function._SingleLevelFunction, so
I've renamed it accordingly.

We could just delete the switch because users should not be directly
working with torch.autograd.function._SingleLevelFunction. However,
it was useful for debugging when something went wrong when I was
implementing the autograd.Function <> functorch interaction, so I want
to keep it around as a debugging tool for a while since the code is
already there.

Test Plan:
- updated tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025
Approved by: https://github.com/soulitzer
This commit is contained in:
Richard Zou 2023-01-16 19:04:47 -08:00 committed by PyTorch MergeBot
parent 14ff58d4fa
commit 7aaad0b832
9 changed files with 42 additions and 66 deletions

View File

@ -28,10 +28,7 @@ struct TORCH_API FuncTorchTLSBase {
virtual ~FuncTorchTLSBase() = default;
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
// functorch doesn't always work with autograd.Function.
// This is a hook to get into functorch -- functorch will determine
// if it should raise an error message
virtual int64_t checkSupportsAutogradFunction() const = 0;
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
virtual void checkSupportsInplaceRequiresGrad() const = 0;
virtual void checkSupportsRetainGrad() const = 0;
};

View File

@ -91,10 +91,11 @@ class FuncTorchTLS : public FuncTorchTLSBase {
return result;
}
int64_t checkSupportsAutogradFunction() const override {
TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(),
"functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ",
"Please rewrite your function to not use autograd.Function while we work on fixing this");
int64_t checkSupportsSingleLevelAutogradFunction() const override {
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() == 0 || getSingleLevelAutogradFunctionAllowed(),
"functorch functions (vmap, grad, vjp, etc.) incorrectly used with ",
"torch.autograd.function._SingleLevelFunction. ",
"This is not expected, please file a bug.");
return 0;
}
@ -119,7 +120,7 @@ class FuncTorchTLS : public FuncTorchTLSBase {
std::vector<DynamicLayer> dynamicLayerStack;
bool allow_inplace_requires_grad_ = false;
bool allow_autograd_function_ = false;
bool allow_single_level_autograd_function_ = false;
};
static FuncTorchTLS* getRawFunctorchTLS() {
@ -143,14 +144,14 @@ bool getInplaceRequiresGradAllowed() {
return functorch_tls->allow_inplace_requires_grad_;
}
void setAutogradFunctionAllowed(bool allowed) {
void setSingleLevelAutogradFunctionAllowed(bool allowed) {
auto* functorch_tls = getRawFunctorchTLS();
functorch_tls->allow_autograd_function_ = allowed;
functorch_tls->allow_single_level_autograd_function_ = allowed;
}
bool getAutogradFunctionAllowed() {
bool getSingleLevelAutogradFunctionAllowed() {
auto* functorch_tls = getRawFunctorchTLS();
return functorch_tls->allow_autograd_function_;
return functorch_tls->allow_single_level_autograd_function_;
}
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {

View File

@ -108,11 +108,14 @@ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
// While a functorch transform is active, autograd.Function is disabled
// by default. The following two APIs are APIs for enabling
// autograd.Function. These are not user-facing APIs.
TORCH_API void setAutogradFunctionAllowed(bool allowed);
TORCH_API bool getAutogradFunctionAllowed();
// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
// is disabled by default. The following two APIs are APIs for enabling
// it. These are not user-facing APIs. We can delete this in the future, but
// it is useful for debugging when something goes wrong with the
// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
// because it leads to loud errors if something is incorrect.
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
// While a functorch grad transform is active, Tensor.requires_grad_() gets
// disabled. These two functions are the mechanism to controlling that.

View File

@ -38,7 +38,7 @@ from torch._functorch.make_functional import (
from torch._functorch.eager_transforms import _slice_argnums
from functorch.experimental import functionalize
from torch._ops import PyOperator
from torch._functorch.utils import enable_autograd_function
from torch._functorch.utils import enable_single_level_autograd_function
from torch.autograd.function import _set_autograd_function_extension_enabled
import torch.autograd.forward_ad as fwAD
from torch.func import functional_call, stack_module_state
@ -3035,36 +3035,6 @@ class TestComposability(TestCase):
with self.assertRaises(RuntimeError):
grad(f)(x)
def test_autograd_function_debug_switch(self, device):
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.sin()
@staticmethod
def backward(ctx, gy):
x, = ctx.saved_tensors
return gy * x.cos()
x = torch.randn([])
with torch.autograd.function._set_autograd_function_extension_enabled(False):
# by default, autograd.Function is disabled in a functorch transform
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
grad(MySin.apply)(x)
# we have a debug switch to allow it
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
try:
torch._C._functorch.set_autograd_function_allowed(True)
self.assertTrue(torch._C._functorch.get_autograd_function_allowed())
y = grad(MySin.apply)(x)
finally:
torch._C._functorch.set_autograd_function_allowed(False)
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
self.assertEqual(y, x.cos())
@_set_autograd_function_extension_enabled()
@parametrize('transform', [
'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
@ -4346,7 +4316,7 @@ def construct_sum_pyop():
def backward(ctx, gy):
return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
with enable_autograd_function():
with enable_single_level_autograd_function():
return MySum.apply(x, dim)
@mysum.py_impl(torch._C.DispatchKey.AutogradCPU)

View File

@ -19,8 +19,8 @@ def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]:
def current_level() -> int: ...
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
def set_autograd_function_allowed(allowed: bool) -> None: ...
def get_autograd_function_allowed() -> bool: ...
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
def get_single_level_autograd_function_allowed() -> bool: ...
# Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum):

View File

@ -1,7 +1,7 @@
import torch
from torch._ops import PyOperator
from torch._C._functorch import TransformType
from torch._functorch.utils import enable_autograd_function
from torch._functorch.utils import enable_single_level_autograd_function
import torch.utils._pytree as pytree
from torch._C._functorch import (
_wrap_for_grad,
@ -86,7 +86,7 @@ custom_function_call = CustomFunctionPyOperator()
@custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands):
Generated = generate_single_level_function(interpreter, autograd_function)
with enable_autograd_function():
with enable_single_level_autograd_function():
flat_out = Generated.apply(*operands)
return flat_out

View File

@ -1,19 +1,19 @@
import contextlib
import torch
from torch._C._functorch import (
set_autograd_function_allowed,
get_autograd_function_allowed,
set_single_level_autograd_function_allowed,
get_single_level_autograd_function_allowed,
unwrap_if_dead,
)
@contextlib.contextmanager
def enable_autograd_function():
def enable_single_level_autograd_function():
try:
prev_state = get_autograd_function_allowed()
set_autograd_function_allowed(True)
prev_state = get_single_level_autograd_function_allowed()
set_single_level_autograd_function_allowed(True)
yield
finally:
set_autograd_function_allowed(prev_state)
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons

View File

@ -877,10 +877,15 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
seq_id);
// Temporary hack to improve functorch UX. We'll find a better solution.
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
if (functorch_tls) {
functorch_tls->checkSupportsAutogradFunction();
// autograd.Function support for functorch is handled in Python.
// If we have gotten here, then either we are dealing with a
// torch.autograd.function._SingleLevelFunction, or something in
// the implementation went wrong.
// The following code is useful for debugging when something goes wrong
// because it'll raise a loud error (instead of being silently incorrect).
functorch_tls->checkSupportsSingleLevelAutogradFunction();
}
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));

View File

@ -447,11 +447,11 @@ void initFuncTorchBindings(PyObject* module) {
"get_inplace_requires_grad_allowed",
&at::functorch::getInplaceRequiresGradAllowed);
m.def(
"set_autograd_function_allowed",
&at::functorch::setAutogradFunctionAllowed);
"set_single_level_autograd_function_allowed",
&at::functorch::setSingleLevelAutogradFunctionAllowed);
m.def(
"get_autograd_function_allowed",
&at::functorch::getAutogradFunctionAllowed);
"get_single_level_autograd_function_allowed",
&at::functorch::getSingleLevelAutogradFunctionAllowed);
m.def("unwrap_if_dead", &unwrapIfDead);
m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
m.def("dlevel", &dlevel, "dlevel");