mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Test reland "AOTAutograd: gate view-replay behind config, not the def… (#124948)
A parallel attempt at landing https://github.com/pytorch/pytorch/pull/124945, but attempting to land through fbcode first Pull Request resolved: https://github.com/pytorch/pytorch/pull/124948 Approved by: https://github.com/albanD
This commit is contained in:
parent
fc13c1c850
commit
fc2aa23c1e
|
|
@ -3261,6 +3261,7 @@ def forward(self, tangents_1):
|
||||||
|
|
||||||
return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))
|
return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))
|
||||||
|
|
||||||
|
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
|
||||||
def test_output_aliases_input_view_meta_replay(self):
|
def test_output_aliases_input_view_meta_replay(self):
|
||||||
@self._compile_and_erase_bases(0)
|
@self._compile_and_erase_bases(0)
|
||||||
def f(a):
|
def f(a):
|
||||||
|
|
@ -3274,6 +3275,7 @@ def forward(self, tangents_1):
|
||||||
str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
|
str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
|
||||||
def test_output_aliases_intermediate_view_meta_replay(self):
|
def test_output_aliases_intermediate_view_meta_replay(self):
|
||||||
@self._compile_and_erase_bases(0, 1)
|
@self._compile_and_erase_bases(0, 1)
|
||||||
def f(a):
|
def f(a):
|
||||||
|
|
@ -3293,6 +3295,7 @@ def forward(self, tangents_1):
|
||||||
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
|
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
|
||||||
def test_output_aliases_output_view_meta_replay(self):
|
def test_output_aliases_output_view_meta_replay(self):
|
||||||
@self._compile_and_erase_bases(1)
|
@self._compile_and_erase_bases(1)
|
||||||
def f(a):
|
def f(a):
|
||||||
|
|
|
||||||
|
|
@ -257,6 +257,19 @@ intentionally_not_handled = {
|
||||||
"resize_": {b8, f16, f32, f64, i32, i64},
|
"resize_": {b8, f16, f32, f64, i32, i64},
|
||||||
"resize_as_": {b8, f16, f32, f64, i32, i64},
|
"resize_as_": {b8, f16, f32, f64, i32, i64},
|
||||||
}
|
}
|
||||||
|
# This is only fixed when this config is set
|
||||||
|
# We should eventually always turn it on
|
||||||
|
import torch._functorch.config as functorch_config
|
||||||
|
|
||||||
|
if not functorch_config.view_replay_for_aliased_outputs:
|
||||||
|
intentionally_not_handled['("as_strided", "partial_views")'] = {
|
||||||
|
b8,
|
||||||
|
f16,
|
||||||
|
f32,
|
||||||
|
f64,
|
||||||
|
i32,
|
||||||
|
i64,
|
||||||
|
}
|
||||||
|
|
||||||
inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled)
|
inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from torch.utils._python_dispatch import (
|
||||||
is_traceable_wrapper_subclass,
|
is_traceable_wrapper_subclass,
|
||||||
transform_subclass,
|
transform_subclass,
|
||||||
)
|
)
|
||||||
|
from .. import config
|
||||||
|
|
||||||
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
|
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
|
||||||
|
|
||||||
|
|
@ -219,7 +220,7 @@ def gen_alias_from_base(
|
||||||
# In summary, we use the fact that FunctionalTensorWrapper saves the view
|
# In summary, we use the fact that FunctionalTensorWrapper saves the view
|
||||||
# functions applied to itself (collected during functionalization) so as
|
# functions applied to itself (collected during functionalization) so as
|
||||||
# to replay them (view functions) on the aliased_base_tensor.
|
# to replay them (view functions) on the aliased_base_tensor.
|
||||||
if target_functional_tensor is not None:
|
if config.view_replay_for_aliased_outputs and target_functional_tensor is not None:
|
||||||
from .schemas import FunctionalTensorMetadataEq
|
from .schemas import FunctionalTensorMetadataEq
|
||||||
|
|
||||||
assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq)
|
assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq)
|
||||||
|
|
@ -237,11 +238,10 @@ def gen_alias_from_base(
|
||||||
#
|
#
|
||||||
# In order for this to work, we should have a way to replace those
|
# In order for this to work, we should have a way to replace those
|
||||||
# symbolic shapes with concrete numbers.
|
# symbolic shapes with concrete numbers.
|
||||||
aot_joint_log.warning(
|
aot_joint_log.info(
|
||||||
"could not reconstruct view by re-applying a ViewMeta sequence. "
|
"could not reconstruct view by re-applying a ViewMeta sequence. "
|
||||||
"This error is possibly caused by dynamic shapes. "
|
|
||||||
"Fallbacking to reconstruction using as_strided. "
|
"Fallbacking to reconstruction using as_strided. "
|
||||||
"Error message: %s",
|
"Reason: %s",
|
||||||
str(e),
|
str(e),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,26 @@ static_weight_shapes = True
|
||||||
# Applies CSE to the graph before partitioning
|
# Applies CSE to the graph before partitioning
|
||||||
cse = True
|
cse = True
|
||||||
|
|
||||||
|
# When AOTAutograd regenerates aliased graph outputs,
|
||||||
|
# attempte to use functionalization's view-replay logic
|
||||||
|
# before falling back to the autograd engine's view replay or as_strided.
|
||||||
|
# This can have some perf implications
|
||||||
|
# (although for many models this will not matter).
|
||||||
|
# (1) If you have many view ops chained together, replaying all of them
|
||||||
|
# at runtime can have more overhead compared to a single as_strided call
|
||||||
|
# (2) If you are doing training, AsStridedBackward is quite slow,
|
||||||
|
# and the individual view op backward formulas will likely be faster.
|
||||||
|
# (3) Some backends like XLA do not support as_strided
|
||||||
|
|
||||||
|
# Temporary hack: disable this flag for internal
|
||||||
|
# (needed to fix an internal issue while avoiding bumping XLA pin)
|
||||||
|
# eventually: either default this config to false completely
|
||||||
|
# once XLA pin update works,
|
||||||
|
# or default config to true and fix relevant bugs
|
||||||
|
from torch._inductor.config import is_fbcode
|
||||||
|
|
||||||
|
view_replay_for_aliased_outputs = not is_fbcode()
|
||||||
|
|
||||||
# Restricts the amount of computation AOTAutograd can do.
|
# Restricts the amount of computation AOTAutograd can do.
|
||||||
# NB: We have essentially disabled this heuristic now. However, this is kept
|
# NB: We have essentially disabled this heuristic now. However, this is kept
|
||||||
# here for now in case it's useful. Setting it low can artificially reduce the
|
# here for now in case it's useful. Setting it low can artificially reduce the
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user