mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)
functorch used to have a switch that enables/disables autograd.Function. That switch now enables/disables torch.autograd.function._SingleLevelFunction, so I've renamed it accordingly. We could just delete the switch because users should not be directly working with torch.autograd.function._SingleLevelFunction. However, it was useful for debugging when something went wrong when I was implementing the autograd.Function <> functorch interaction, so I want to keep it around as a debugging tool for a while since the code is already there. Test Plan: - updated tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
14ff58d4fa
commit
7aaad0b832
|
|
@ -28,10 +28,7 @@ struct TORCH_API FuncTorchTLSBase {
|
||||||
virtual ~FuncTorchTLSBase() = default;
|
virtual ~FuncTorchTLSBase() = default;
|
||||||
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
|
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
|
||||||
|
|
||||||
// functorch doesn't always work with autograd.Function.
|
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
|
||||||
// This is a hook to get into functorch -- functorch will determine
|
|
||||||
// if it should raise an error message
|
|
||||||
virtual int64_t checkSupportsAutogradFunction() const = 0;
|
|
||||||
virtual void checkSupportsInplaceRequiresGrad() const = 0;
|
virtual void checkSupportsInplaceRequiresGrad() const = 0;
|
||||||
virtual void checkSupportsRetainGrad() const = 0;
|
virtual void checkSupportsRetainGrad() const = 0;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -91,10 +91,11 @@ class FuncTorchTLS : public FuncTorchTLSBase {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t checkSupportsAutogradFunction() const override {
|
int64_t checkSupportsSingleLevelAutogradFunction() const override {
|
||||||
TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(),
|
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() == 0 || getSingleLevelAutogradFunctionAllowed(),
|
||||||
"functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ",
|
"functorch functions (vmap, grad, vjp, etc.) incorrectly used with ",
|
||||||
"Please rewrite your function to not use autograd.Function while we work on fixing this");
|
"torch.autograd.function._SingleLevelFunction. ",
|
||||||
|
"This is not expected, please file a bug.");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -119,7 +120,7 @@ class FuncTorchTLS : public FuncTorchTLSBase {
|
||||||
|
|
||||||
std::vector<DynamicLayer> dynamicLayerStack;
|
std::vector<DynamicLayer> dynamicLayerStack;
|
||||||
bool allow_inplace_requires_grad_ = false;
|
bool allow_inplace_requires_grad_ = false;
|
||||||
bool allow_autograd_function_ = false;
|
bool allow_single_level_autograd_function_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
static FuncTorchTLS* getRawFunctorchTLS() {
|
static FuncTorchTLS* getRawFunctorchTLS() {
|
||||||
|
|
@ -143,14 +144,14 @@ bool getInplaceRequiresGradAllowed() {
|
||||||
return functorch_tls->allow_inplace_requires_grad_;
|
return functorch_tls->allow_inplace_requires_grad_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setAutogradFunctionAllowed(bool allowed) {
|
void setSingleLevelAutogradFunctionAllowed(bool allowed) {
|
||||||
auto* functorch_tls = getRawFunctorchTLS();
|
auto* functorch_tls = getRawFunctorchTLS();
|
||||||
functorch_tls->allow_autograd_function_ = allowed;
|
functorch_tls->allow_single_level_autograd_function_ = allowed;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getAutogradFunctionAllowed() {
|
bool getSingleLevelAutogradFunctionAllowed() {
|
||||||
auto* functorch_tls = getRawFunctorchTLS();
|
auto* functorch_tls = getRawFunctorchTLS();
|
||||||
return functorch_tls->allow_autograd_function_;
|
return functorch_tls->allow_single_level_autograd_function_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
|
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
|
||||||
|
|
|
||||||
|
|
@ -108,11 +108,14 @@ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
|
||||||
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
|
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
|
||||||
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
|
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
|
||||||
|
|
||||||
// While a functorch transform is active, autograd.Function is disabled
|
// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
|
||||||
// by default. The following two APIs are APIs for enabling
|
// is disabled by default. The following two APIs are APIs for enabling
|
||||||
// autograd.Function. These are not user-facing APIs.
|
// it. These are not user-facing APIs. We can delete this in the future, but
|
||||||
TORCH_API void setAutogradFunctionAllowed(bool allowed);
|
// it is useful for debugging when something goes wrong with the
|
||||||
TORCH_API bool getAutogradFunctionAllowed();
|
// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
|
||||||
|
// because it leads to loud errors if something is incorrect.
|
||||||
|
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
|
||||||
|
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
|
||||||
|
|
||||||
// While a functorch grad transform is active, Tensor.requires_grad_() gets
|
// While a functorch grad transform is active, Tensor.requires_grad_() gets
|
||||||
// disabled. These two functions are the mechanism to controlling that.
|
// disabled. These two functions are the mechanism to controlling that.
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from torch._functorch.make_functional import (
|
||||||
from torch._functorch.eager_transforms import _slice_argnums
|
from torch._functorch.eager_transforms import _slice_argnums
|
||||||
from functorch.experimental import functionalize
|
from functorch.experimental import functionalize
|
||||||
from torch._ops import PyOperator
|
from torch._ops import PyOperator
|
||||||
from torch._functorch.utils import enable_autograd_function
|
from torch._functorch.utils import enable_single_level_autograd_function
|
||||||
from torch.autograd.function import _set_autograd_function_extension_enabled
|
from torch.autograd.function import _set_autograd_function_extension_enabled
|
||||||
import torch.autograd.forward_ad as fwAD
|
import torch.autograd.forward_ad as fwAD
|
||||||
from torch.func import functional_call, stack_module_state
|
from torch.func import functional_call, stack_module_state
|
||||||
|
|
@ -3035,36 +3035,6 @@ class TestComposability(TestCase):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
grad(f)(x)
|
grad(f)(x)
|
||||||
|
|
||||||
def test_autograd_function_debug_switch(self, device):
|
|
||||||
class MySin(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x):
|
|
||||||
ctx.save_for_backward(x)
|
|
||||||
return x.sin()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gy):
|
|
||||||
x, = ctx.saved_tensors
|
|
||||||
return gy * x.cos()
|
|
||||||
|
|
||||||
x = torch.randn([])
|
|
||||||
|
|
||||||
with torch.autograd.function._set_autograd_function_extension_enabled(False):
|
|
||||||
# by default, autograd.Function is disabled in a functorch transform
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
|
|
||||||
grad(MySin.apply)(x)
|
|
||||||
|
|
||||||
# we have a debug switch to allow it
|
|
||||||
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
|
|
||||||
try:
|
|
||||||
torch._C._functorch.set_autograd_function_allowed(True)
|
|
||||||
self.assertTrue(torch._C._functorch.get_autograd_function_allowed())
|
|
||||||
y = grad(MySin.apply)(x)
|
|
||||||
finally:
|
|
||||||
torch._C._functorch.set_autograd_function_allowed(False)
|
|
||||||
self.assertFalse(torch._C._functorch.get_autograd_function_allowed())
|
|
||||||
self.assertEqual(y, x.cos())
|
|
||||||
|
|
||||||
@_set_autograd_function_extension_enabled()
|
@_set_autograd_function_extension_enabled()
|
||||||
@parametrize('transform', [
|
@parametrize('transform', [
|
||||||
'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
|
'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
|
||||||
|
|
@ -4346,7 +4316,7 @@ def construct_sum_pyop():
|
||||||
def backward(ctx, gy):
|
def backward(ctx, gy):
|
||||||
return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
|
return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None
|
||||||
|
|
||||||
with enable_autograd_function():
|
with enable_single_level_autograd_function():
|
||||||
return MySum.apply(x, dim)
|
return MySum.apply(x, dim)
|
||||||
|
|
||||||
@mysum.py_impl(torch._C.DispatchKey.AutogradCPU)
|
@mysum.py_impl(torch._C.DispatchKey.AutogradCPU)
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,8 @@ def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]:
|
||||||
def current_level() -> int: ...
|
def current_level() -> int: ...
|
||||||
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
||||||
|
|
||||||
def set_autograd_function_allowed(allowed: bool) -> None: ...
|
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
||||||
def get_autograd_function_allowed() -> bool: ...
|
def get_single_level_autograd_function_allowed() -> bool: ...
|
||||||
|
|
||||||
# Defined in aten/src/ATen/functorch/Interpreter.h
|
# Defined in aten/src/ATen/functorch/Interpreter.h
|
||||||
class TransformType(Enum):
|
class TransformType(Enum):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch._ops import PyOperator
|
from torch._ops import PyOperator
|
||||||
from torch._C._functorch import TransformType
|
from torch._C._functorch import TransformType
|
||||||
from torch._functorch.utils import enable_autograd_function
|
from torch._functorch.utils import enable_single_level_autograd_function
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._C._functorch import (
|
from torch._C._functorch import (
|
||||||
_wrap_for_grad,
|
_wrap_for_grad,
|
||||||
|
|
@ -86,7 +86,7 @@ custom_function_call = CustomFunctionPyOperator()
|
||||||
@custom_function_call.py_impl(TransformType.Jvp)
|
@custom_function_call.py_impl(TransformType.Jvp)
|
||||||
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
||||||
Generated = generate_single_level_function(interpreter, autograd_function)
|
Generated = generate_single_level_function(interpreter, autograd_function)
|
||||||
with enable_autograd_function():
|
with enable_single_level_autograd_function():
|
||||||
flat_out = Generated.apply(*operands)
|
flat_out = Generated.apply(*operands)
|
||||||
return flat_out
|
return flat_out
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
from torch._C._functorch import (
|
from torch._C._functorch import (
|
||||||
set_autograd_function_allowed,
|
set_single_level_autograd_function_allowed,
|
||||||
get_autograd_function_allowed,
|
get_single_level_autograd_function_allowed,
|
||||||
unwrap_if_dead,
|
unwrap_if_dead,
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def enable_autograd_function():
|
def enable_single_level_autograd_function():
|
||||||
try:
|
try:
|
||||||
prev_state = get_autograd_function_allowed()
|
prev_state = get_single_level_autograd_function_allowed()
|
||||||
set_autograd_function_allowed(True)
|
set_single_level_autograd_function_allowed(True)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
set_autograd_function_allowed(prev_state)
|
set_single_level_autograd_function_allowed(prev_state)
|
||||||
|
|
||||||
def unwrap_dead_wrappers(args):
|
def unwrap_dead_wrappers(args):
|
||||||
# NB: doesn't use tree_map_only for performance reasons
|
# NB: doesn't use tree_map_only for performance reasons
|
||||||
|
|
|
||||||
|
|
@ -877,10 +877,15 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
|
||||||
unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
|
unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
|
||||||
seq_id);
|
seq_id);
|
||||||
|
|
||||||
// Temporary hack to improve functorch UX. We'll find a better solution.
|
|
||||||
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
|
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
|
||||||
if (functorch_tls) {
|
if (functorch_tls) {
|
||||||
functorch_tls->checkSupportsAutogradFunction();
|
// autograd.Function support for functorch is handled in Python.
|
||||||
|
// If we have gotten here, then either we are dealing with a
|
||||||
|
// torch.autograd.function._SingleLevelFunction, or something in
|
||||||
|
// the implementation went wrong.
|
||||||
|
// The following code is useful for debugging when something goes wrong
|
||||||
|
// because it'll raise a loud error (instead of being silently incorrect).
|
||||||
|
functorch_tls->checkSupportsSingleLevelAutogradFunction();
|
||||||
}
|
}
|
||||||
|
|
||||||
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
|
THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
|
||||||
|
|
|
||||||
|
|
@ -447,11 +447,11 @@ void initFuncTorchBindings(PyObject* module) {
|
||||||
"get_inplace_requires_grad_allowed",
|
"get_inplace_requires_grad_allowed",
|
||||||
&at::functorch::getInplaceRequiresGradAllowed);
|
&at::functorch::getInplaceRequiresGradAllowed);
|
||||||
m.def(
|
m.def(
|
||||||
"set_autograd_function_allowed",
|
"set_single_level_autograd_function_allowed",
|
||||||
&at::functorch::setAutogradFunctionAllowed);
|
&at::functorch::setSingleLevelAutogradFunctionAllowed);
|
||||||
m.def(
|
m.def(
|
||||||
"get_autograd_function_allowed",
|
"get_single_level_autograd_function_allowed",
|
||||||
&at::functorch::getAutogradFunctionAllowed);
|
&at::functorch::getSingleLevelAutogradFunctionAllowed);
|
||||||
m.def("unwrap_if_dead", &unwrapIfDead);
|
m.def("unwrap_if_dead", &unwrapIfDead);
|
||||||
m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
|
m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
|
||||||
m.def("dlevel", &dlevel, "dlevel");
|
m.def("dlevel", &dlevel, "dlevel");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user