mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add torch.autograd._set_view_replay_enabled, use in aot autograd (#92588)
tldr; this should fix some minor perf regressions that were caused by adding more as_strided() calls in aot autograd.
This PR adds a new context manager, `torch.autograd._set_view_replay_enabled()`.
Context: AOT Autograd has special handling for "outputs that alias graph intermediates". E.g. given this function:
```
def f(x):
y = torch.mul(x, 2)
out = y.view(-1)
return out
```
AOT Autograd will do the following:
```
def fn_to_compile(x):
y = torch.mul(x, 2)
out = y.view(-1)
# return the graph intermediate
return y, out
compiled_fn = compile(fn_to_compile)
def wrapper(x):
y, out = compiled_fn(x)
# regenerate the alias of the graph intermediate
return out._view_func(y)
```
What's annoying is that `out._view_func()` will result in a `.as_strided` call, because `out` is an ordinary runtime tensor. This (likely?) caused a perf regression, because when running the backward, out `as_strided_backward()` is slower than our `view_backward()`.
In this PR, I added some TLS for instructing autograd to do view replay instead of as_strided, even when given a normal tensor. I'm definitely interested in thoughts from autograd folks (cc @albanD @soulitzer). A few points that I want to bring up:
(1) One reason that this API seems generally useful to me is because of the case where you `torch.compile()` a function, and you pass in two inputs that alias each other, and mutate one of the inputs. Autograd is forced to add a bunch of as_strided() calls into the graph when this happens, but this would give users an escape hatch for better compiled perf in this situation
(2) To be fair, AOT Autograd probably won't need this TLS in the long term. There's a better (more complicated) solution, where AOT Autograd manually precomputes the view chain off of graph intermediates during tracing, and re-applies them at runtime. This is kind of complicated though and feels lower priority to implement immediately.
(3) Given all of that I made the API private, but lmk what you all think.
This is a followup of https://github.com/pytorch/pytorch/pull/92255.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92588
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
parent
333e771394
commit
83275d8cdf
|
|
@ -36,6 +36,10 @@ struct C10_API AutogradState {
|
|||
mulithreading_enabled_ = mulithreading_enabled;
|
||||
}
|
||||
|
||||
void set_view_replay_enabled(bool view_replay_enabled) {
|
||||
view_replay_enabled_ = view_replay_enabled;
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
|
|
@ -52,11 +56,16 @@ struct C10_API AutogradState {
|
|||
return mulithreading_enabled_;
|
||||
}
|
||||
|
||||
bool get_view_replay_enabled() const {
|
||||
return view_replay_enabled_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
bool mulithreading_enabled_ : 1;
|
||||
bool view_replay_enabled_ : 1;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -3717,6 +3717,25 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
|||
with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"):
|
||||
t.backward()
|
||||
|
||||
def test_view_replay_enabled(self):
|
||||
def f(x):
|
||||
out = x.clone().view(-1)
|
||||
# mutate the view, triggering autograd view-replay logic
|
||||
out.add_(1)
|
||||
return out
|
||||
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
out = f(x)
|
||||
|
||||
# view-replay was enabled, so we should see ViewBackward in the graph
|
||||
# instead of AsStridedBackward.
|
||||
self.assertTrue("ViewBackward" in str(out.grad_fn))
|
||||
|
||||
# Without view-replay we should as an AsStridedBackward
|
||||
out = f(x)
|
||||
self.assertTrue("AsStridedBackward" in str(out.grad_fn))
|
||||
|
||||
def test_current_node(self):
|
||||
pr = []
|
||||
|
||||
|
|
|
|||
|
|
@ -158,7 +158,8 @@ at::_ops::${unambiguous_name}::call(${unpacked_args})"""
|
|||
SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
|
||||
"""\
|
||||
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
|
||||
if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided()) {
|
||||
if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided() ||
|
||||
c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
|
||||
${replay_view_func}
|
||||
}
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -986,6 +986,9 @@ class _EnableTorchFunction:
|
|||
class _MultithreadingEnabled:
|
||||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
||||
class _ViewReplayEnabled:
|
||||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class LoggerBase(object):
|
||||
...
|
||||
|
|
|
|||
|
|
@ -1928,6 +1928,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
|
|||
else:
|
||||
args_with_synthetic_bases = args
|
||||
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
all_outs = CompiledFunction.apply(*args_with_synthetic_bases)
|
||||
|
||||
num_mutated_inps = CompiledFunction.num_mutated_inputs
|
||||
|
|
@ -2028,9 +2029,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
|
|||
# TODO: handle the custom autograd function case here.
|
||||
# We need a way to check whether a tensor came from a custom autograd fn from python,
|
||||
# AND a way to replay that custom view fn.
|
||||
regenerated_out = gen_alias_from_base(
|
||||
aliased_base_tensor, o_, o_grad
|
||||
)
|
||||
regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
|
||||
fw_outs_including_aliases.append(regenerated_out)
|
||||
return fw_outs_including_aliases
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
|
|||
from .variable import Variable
|
||||
from .function import Function, NestedIOFunction
|
||||
from .gradcheck import gradcheck, gradgradcheck
|
||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled
|
||||
from .grad_mode import (
|
||||
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking
|
||||
)
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
|
||||
from . import functional
|
||||
|
|
|
|||
|
|
@ -253,3 +253,39 @@ class set_multithreading_enabled(_DecoratorContextManager):
|
|||
|
||||
def clone(self) -> "set_multithreading_enabled":
|
||||
return self.__class__(self.mode)
|
||||
|
||||
|
||||
class _force_original_view_tracking(_DecoratorContextManager):
|
||||
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
|
||||
|
||||
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
|
||||
It can be used as a context-manager or as a function.
|
||||
|
||||
This context manager is thread local; it will not affect computation
|
||||
in other threads.
|
||||
|
||||
When a tensor view is mutated, the autograd engine needs to decide whether or not
|
||||
to regenerate the "updated view" by either replaying the chain of views from the updated base,
|
||||
or with a single call to as_strided.
|
||||
|
||||
If set_view_replay_enabled is set to True, then autograd will always use view replay.
|
||||
Otherwise, it will fall back to its existing logic.
|
||||
|
||||
Args:
|
||||
mode (bool): Flag whether to enable view-replay (``True``), or disable
|
||||
(``False``).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mode: bool) -> None:
|
||||
self.mode = mode
|
||||
self._force_original_view_tracking_guard = torch._C._ViewReplayEnabled(mode)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
del self._force_original_view_tracking_guard
|
||||
|
||||
def clone(self):
|
||||
return self.__class__(self.mode)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,17 @@ struct MultithreadingEnabled {
|
|||
bool old_;
|
||||
};
|
||||
|
||||
struct ViewReplayEnabled {
|
||||
ViewReplayEnabled(bool enabled)
|
||||
: old_(c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
|
||||
c10::AutogradState::get_tls_state().set_view_replay_enabled(enabled);
|
||||
}
|
||||
~ViewReplayEnabled() {
|
||||
c10::AutogradState::get_tls_state().set_view_replay_enabled(old_);
|
||||
}
|
||||
bool old_;
|
||||
};
|
||||
|
||||
struct DisableAutocast {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
|
||||
};
|
||||
|
|
@ -360,6 +371,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
|||
.def(py::init<bool>());
|
||||
py::class_<DisableAutocast>(std::move(_C_m), "_DisableAutocast")
|
||||
.def(py::init<>());
|
||||
py::class_<ViewReplayEnabled>(_C_m, "_ViewReplayEnabled")
|
||||
.def(py::init<bool>());
|
||||
py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor")
|
||||
.def(py::init([]() -> torch::autograd::SavedVariable {
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user