mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)
functorch used to have a switch that enables/disables autograd.Function. That switch now enables/disables torch.autograd.function._SingleLevelFunction, so I've renamed it accordingly. We could just delete the switch because users should not be directly working with torch.autograd.function._SingleLevelFunction. However, it was useful for debugging when something went wrong when I was implementing the autograd.Function <> functorch interaction, so I want to keep it around as a debugging tool for a while since the code is already there. Test Plan: - updated tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
14ff58d4fa
commit
7aaad0b832
|
|
@ -28,10 +28,7 @@ struct TORCH_API FuncTorchTLSBase {
|
|||
virtual ~FuncTorchTLSBase() = default;
|
||||
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
|
||||
|
||||
// functorch doesn't always work with autograd.Function.
|
||||
// This is a hook to get into functorch -- functorch will determine
|
||||
// if it should raise an error message
|
||||
virtual int64_t checkSupportsAutogradFunction() const = 0;
|
||||
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
|
||||
virtual void checkSupportsInplaceRequiresGrad() const = 0;
|
||||
virtual void checkSupportsRetainGrad() const = 0;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -91,10 +91,11 @@ class FuncTorchTLS : public FuncTorchTLSBase {
|
|||
return result;
|
||||
}
|
||||
|
||||
int64_t checkSupportsAutogradFunction() const override {
|
||||
TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(),
|
||||
"functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ",
|
||||
"Please rewrite your function to not use autograd.Function while we work on fixing this");
|
||||
int64_t checkSupportsSingleLevelAutogradFunction() const override {
|
||||
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() == 0 || getSingleLevelAutogradFunctionAllowed(),
|
||||
"functorch functions (vmap, grad, vjp, etc.) incorrectly used with ",
|
||||
"torch.autograd.function._SingleLevelFunction. ",
|
||||
"This is not expected, please file a bug.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -119,7 +120,7 @@ class FuncTorchTLS : public FuncTorchTLSBase {
|
|||
|
||||
std::vector<DynamicLayer> dynamicLayerStack;
|
||||
bool allow_inplace_requires_grad_ = false;
|
||||
bool allow_autograd_function_ = false;
|
||||
bool allow_single_level_autograd_function_ = false;
|
||||
};
|
||||
|
||||
static FuncTorchTLS* getRawFunctorchTLS() {
|
||||
|
|
@ -143,14 +144,14 @@ bool getInplaceRequiresGradAllowed() {
|
|||
return functorch_tls->allow_inplace_requires_grad_;
|
||||
}
|
||||
|
||||
void setAutogradFunctionAllowed(bool allowed) {
|
||||
void setSingleLevelAutogradFunctionAllowed(bool allowed) {
|
||||
auto* functorch_tls = getRawFunctorchTLS();
|
||||
functorch_tls->allow_autograd_function_ = allowed;
|
||||
functorch_tls->allow_single_level_autograd_function_ = allowed;
|
||||
}
|
||||
|
||||
bool getAutogradFunctionAllowed() {
|
||||
bool getSingleLevelAutogradFunctionAllowed() {
|
||||
auto* functorch_tls = getRawFunctorchTLS();
|
||||
return functorch_tls->allow_autograd_function_;
|
||||
return functorch_tls->allow_single_level_autograd_function_;
|
||||
}
|
||||
|
||||
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
|
||||
|
|
|
|||
|
|
@ -108,11 +108,14 @@ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
|
|||
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
|
||||
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
|
||||
|
||||
// While a functorch transform is active, autograd.Function is disabled
|
||||
// by default. The following two APIs are APIs for enabling
|
||||
// autograd.Function. These are not user-facing APIs.
|
||||
TORCH_API void setAutogradFunctionAllowed(bool allowed);
|
||||
TORCH_API bool getAutogradFunctionAllowed();
|
||||
// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
|
||||
// is disabled by default. The following two APIs are APIs for enabling
|
||||
// it. These are not user-facing APIs. We can delete this in the future, but
|
||||
// it is useful for debugging when something goes wrong with the
|
||||
// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
|
||||
// because it leads to loud errors if something is incorrect.
|
||||
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
|
||||
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
|
||||
|
||||
// While a functorch grad transform is active, Tensor.requires_grad_() gets
|
||||
// disabled. These two functions are the mechanism to controlling that.
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from torch._functorch.make_functional import (
|
|||
from torch._functorch.eager_transforms import _slice_argnums
|
||||
from functorch.experimental import functionalize
|
||||
from torch._ops import PyOperator
|
||||
from torch._functorch.utils import enable_autograd_function
|
||||
from torch._functorch.utils import enable_single_level_autograd_function
|
||||
from torch.autograd.function import _set_autograd_function_extension_enabled
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from torch.func import functional_call, stack_module_state
|
||||
|
|
@ -3035,36 +3035,6 @@ class TestComposability(TestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
grad(f)(x)
|
||||
|
||||
def test_autograd_function_debug_switch(self, device):
|
||||
class MySin(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return x.sin()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
x, = ctx.saved_tensors
|
||||
return gy * x.cos()
|
||||
|
||||
x = torch.randn([])
|
||||
|
||||
with torch.autograd.function._set_autograd_function_extension_enabled(False):
|
||||
# by default, autograd.Function is disabled in a functorch transform
|
||||
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
|
||||
grad(MySin.apply)(x)
|
||||
|
||||
# we have a debug switch to allow it
|
||||
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
|
||||
try:
|
||||
torch._C._functorch.set_autograd_function_allowed(True)
|
||||
self.assertTrue(torch._C._functorch.get_autograd_function_allowed())
|
||||
y = grad(MySin.apply)(x)
|
||||
finally:
|
||||
torch._C._functorch.set_autograd_function_allowed(False)
|
||||
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
@_set_autograd_function_extension_enabled()
|
||||
@parametrize('transform', [
|
||||
'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
|
||||
|
|
@ -4346,7 +4316,7 @@ def construct_sum_pyop():
|
|||
def backward(ctx, gy):
|
||||
return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
|
||||
|
||||
with enable_autograd_function():
|
||||
with enable_single_level_autograd_function():
|
||||
return MySum.apply(x, dim)
|
||||
|
||||
@mysum.py_impl(torch._C.DispatchKey.AutogradCPU)
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]:
|
|||
def current_level() -> int: ...
|
||||
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
||||
|
||||
def set_autograd_function_allowed(allowed: bool) -> None: ...
|
||||
def get_autograd_function_allowed() -> bool: ...
|
||||
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
||||
def get_single_level_autograd_function_allowed() -> bool: ...
|
||||
|
||||
# Defined in aten/src/ATen/functorch/Interpreter.h
|
||||
class TransformType(Enum):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch._ops import PyOperator
|
||||
from torch._C._functorch import TransformType
|
||||
from torch._functorch.utils import enable_autograd_function
|
||||
from torch._functorch.utils import enable_single_level_autograd_function
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C._functorch import (
|
||||
_wrap_for_grad,
|
||||
|
|
@ -86,7 +86,7 @@ custom_function_call = CustomFunctionPyOperator()
|
|||
@custom_function_call.py_impl(TransformType.Jvp)
|
||||
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
||||
Generated = generate_single_level_function(interpreter, autograd_function)
|
||||
with enable_autograd_function():
|
||||
with enable_single_level_autograd_function():
|
||||
flat_out = Generated.apply(*operands)
|
||||
return flat_out
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
import contextlib
|
||||
import torch
|
||||
from torch._C._functorch import (
|
||||
set_autograd_function_allowed,
|
||||
get_autograd_function_allowed,
|
||||
set_single_level_autograd_function_allowed,
|
||||
get_single_level_autograd_function_allowed,
|
||||
unwrap_if_dead,
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_autograd_function():
|
||||
def enable_single_level_autograd_function():
|
||||
try:
|
||||
prev_state = get_autograd_function_allowed()
|
||||
set_autograd_function_allowed(True)
|
||||
prev_state = get_single_level_autograd_function_allowed()
|
||||
set_single_level_autograd_function_allowed(True)
|
||||
yield
|
||||
finally:
|
||||
set_autograd_function_allowed(prev_state)
|
||||
set_single_level_autograd_function_allowed(prev_state)
|
||||
|
||||
def unwrap_dead_wrappers(args):
|
||||
# NB: doesn't use tree_map_only for performance reasons
|
||||
|
|
|
|||
|
|
@ -877,10 +877,15 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
|
|||
unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
|
||||
seq_id);
|
||||
|
||||
// Temporary hack to improve functorch UX. We'll find a better solution.
|
||||
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
|
||||
if (functorch_tls) {
|
||||
functorch_tls->checkSupportsAutogradFunction();
|
||||
// autograd.Function support for functorch is handled in Python.
|
||||
// If we have gotten here, then either we are dealing with a
|
||||
// torch.autograd.function._SingleLevelFunction, or something in
|
||||
// the implementation went wrong.
|
||||
// The following code is useful for debugging when something goes wrong
|
||||
// because it'll raise a loud error (instead of being silently incorrect).
|
||||
functorch_tls->checkSupportsSingleLevelAutogradFunction();
|
||||
}
|
||||
|
||||
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
|
||||
|
|
|
|||
|
|
@ -447,11 +447,11 @@ void initFuncTorchBindings(PyObject* module) {
|
|||
"get_inplace_requires_grad_allowed",
|
||||
&at::functorch::getInplaceRequiresGradAllowed);
|
||||
m.def(
|
||||
"set_autograd_function_allowed",
|
||||
&at::functorch::setAutogradFunctionAllowed);
|
||||
"set_single_level_autograd_function_allowed",
|
||||
&at::functorch::setSingleLevelAutogradFunctionAllowed);
|
||||
m.def(
|
||||
"get_autograd_function_allowed",
|
||||
&at::functorch::getAutogradFunctionAllowed);
|
||||
"get_single_level_autograd_function_allowed",
|
||||
&at::functorch::getSingleLevelAutogradFunctionAllowed);
|
||||
m.def("unwrap_if_dead", &unwrapIfDead);
|
||||
m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
|
||||
m.def("dlevel", &dlevel, "dlevel");
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user