mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Test reland "AOTAutograd: gate view-replay behind config, not the def… (#124948)
A parallel attempt at landing https://github.com/pytorch/pytorch/pull/124945, but attempting to land through fbcode first Pull Request resolved: https://github.com/pytorch/pytorch/pull/124948 Approved by: https://github.com/albanD
This commit is contained in:
parent
fc13c1c850
commit
fc2aa23c1e
|
|
@ -3261,6 +3261,7 @@ def forward(self, tangents_1):
|
|||
|
||||
return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))
|
||||
|
||||
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
|
||||
def test_output_aliases_input_view_meta_replay(self):
|
||||
@self._compile_and_erase_bases(0)
|
||||
def f(a):
|
||||
|
|
@ -3274,6 +3275,7 @@ def forward(self, tangents_1):
|
|||
str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
|
||||
)
|
||||
|
||||
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
|
||||
def test_output_aliases_intermediate_view_meta_replay(self):
|
||||
@self._compile_and_erase_bases(0, 1)
|
||||
def f(a):
|
||||
|
|
@ -3293,6 +3295,7 @@ def forward(self, tangents_1):
|
|||
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
|
||||
)
|
||||
|
||||
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
|
||||
def test_output_aliases_output_view_meta_replay(self):
|
||||
@self._compile_and_erase_bases(1)
|
||||
def f(a):
|
||||
|
|
|
|||
|
|
@ -257,6 +257,19 @@ intentionally_not_handled = {
|
|||
"resize_": {b8, f16, f32, f64, i32, i64},
|
||||
"resize_as_": {b8, f16, f32, f64, i32, i64},
|
||||
}
|
||||
# This is only fixed when this config is set
|
||||
# We should eventually always turn it on
|
||||
import torch._functorch.config as functorch_config
|
||||
|
||||
if not functorch_config.view_replay_for_aliased_outputs:
|
||||
intentionally_not_handled['("as_strided", "partial_views")'] = {
|
||||
b8,
|
||||
f16,
|
||||
f32,
|
||||
f64,
|
||||
i32,
|
||||
i64,
|
||||
}
|
||||
|
||||
inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from torch.utils._python_dispatch import (
|
|||
is_traceable_wrapper_subclass,
|
||||
transform_subclass,
|
||||
)
|
||||
from .. import config
|
||||
|
||||
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
|
||||
|
||||
|
|
@ -219,7 +220,7 @@ def gen_alias_from_base(
|
|||
# In summary, we use the fact that FunctionalTensorWrapper saves the view
|
||||
# functions applied to itself (collected during functionalization) so as
|
||||
# to replay them (view functions) on the aliased_base_tensor.
|
||||
if target_functional_tensor is not None:
|
||||
if config.view_replay_for_aliased_outputs and target_functional_tensor is not None:
|
||||
from .schemas import FunctionalTensorMetadataEq
|
||||
|
||||
assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq)
|
||||
|
|
@ -237,11 +238,10 @@ def gen_alias_from_base(
|
|||
#
|
||||
# In order for this to work, we should have a way to replace those
|
||||
# symbolic shapes with concrete numbers.
|
||||
aot_joint_log.warning(
|
||||
aot_joint_log.info(
|
||||
"could not reconstruct view by re-applying a ViewMeta sequence. "
|
||||
"This error is possibly caused by dynamic shapes. "
|
||||
"Fallbacking to reconstruction using as_strided. "
|
||||
"Error message: %s",
|
||||
"Reason: %s",
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,26 @@ static_weight_shapes = True
|
|||
# Applies CSE to the graph before partitioning
|
||||
cse = True
|
||||
|
||||
# When AOTAutograd regenerates aliased graph outputs,
|
||||
# attempte to use functionalization's view-replay logic
|
||||
# before falling back to the autograd engine's view replay or as_strided.
|
||||
# This can have some perf implications
|
||||
# (although for many models this will not matter).
|
||||
# (1) If you have many view ops chained together, replaying all of them
|
||||
# at runtime can have more overhead compared to a single as_strided call
|
||||
# (2) If you are doing training, AsStridedBackward is quite slow,
|
||||
# and the individual view op backward formulas will likely be faster.
|
||||
# (3) Some backends like XLA do not support as_strided
|
||||
|
||||
# Temporary hack: disable this flag for internal
|
||||
# (needed to fix an internal issue while avoiding bumping XLA pin)
|
||||
# eventually: either default this config to false completely
|
||||
# once XLA pin update works,
|
||||
# or default config to true and fix relevant bugs
|
||||
from torch._inductor.config import is_fbcode
|
||||
|
||||
view_replay_for_aliased_outputs = not is_fbcode()
|
||||
|
||||
# Restricts the amount of computation AOTAutograd can do.
|
||||
# NB: We have essentially disabled this heuristic now. However, this is kept
|
||||
# here for now in case it's useful. Setting it low can artificially reduce the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user