Test reland "AOTAutograd: gate view-replay behind config, not the def… (#124948)

A parallel attempt at landing https://github.com/pytorch/pytorch/pull/124945, but attempting to land through fbcode first

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124948
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh 2024-04-26 13:16:24 +00:00 committed by PyTorch MergeBot
parent fc13c1c850
commit fc2aa23c1e
4 changed files with 40 additions and 4 deletions

View File

@ -3261,6 +3261,7 @@ def forward(self, tangents_1):
return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g)) return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_input_view_meta_replay(self): def test_output_aliases_input_view_meta_replay(self):
@self._compile_and_erase_bases(0) @self._compile_and_erase_bases(0)
def f(a): def f(a):
@ -3274,6 +3275,7 @@ def forward(self, tangents_1):
str(out.grad_fn.__class__), """<class 'ViewBackward0'>""" str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
) )
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_intermediate_view_meta_replay(self): def test_output_aliases_intermediate_view_meta_replay(self):
@self._compile_and_erase_bases(0, 1) @self._compile_and_erase_bases(0, 1)
def f(a): def f(a):
@ -3293,6 +3295,7 @@ def forward(self, tangents_1):
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>""" str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
) )
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_output_view_meta_replay(self): def test_output_aliases_output_view_meta_replay(self):
@self._compile_and_erase_bases(1) @self._compile_and_erase_bases(1)
def f(a): def f(a):

View File

@ -257,6 +257,19 @@ intentionally_not_handled = {
"resize_": {b8, f16, f32, f64, i32, i64}, "resize_": {b8, f16, f32, f64, i32, i64},
"resize_as_": {b8, f16, f32, f64, i32, i64}, "resize_as_": {b8, f16, f32, f64, i32, i64},
} }
# This is only fixed when this config is set
# We should eventually always turn it on
import torch._functorch.config as functorch_config
if not functorch_config.view_replay_for_aliased_outputs:
intentionally_not_handled['("as_strided", "partial_views")'] = {
b8,
f16,
f32,
f64,
i32,
i64,
}
inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled) inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled)

View File

@ -18,6 +18,7 @@ from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass, is_traceable_wrapper_subclass,
transform_subclass, transform_subclass,
) )
from .. import config
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
@ -219,7 +220,7 @@ def gen_alias_from_base(
# In summary, we use the fact that FunctionalTensorWrapper saves the view # In summary, we use the fact that FunctionalTensorWrapper saves the view
# functions applied to itself (collected during functionalization) so as # functions applied to itself (collected during functionalization) so as
# to replay them (view functions) on the aliased_base_tensor. # to replay them (view functions) on the aliased_base_tensor.
if target_functional_tensor is not None: if config.view_replay_for_aliased_outputs and target_functional_tensor is not None:
from .schemas import FunctionalTensorMetadataEq from .schemas import FunctionalTensorMetadataEq
assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq) assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq)
@ -237,11 +238,10 @@ def gen_alias_from_base(
# #
# In order for this to work, we should have a way to replace those # In order for this to work, we should have a way to replace those
# symbolic shapes with concrete numbers. # symbolic shapes with concrete numbers.
aot_joint_log.warning( aot_joint_log.info(
"could not reconstruct view by re-applying a ViewMeta sequence. " "could not reconstruct view by re-applying a ViewMeta sequence. "
"This error is possibly caused by dynamic shapes. "
"Fallbacking to reconstruction using as_strided. " "Fallbacking to reconstruction using as_strided. "
"Error message: %s", "Reason: %s",
str(e), str(e),
) )
else: else:

View File

@ -41,6 +41,26 @@ static_weight_shapes = True
# Applies CSE to the graph before partitioning # Applies CSE to the graph before partitioning
cse = True cse = True
# When AOTAutograd regenerates aliased graph outputs,
# attempte to use functionalization's view-replay logic
# before falling back to the autograd engine's view replay or as_strided.
# This can have some perf implications
# (although for many models this will not matter).
# (1) If you have many view ops chained together, replaying all of them
# at runtime can have more overhead compared to a single as_strided call
# (2) If you are doing training, AsStridedBackward is quite slow,
# and the individual view op backward formulas will likely be faster.
# (3) Some backends like XLA do not support as_strided
# Temporary hack: disable this flag for internal
# (needed to fix an internal issue while avoiding bumping XLA pin)
# eventually: either default this config to false completely
# once XLA pin update works,
# or default config to true and fix relevant bugs
from torch._inductor.config import is_fbcode
view_replay_for_aliased_outputs = not is_fbcode()
# Restricts the amount of computation AOTAutograd can do. # Restricts the amount of computation AOTAutograd can do.
# NB: We have essentially disabled this heuristic now. However, this is kept # NB: We have essentially disabled this heuristic now. However, this is kept
# here for now in case it's useful. Setting it low can artificially reduce the # here for now in case it's useful. Setting it low can artificially reduce the