add torch.autograd._set_view_replay_enabled, use in aot autograd (#92588)

tldr; this should fix some minor perf regressions that were caused by adding more as_strided() calls in aot autograd.

This PR adds a new context manager, `torch.autograd._set_view_replay_enabled()`.

Context: AOT Autograd has special handling for "outputs that alias graph intermediates". E.g. given this function:

```
def f(x):
    y = torch.mul(x, 2)
    out = y.view(-1)
    return out
```

AOT Autograd will do the following:

```
def fn_to_compile(x):
    y = torch.mul(x, 2)
    out = y.view(-1)
    # return the graph intermediate
    return y, out

compiled_fn = compile(fn_to_compile)

def wrapper(x):
    y, out = compiled_fn(x)
    # regenerate the alias of the graph intermediate
    return out._view_func(y)
```

What's annoying is that `out._view_func()` will result in a `.as_strided` call, because `out` is an ordinary runtime tensor. This (likely?) caused a perf regression, because when running the backward, out `as_strided_backward()` is slower than our `view_backward()`.

In this PR, I added some TLS for instructing autograd to do view replay instead of as_strided, even when given a normal tensor. I'm definitely interested in thoughts from autograd folks (cc @albanD @soulitzer). A few points that I want to bring up:

(1) One reason that this API seems generally useful to me is because of the case where you `torch.compile()` a function, and you pass in two inputs that alias each other, and mutate one of the inputs. Autograd is forced to add a bunch of as_strided() calls into the graph when this happens, but this would give users an escape hatch for better compiled perf in this situation

(2) To be fair, AOT Autograd probably won't need this TLS in the long term. There's a better (more complicated) solution, where AOT Autograd manually precomputes the view chain off of graph intermediates during tracing, and re-applies them at runtime. This is kind of complicated though and feels lower priority to implement immediately.

(3) Given all of that I made the API private, but lmk what you all think.

This is a followup of https://github.com/pytorch/pytorch/pull/92255.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92588
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Brian Hirsh 2023-02-07 21:02:48 +00:00 committed by PyTorch MergeBot
parent 333e771394
commit 83275d8cdf
8 changed files with 88 additions and 6 deletions

View File

@ -36,6 +36,10 @@ struct C10_API AutogradState {
mulithreading_enabled_ = mulithreading_enabled; mulithreading_enabled_ = mulithreading_enabled;
} }
void set_view_replay_enabled(bool view_replay_enabled) {
view_replay_enabled_ = view_replay_enabled;
}
bool get_grad_mode() const { bool get_grad_mode() const {
return grad_mode_; return grad_mode_;
} }
@ -52,11 +56,16 @@ struct C10_API AutogradState {
return mulithreading_enabled_; return mulithreading_enabled_;
} }
bool get_view_replay_enabled() const {
return view_replay_enabled_;
}
private: private:
bool grad_mode_ : 1; bool grad_mode_ : 1;
bool inference_mode_ : 1; bool inference_mode_ : 1;
bool fw_grad_mode_ : 1; bool fw_grad_mode_ : 1;
bool mulithreading_enabled_ : 1; bool mulithreading_enabled_ : 1;
bool view_replay_enabled_ : 1;
}; };
} // namespace c10 } // namespace c10

View File

@ -3717,6 +3717,25 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"): with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"):
t.backward() t.backward()
def test_view_replay_enabled(self):
def f(x):
out = x.clone().view(-1)
# mutate the view, triggering autograd view-replay logic
out.add_(1)
return out
x = torch.ones(2, 2, requires_grad=True)
with torch.autograd._force_original_view_tracking(True):
out = f(x)
# view-replay was enabled, so we should see ViewBackward in the graph
# instead of AsStridedBackward.
self.assertTrue("ViewBackward" in str(out.grad_fn))
# Without view-replay we should as an AsStridedBackward
out = f(x)
self.assertTrue("AsStridedBackward" in str(out.grad_fn))
def test_current_node(self): def test_current_node(self):
pr = [] pr = []

View File

