diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 278a72746b5..bf7ae9f5bb4 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -388,6 +388,15 @@ struct C10_API VariableVersion { } } + void set_version(int64_t i) { + TORCH_CHECK( + version_counter_, + "Tried to call torch.autograd._unsafe_set_version() on a tensor " + "that does not have a version counter. Was it created in inference mode?"); + TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i); + version_counter_->version_ = i; + } + // Inference tensor doesn't have version counter so it shouldn't be // accessed. uint32_t current_version() const { diff --git a/test/test_autograd.py b/test/test_autograd.py index 2a66d4b806d..e620bb6d2ba 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3736,6 +3736,23 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad out = f(x) self.assertTrue("AsStridedBackward" in str(out.grad_fn)) + def test_unsafe_set_version_counter(self): + x = torch.ones(2, requires_grad=True).clone() + x.add_(1) + x.add_(2) + self.assertEqual(2, x._version) + with torch.autograd._unsafe_preserve_version_counter(x): + x.mul_(2) + x.mul_(3) + # version counter doesn't change inside of the context manager + self.assertEqual(2, x._version) + + torch._C._autograd._unsafe_set_version_counter(x, 0) + self.assertEqual(0, x._version) + with self.assertRaisesRegex(RuntimeError, "Cannot set"): + torch._C._autograd._unsafe_set_version_counter(x, -1) + + def test_current_node(self): pr = [] diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index bdba43cb693..391095e3b3b 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -76,6 +76,8 @@ def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... def _pop_saved_tensors_default_hooks() -> None: ... +def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ... + def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... def _profiler_type() -> ActiveProfilerType: ... diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index b520a531bcd..84fec205feb 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -16,7 +16,8 @@ from .variable import Variable from .function import Function, NestedIOFunction from .gradcheck import gradcheck, gradgradcheck from .grad_mode import ( - no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking + no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking, + _unsafe_preserve_version_counter ) from .anomaly_mode import detect_anomaly, set_detect_anomaly from ..overrides import has_torch_function, handle_torch_function, is_tensor_like diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index c699a252583..9b2f8613f8d 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -289,3 +289,36 @@ class _force_original_view_tracking(_DecoratorContextManager): def clone(self): return self.__class__(self.mode) + +class _unsafe_preserve_version_counter(_DecoratorContextManager): + r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING! + + This context manager can lead to arbitrary silent-correctness issues in any other part of your code + (even the ones not touched directly by the context manager)! + + Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute. + This is generally important for correctness, as for example, mutating a tensor that autograd has saved + for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect + and error out in this situation. + + However, there are rare instances where it might be useful to hide mutations from autograd. For example: + if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate + the tensor right before it is needed by autograd. + + Args: + tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of. + + .. note:: + This API does not apply to :ref:`forward-mode AD `. + + """ + + def __init__(self, tensor: torch.Tensor) -> None: + self.tensor = tensor + self.prev_version = tensor._version + + def __enter__(self) -> None: + pass + + def __exit__(self, *args) -> None: + torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index fdbe961691b..cfdf291b66b 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -301,6 +301,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { return activities; }); + m.def("_unsafe_set_version_counter", [](at::Tensor t, int64_t i) { + auto vc = torch::autograd::impl::version_counter(t); + vc.set_version(i); + }); + m.def("_enable_profiler_legacy", enableProfilerLegacy); py::class_(m, "_ProfilerDisableOptions") .def(py::init());