mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add torch.autograd._unsafe_set_version_counter API (#92924)
better description coming soon (but this is meant to fix https://github.com/pytorch/pytorch/issues/91093) Pull Request resolved: https://github.com/pytorch/pytorch/pull/92924 Approved by: https://github.com/ezyang, https://github.com/alanwaketan, https://github.com/albanD
This commit is contained in:
parent
c74f438c01
commit
2b36d35b9c
|
|
@ -388,6 +388,15 @@ struct C10_API VariableVersion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_version(int64_t i) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
version_counter_,
|
||||||
|
"Tried to call torch.autograd._unsafe_set_version() on a tensor "
|
||||||
|
"that does not have a version counter. Was it created in inference mode?");
|
||||||
|
TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i);
|
||||||
|
version_counter_->version_ = i;
|
||||||
|
}
|
||||||
|
|
||||||
// Inference tensor doesn't have version counter so it shouldn't be
|
// Inference tensor doesn't have version counter so it shouldn't be
|
||||||
// accessed.
|
// accessed.
|
||||||
uint32_t current_version() const {
|
uint32_t current_version() const {
|
||||||
|
|
|
||||||
|
|
@ -3736,6 +3736,23 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||||
out = f(x)
|
out = f(x)
|
||||||
self.assertTrue("AsStridedBackward" in str(out.grad_fn))
|
self.assertTrue("AsStridedBackward" in str(out.grad_fn))
|
||||||
|
|
||||||
|
def test_unsafe_set_version_counter(self):
|
||||||
|
x = torch.ones(2, requires_grad=True).clone()
|
||||||
|
x.add_(1)
|
||||||
|
x.add_(2)
|
||||||
|
self.assertEqual(2, x._version)
|
||||||
|
with torch.autograd._unsafe_preserve_version_counter(x):
|
||||||
|
x.mul_(2)
|
||||||
|
x.mul_(3)
|
||||||
|
# version counter doesn't change inside of the context manager
|
||||||
|
self.assertEqual(2, x._version)
|
||||||
|
|
||||||
|
torch._C._autograd._unsafe_set_version_counter(x, 0)
|
||||||
|
self.assertEqual(0, x._version)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Cannot set"):
|
||||||
|
torch._C._autograd._unsafe_set_version_counter(x, -1)
|
||||||
|
|
||||||
|
|
||||||
def test_current_node(self):
|
def test_current_node(self):
|
||||||
pr = []
|
pr = []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,8 @@ def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
||||||
def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
||||||
def _pop_saved_tensors_default_hooks() -> None: ...
|
def _pop_saved_tensors_default_hooks() -> None: ...
|
||||||
|
|
||||||
|
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
|
||||||
|
|
||||||
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
||||||
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
||||||
def _profiler_type() -> ActiveProfilerType: ...
|
def _profiler_type() -> ActiveProfilerType: ...
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,8 @@ from .variable import Variable
|
||||||
from .function import Function, NestedIOFunction
|
from .function import Function, NestedIOFunction
|
||||||
from .gradcheck import gradcheck, gradgradcheck
|
from .gradcheck import gradcheck, gradgradcheck
|
||||||
from .grad_mode import (
|
from .grad_mode import (
|
||||||
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking
|
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking,
|
||||||
|
_unsafe_preserve_version_counter
|
||||||
)
|
)
|
||||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||||
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
|
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
|
||||||
|
|
|
||||||
|
|
@ -289,3 +289,36 @@ class _force_original_view_tracking(_DecoratorContextManager):
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return self.__class__(self.mode)
|
return self.__class__(self.mode)
|
||||||
|
|
||||||
|
class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||||||
|
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING!
|
||||||
|
|
||||||
|
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
|
||||||
|
(even the ones not touched directly by the context manager)!
|
||||||
|
|
||||||
|
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
|
||||||
|
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
|
||||||
|
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
|
||||||
|
and error out in this situation.
|
||||||
|
|
||||||
|
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
|
||||||
|
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
|
||||||
|
the tensor right before it is needed by autograd.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tensor: torch.Tensor) -> None:
|
||||||
|
self.tensor = tensor
|
||||||
|
self.prev_version = tensor._version
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||||
return activities;
|
return activities;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("_unsafe_set_version_counter", [](at::Tensor t, int64_t i) {
|
||||||
|
auto vc = torch::autograd::impl::version_counter(t);
|
||||||
|
vc.set_version(i);
|
||||||
|
});
|
||||||
|
|
||||||
m.def("_enable_profiler_legacy", enableProfilerLegacy);
|
m.def("_enable_profiler_legacy", enableProfilerLegacy);
|
||||||
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
|
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
|
||||||
.def(py::init<bool, bool>());
|
.def(py::init<bool, bool>());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user