@ -158,7 +158,8 @@ at::_ops::${unambiguous_name}::call(${unpacked_args})"""
SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
"""\ """\
std::function<at::Tensor(const at::Tensor&)> func=nullptr; std::function<at::Tensor(const at::Tensor&)> func=nullptr;
if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided()) { if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided() ||
c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
${replay_view_func} ${replay_view_func}
} }
""" """

View File

@ -986,6 +986,9 @@ class _EnableTorchFunction:
class _MultithreadingEnabled: class _MultithreadingEnabled:
def __init__(self, mode: _bool) -> None: ... def __init__(self, mode: _bool) -> None: ...
class _ViewReplayEnabled:
def __init__(self, mode: _bool) -> None: ...
# Defined in torch/csrc/jit/python/script_init.cpp # Defined in torch/csrc/jit/python/script_init.cpp
class LoggerBase(object): class LoggerBase(object):
... ...

View File

@ -1928,7 +1928,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
else: else:
args_with_synthetic_bases = args args_with_synthetic_bases = args
all_outs = CompiledFunction.apply(*args_with_synthetic_bases) with torch.autograd._force_original_view_tracking(True):
all_outs = CompiledFunction.apply(*args_with_synthetic_bases)
num_mutated_inps = CompiledFunction.num_mutated_inputs num_mutated_inps = CompiledFunction.num_mutated_inputs
num_intermediate_bases = CompiledFunction.fw_metadata.num_intermediate_bases num_intermediate_bases = CompiledFunction.fw_metadata.num_intermediate_bases
@ -2028,9 +2029,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
# TODO: handle the custom autograd function case here. # TODO: handle the custom autograd function case here.
# We need a way to check whether a tensor came from a custom autograd fn from python, # We need a way to check whether a tensor came from a custom autograd fn from python,
# AND a way to replay that custom view fn. # AND a way to replay that custom view fn.
regenerated_out = gen_alias_from_base( regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
aliased_base_tensor, o_, o_grad
)
fw_outs_including_aliases.append(regenerated_out) fw_outs_including_aliases.append(regenerated_out)
return fw_outs_including_aliases return fw_outs_including_aliases
else: else:

View File

@ -15,7 +15,9 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from .variable import Variable from .variable import Variable
from .function import Function, NestedIOFunction from .function import Function, NestedIOFunction
from .gradcheck import gradcheck, gradgradcheck from .gradcheck import gradcheck, gradgradcheck
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled from .grad_mode import (
no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled, _force_original_view_tracking
)
from .anomaly_mode import detect_anomaly, set_detect_anomaly from .anomaly_mode import detect_anomaly, set_detect_anomaly
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
from . import functional from . import functional

View File

@ -253,3 +253,39 @@ class set_multithreading_enabled(_DecoratorContextManager):
def clone(self) -> "set_multithreading_enabled": def clone(self) -> "set_multithreading_enabled":
return self.__class__(self.mode) return self.__class__(self.mode)
class _force_original_view_tracking(_DecoratorContextManager):
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
When a tensor view is mutated, the autograd engine needs to decide whether or not
to regenerate the "updated view" by either replaying the chain of views from the updated base,
or with a single call to as_strided.
If set_view_replay_enabled is set to True, then autograd will always use view replay.
Otherwise, it will fall back to its existing logic.
Args:
mode (bool): Flag whether to enable view-replay (``True``), or disable
(``False``).
"""
def __init__(self, mode: bool) -> None:
self.mode = mode
self._force_original_view_tracking_guard = torch._C._ViewReplayEnabled(mode)
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
del self._force_original_view_tracking_guard
def clone(self):
return self.__class__(self.mode)

View File

@ -55,6 +55,17 @@ struct MultithreadingEnabled {
bool old_; bool old_;
}; };
struct ViewReplayEnabled {
ViewReplayEnabled(bool enabled)
: old_(c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
c10::AutogradState::get_tls_state().set_view_replay_enabled(enabled);
}
~ViewReplayEnabled() {
c10::AutogradState::get_tls_state().set_view_replay_enabled(old_);
}
bool old_;
};
struct DisableAutocast { struct DisableAutocast {
c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset}; c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
}; };
@ -360,6 +371,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
.def(py::init<bool>()); .def(py::init<bool>());
py::class_<DisableAutocast>(std::move(_C_m), "_DisableAutocast") py::class_<DisableAutocast>(std::move(_C_m), "_DisableAutocast")
.def(py::init<>()); .def(py::init<>());
py::class_<ViewReplayEnabled>(_C_m, "_ViewReplayEnabled")
.def(py::init<bool>());
py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor") py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor")
.def(py::init([]() -> torch::autograd::SavedVariable { .def(py::init([]() -> torch::autograd::SavedVariable {
TORCH_CHECK( TORCH_CHECK(