mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add C API to return all torch function disablement status (#133136)
This PR adds a C function to check if all torch function is disabled. Recall that there are three torch function enablement states: * All disabled * Torch Function Subclass disabled * All enabled The API before this change provides two functions: * `_is_torch_function_enabled` - returns True iff the current TF state is All enabled * `_is_torch_function_mode_enabled` - returns True iff the state is not All disabled and the torch function mode stack is non-empty. The crux of why a new API is needed is the following: If dynamo enters a frame with the torch function mode stack empty, `_is_torch_function_enabled` == False, it is impossible to determine if after a new mode is pushed whether we should enter the mode or not. This is because we don't know if the enablement state is All disabled or only subclass disabled. Adding this API to check if All disabled is True allows us to disambiguate this case. In the next PR, Dynamo InstructionTranslator will have clearer flags than the underlying C API: * A flag to indicate if subclasses are disabled (ie All disabled or Subclass Disabled is the current state) * A flag to indicate if modes are disabled (ie if All disabled is the current state) * A symbolic stack which can be checked if any modes are present Pull Request resolved: https://github.com/pytorch/pytorch/pull/133136 Approved by: https://github.com/bdhirsh ghstack dependencies: #133130, #133729, #133131, #133132, #133133, #133134
This commit is contained in:
parent
d97ca968cd
commit
48ee0984ac
|
|
@ -46,4 +46,9 @@ bool torch_function_mode_enabled() {
|
|||
PythonTorchFunctionTLS::stack_len() > 0;
|
||||
}
|
||||
|
||||
// This is needed to disambiguate the ternary torch function disabled states
|
||||
bool torch_function_all_disabled() {
|
||||
return PythonTorchFunctionTLS::get_disabled_state() == TorchFunctionDisabledState::ALL_DISABLED;
|
||||
}
|
||||
|
||||
} // namespace at::impl
|
||||
|
|
|
|||
|
|
@ -31,4 +31,6 @@ struct TORCH_API PythonTorchFunctionTLS {
|
|||
|
||||
TORCH_API bool torch_function_mode_enabled();
|
||||
|
||||
TORCH_API bool torch_function_all_disabled();
|
||||
|
||||
} // namespace at::impl
|
||||
|
|
|
|||
|
|
@ -1550,6 +1550,23 @@ class TestTorchFunctionMode(TestCase):
|
|||
finally:
|
||||
del g
|
||||
|
||||
def test_torch_function_all_disabled_api(self):
|
||||
from torch._C import _is_torch_function_all_disabled
|
||||
|
||||
state = _is_torch_function_all_disabled()
|
||||
self.assertFalse(state)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
state = _is_torch_function_all_disabled()
|
||||
self.assertTrue(state)
|
||||
|
||||
state = _is_torch_function_all_disabled()
|
||||
self.assertFalse(state)
|
||||
|
||||
with torch._C.DisableTorchFunctionSubclass():
|
||||
state = _is_torch_function_all_disabled()
|
||||
self.assertFalse(state)
|
||||
|
||||
def test_subclass_hash(self):
|
||||
class DiagTensor(torch.Tensor):
|
||||
def __init__(self, diag):
|
||||
|
|
|
|||
|
|
@ -1238,6 +1238,7 @@ def _set_check_sparse_tensor_invariants(
|
|||
def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator
|
||||
def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator
|
||||
def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction
|
||||
def _is_torch_function_all_disabled() -> _bool: ... # THPModule_isAllDisabledTorchFunction
|
||||
def _has_torch_function(
|
||||
args: Iterable[Any],
|
||||
) -> _bool: ... # THPModule_has_torch_function
|
||||
|
|
|
|||
|
|
@ -1530,6 +1530,10 @@ static PyMethodDef TorchMethods[] = { // NOLINT
|
|||
THPModule_isEnabledTorchFunction,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_is_torch_function_all_disabled",
|
||||
THPModule_isAllDisabledTorchFunction,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_disabled_torch_function_impl",
|
||||
THPModule_disable_torch_function,
|
||||
METH_VARARGS,
|
||||
|
|
|
|||
|
|
@ -64,6 +64,16 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) {
|
|||
}
|
||||
}
|
||||
|
||||
PyObject* THPModule_isAllDisabledTorchFunction(
|
||||
PyObject* self,
|
||||
PyObject* unused) {
|
||||
if (at::impl::torch_function_all_disabled()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
|
||||
{"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
|
||||
{"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ struct DisableTorchDispatch {
|
|||
} // namespace torch
|
||||
|
||||
PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused);
|
||||
PyObject* THPModule_isAllDisabledTorchFunction(
|
||||
PyObject* self,
|
||||
PyObject* unused);
|
||||
PyObject* THPModule_DisableTorchFunctionType();
|
||||
PyObject* THPModule_DisableTorchFunctionSubclassType();
|
||||
PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user