mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add torch.autograd._unsafe_set_version_counter API (#92924)
better description coming soon (but this is meant to fix https://github.com/pytorch/pytorch/issues/91093) Pull Request resolved: https://github.com/pytorch/pytorch/pull/92924 Approved by: https://github.com/ezyang, https://github.com/alanwaketan, https://github.com/albanD
This commit is contained in:
parent
c74f438c01
commit
2b36d35b9c
|
|
@ -388,6 +388,15 @@ struct C10_API VariableVersion {
|
|||
}
|
||||
}
|
||||
|
||||
void set_version(int64_t i) {
|
||||
TORCH_CHECK(
|
||||
version_counter_,
|
||||
"Tried to call torch.autograd._unsafe_set_version() on a tensor "
|
||||
"that does not have a version counter. Was it created in inference mode?");
|
||||
TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i);
|
||||
version_counter_->version_ = i;
|
||||
}
|
||||
|
||||
// Inference tensor doesn't have version counter so it shouldn't be
|
||||
// accessed.
|
||||
uint32_t current_version() const {
|
||||
|
|
|
|||
|
|
@ -3736,6 +3736,23 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
|||
out = f(x)
|
||||
self.assertTrue("AsStridedBackward" in str(out.grad_fn))
|
||||
|
||||
def test_unsafe_set_version_counter(self):
|
||||
x = torch.ones(2, requires_grad=True).clone()
|
||||
x.add_(1)
|
||||
x.add_(2)
|
||||
self.assertEqual(2, x._version)
|
||||
with torch.autograd._unsafe_preserve_version_counter(x):
|
||||
x.mul_(2)
|
||||
x.mul_(3)
|
||||
# version counter doesn't change inside of the context manager
|
||||
self.assertEqual(2, x._version)
|
||||
|
||||
torch._C._autograd._unsafe_set_version_counter(x, 0)
|
||||
self.assertEqual(0, x._version)
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot set"):
|
||||
torch._C._autograd._unsafe_set_version_counter(x, -1)
|
||||
|
||||
|
||||
def test_current_node(self):
|
||||
pr = []
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,8 @@ def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
|||
def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
||||
def _pop_saved_tensors_default_hooks() -> None: ...
|
||||
|
||||
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
|
||||
|
||||
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
||||
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
||||
def _profiler_type() -> ActiveProfilerType: ...
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ from .variable import Variable
|
|||
from .function import Function, NestedIOFunction
|
||||
from .gradcheck import gradcheck, gradgradcheck
|
||||
from .grad_mode import (
|
||||
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking
|
||||
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking,
|
||||
_unsafe_preserve_version_counter
|
||||
)
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
|
||||
|
|
|
|||
|
|
@ -289,3 +289,36 @@ class _force_original_view_tracking(_DecoratorContextManager):
|
|||
|
||||
def clone(self):
|
||||
return self.__class__(self.mode)
|
||||
|
||||
class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||||
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING!
|
||||
|
||||
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
|
||||
(even the ones not touched directly by the context manager)!
|
||||
|
||||
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
|
||||
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
|
||||
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
|
||||
and error out in this situation.
|
||||
|
||||
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
|
||||
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
|
||||
the tensor right before it is needed by autograd.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
|
||||
|
||||
.. note::
|
||||
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor) -> None:
|
||||
self.tensor = tensor
|
||||
self.prev_version = tensor._version
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
|
||||
|
|
|
|||
|
|
@ -301,6 +301,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
|||
return activities;
|
||||
});
|
||||
|
||||
m.def("_unsafe_set_version_counter", [](at::Tensor t, int64_t i) {
|
||||
auto vc = torch::autograd::impl::version_counter(t);
|
||||
vc.set_version(i);
|
||||
});
|
||||
|
||||
m.def("_enable_profiler_legacy", enableProfilerLegacy);
|
||||
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
|
||||
.def(py::init<bool, bool>());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user