Enable reverse view_funcs by default for python subclasses (#116512)

Part 3 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).

Changes codegen to generate `view_func()` / `rev_view_func()` by default for python subclasses. With `view_func()` existing more often now, the lazy view rebase logic [here](f10c3f4184/torch/csrc/autograd/variable.cpp (L665-L695)) causes some slight behavior changes for in-place ops on views:
* Additional view nodes are inserted into output graphs, changing their string representation, although they are functionally the same. The extra nodes are removed in AOTAutograd's DCE pass.
* When `t` is a `FunctionalTensor`, calling `t.grad_fn` will now invoke `view_func()`; we need to make sure we're operating in a `FunctionalTensorMode` so the view op calls succeed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116512
Approved by: https://github.com/bdhirsh, https://github.com/soulitzer
ghstack dependencies: #115894
This commit is contained in:
Joel Schlosser 2024-01-04 17:41:42 -05:00 committed by PyTorch MergeBot
parent 3c21264c9b
commit 7956ca16e6
5 changed files with 45 additions and 26 deletions

View File

@ -1736,8 +1736,8 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_3); as_strided_2 = as_strided_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return [as_strided_scatter, add_1]""") # noqa: B950
def test_input_mutation_aliases_other_input2(self):
@ -1762,8 +1762,8 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_3); as_strided_2 = as_strided_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return [as_strided_scatter, add_1]""") # noqa: B950
def test_input_mutation_aliases_and_output_alias(self):
@ -1786,8 +1786,8 @@ def forward(self, primals_1):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_6, [4]); as_strided_6 = None
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
return [as_strided_scatter, view_1]""") # noqa: B950
def test_input_aliased_with_mutation_output_alias(self):
@ -1816,8 +1816,8 @@ def forward(self, primals_1, primals_2):
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_6, [-1]); as_strided_6 = None
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
return [as_strided_scatter, add, view_1]""") # noqa: B950
def test_input_metadata_mutation_aliases(self):
@ -1907,11 +1907,11 @@ def forward(self, primals_1, primals_2, primals_3):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_3, 0); as_strided_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
as_strided_11 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_2 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
return [as_strided_scatter, add_2, view_2, unsqueeze_1]""") # noqa: B950
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@ -1975,8 +1975,8 @@ def forward(self, primals_1, primals_2):
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t = torch.ops.aten.t.default(view); view = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_3, as_strided_2); as_strided_3 = as_strided_2 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None
view_1 = torch.ops.aten.view.default(add, [-1])
t_1 = torch.ops.aten.t.default(t)
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)

View File

@ -1779,12 +1779,15 @@ def forward(self, arg0_1, arg1_1, arg2_1):
# Make a non-leaf
x = torch.randn(2, requires_grad=True) + 1
fx_g = make_fx(f_functionalized)(x)
# NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a
# DCE pass that will remove nodes like this later on.
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [-1])
mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None
view_1 = torch.ops.aten.view.default(mul, [-1]); mul = None
add = torch.ops.aten.add.Tensor(view_1, 1); view_1 = None
view_1 = torch.ops.aten.view.default(mul, [-1])
view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None
add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None
return add""")
def test_python_functionalization_zero_tensor(self):

View File

@ -173,7 +173,9 @@ SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeT
"""\
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided() ||
if (${is_view_with_metadata_change} ||
!self.unsafeGetTensorImpl()->support_as_strided() ||
self.unsafeGetTensorImpl()->is_python_dispatch() ||
c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
${replay_view_func}
${reverse_replay_view_func}

View File

@ -352,20 +352,31 @@ def run_functionalized_fw_and_collect_metadata(
and o is not curr
]
)
is_result_of_custom_autograd_fn = False
# See Note [Accessing .grad_fn on FunctionalTensor]
# In-place operations on views will trigger a lazy rebase of the autograd graph;
# this runs during access to the .grad_fn. The rebase logic will invoke view ops
# on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure
# these op calls succeed.
grad_fn = None
if isinstance(o, Tensor):
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) autograd fns
if type(o.grad_fn).__name__ == "CppFunction":
is_result_of_custom_autograd_fn = True
if isinstance(o.grad_fn, torch.autograd.function.BackwardCFunction):
is_result_of_custom_autograd_fn = True
with FunctionalTensorMode():
grad_fn = o.grad_fn
is_result_of_custom_autograd_fn = False
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction)
# autograd fns
if type(grad_fn).__name__ == "CppFunction":
is_result_of_custom_autograd_fn = True
if isinstance(grad_fn, torch.autograd.function.BackwardCFunction):
is_result_of_custom_autograd_fn = True
if not isinstance(o, Tensor):
output_type = OutputType.non_alias
base_idx = None
elif (
curr_storage in inp_storage_refs
and o.grad_fn is not None
and grad_fn is not None
and is_result_of_custom_autograd_fn
):
output_type = OutputType.custom_function_view
@ -384,7 +395,7 @@ def run_functionalized_fw_and_collect_metadata(
num_aliased_outs - num_multi_output_view_outs
)
if (
o.grad_fn is not None
grad_fn is not None
and num_aliased_outs_that_are_not_multi_output_views == 0
):
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]

View File

@ -12,6 +12,7 @@ from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import lazy_format_graph_code
from torch._logging import getArtifactLogger
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from .functional_utils import assert_functional_graph
@ -28,7 +29,9 @@ aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
with enable_python_dispatcher():
# FunctionalTensorMode must be enabled here.
# See Note [Accessing .grad_fn on FunctionalTensor]
with enable_python_dispatcher(), FunctionalTensorMode(aot_config.pre_dispatch):
fx_g = make_fx(
f,
decomposition_table=aot_config.decompositions,