add torch.autograd._unsafe_set_version_counter API (#92924)

better description coming soon (but this is meant to fix https://github.com/pytorch/pytorch/issues/91093)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92924
Approved by: https://github.com/ezyang, https://github.com/alanwaketan, https://github.com/albanD
This commit is contained in:
Brian Hirsh 2023-02-11 16:08:43 +00:00 committed by PyTorch MergeBot
parent c74f438c01
commit 2b36d35b9c
6 changed files with 68 additions and 1 deletions

View File

@ -388,6 +388,15 @@ struct C10_API VariableVersion {
} }
} }
void set_version(int64_t i) {
TORCH_CHECK(
version_counter_,
"Tried to call torch.autograd._unsafe_set_version() on a tensor "
"that does not have a version counter. Was it created in inference mode?");
TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i);
version_counter_->version_ = i;
}
// Inference tensor doesn't have version counter so it shouldn't be // Inference tensor doesn't have version counter so it shouldn't be
// accessed. // accessed.
uint32_t current_version() const { uint32_t current_version() const {

View File

@ -3736,6 +3736,23 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
out = f(x) out = f(x)
self.assertTrue("AsStridedBackward" in str(out.grad_fn)) self.assertTrue("AsStridedBackward" in str(out.grad_fn))
def test_unsafe_set_version_counter(self):
x = torch.ones(2, requires_grad=True).clone()
x.add_(1)
x.add_(2)
self.assertEqual(2, x._version)
with torch.autograd._unsafe_preserve_version_counter(x):
x.mul_(2)
x.mul_(3)
# version counter doesn't change inside of the context manager
self.assertEqual(2, x._version)
torch._C._autograd._unsafe_set_version_counter(x, 0)
self.assertEqual(0, x._version)
with self.assertRaisesRegex(RuntimeError, "Cannot set"):
torch._C._autograd._unsafe_set_version_counter(x, -1)
def test_current_node(self): def test_current_node(self):
pr = [] pr = []

View File

@ -76,6 +76,8 @@ def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _pop_saved_tensors_default_hooks() -> None: ... def _pop_saved_tensors_default_hooks() -> None: ...
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ... def _profiler_type() -> ActiveProfilerType: ...

View File

@ -16,7 +16,8 @@ from .variable import Variable
from .function import Function, NestedIOFunction from .function import Function, NestedIOFunction
from .gradcheck import gradcheck, gradgradcheck from .gradcheck import gradcheck, gradgradcheck
from .grad_mode import ( from .grad_mode import (
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking,
_unsafe_preserve_version_counter
) )
from .anomaly_mode import detect_anomaly, set_detect_anomaly from .anomaly_mode import detect_anomaly, set_detect_anomaly
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like from ..overrides import has_torch_function, handle_torch_function, is_tensor_like

View File

@ -289,3 +289,36 @@ class _force_original_view_tracking(_DecoratorContextManager):
def clone(self): def clone(self):
return self.__class__(self.mode) return self.__class__(self.mode)
class _unsafe_preserve_version_counter(_DecoratorContextManager):
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING!
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
(even the ones not touched directly by the context manager)!
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
and error out in this situation.
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
the tensor right before it is needed by autograd.
Args:
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor
self.prev_version = tensor._version
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)

View File

@ -301,6 +301,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
return activities; return activities;
}); });
m.def("_unsafe_set_version_counter", [](at::Tensor t, int64_t i) {
auto vc = torch::autograd::impl::version_counter(t);
vc.set_version(i);
});
m.def("_enable_profiler_legacy", enableProfilerLegacy); m.def("_enable_profiler_legacy", enableProfilerLegacy);
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions") py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
.def(py::init<bool, bool>()); .def(py::init<bool, bool>());