mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable reverse view_funcs by default for python subclasses (#116512)
Part 3 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).
Changes codegen to generate `view_func()` / `rev_view_func()` by default for python subclasses. With `view_func()` existing more often now, the lazy view rebase logic [here](f10c3f4184/torch/csrc/autograd/variable.cpp (L665-L695)) causes some slight behavior changes for in-place ops on views:
* Additional view nodes are inserted into output graphs, changing their string representation, although they are functionally the same. The extra nodes are removed in AOTAutograd's DCE pass.
* When `t` is a `FunctionalTensor`, calling `t.grad_fn` will now invoke `view_func()`; we need to make sure we're operating in a `FunctionalTensorMode` so the view op calls succeed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116512
Approved by: https://github.com/bdhirsh, https://github.com/soulitzer
ghstack dependencies: #115894
This commit is contained in:
parent
3c21264c9b
commit
7956ca16e6
|
|
@ -1736,8 +1736,8 @@ def forward(self, primals_1):
|
|||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
|
||||
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_3); as_strided_2 = as_strided_3 = None
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
|
||||
return [as_strided_scatter, add_1]""") # noqa: B950
|
||||
|
||||
def test_input_mutation_aliases_other_input2(self):
|
||||
|
|
@ -1762,8 +1762,8 @@ def forward(self, primals_1):
|
|||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
|
||||
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
|
||||
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_3); as_strided_2 = as_strided_3 = None
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
|
||||
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
|
||||
return [as_strided_scatter, add_1]""") # noqa: B950
|
||||
|
||||
def test_input_mutation_aliases_and_output_alias(self):
|
||||
|
|
@ -1786,8 +1786,8 @@ def forward(self, primals_1):
|
|||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_6, [4]); as_strided_6 = None
|
||||
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
|
||||
return [as_strided_scatter, view_1]""") # noqa: B950
|
||||
|
||||
def test_input_aliased_with_mutation_output_alias(self):
|
||||
|
|
@ -1816,8 +1816,8 @@ def forward(self, primals_1, primals_2):
|
|||
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_6, [-1]); as_strided_6 = None
|
||||
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
|
||||
return [as_strided_scatter, add, view_1]""") # noqa: B950
|
||||
|
||||
def test_input_metadata_mutation_aliases(self):
|
||||
|
|
@ -1907,11 +1907,11 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
|
||||
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_3, 0); as_strided_3 = None
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
|
||||
as_strided_11 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_2 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None
|
||||
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
|
||||
return [as_strided_scatter, add_2, view_2, unsqueeze_1]""") # noqa: B950
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
|
|
@ -1975,8 +1975,8 @@ def forward(self, primals_1, primals_2):
|
|||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
|
||||
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
t = torch.ops.aten.t.default(view); view = None
|
||||
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided_3, as_strided_2); as_strided_3 = as_strided_2 = None
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None
|
||||
view_1 = torch.ops.aten.view.default(add, [-1])
|
||||
t_1 = torch.ops.aten.t.default(t)
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
|
||||
|
|
|
|||
|
|
@ -1779,12 +1779,15 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
# Make a non-leaf
|
||||
x = torch.randn(2, requires_grad=True) + 1
|
||||
fx_g = make_fx(f_functionalized)(x)
|
||||
# NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a
|
||||
# DCE pass that will remove nodes like this later on.
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [-1])
|
||||
mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None
|
||||
view_1 = torch.ops.aten.view.default(mul, [-1]); mul = None
|
||||
add = torch.ops.aten.add.Tensor(view_1, 1); view_1 = None
|
||||
view_1 = torch.ops.aten.view.default(mul, [-1])
|
||||
view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None
|
||||
add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None
|
||||
return add""")
|
||||
|
||||
def test_python_functionalization_zero_tensor(self):
|
||||
|
|
|
|||
|
|
@ -173,7 +173,9 @@ SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeT
|
|||
"""\
|
||||
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
|
||||
std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
|
||||
if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided() ||
|
||||
if (${is_view_with_metadata_change} ||
|
||||
!self.unsafeGetTensorImpl()->support_as_strided() ||
|
||||
self.unsafeGetTensorImpl()->is_python_dispatch() ||
|
||||
c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
|
||||
${replay_view_func}
|
||||
${reverse_replay_view_func}
|
||||
|
|
|
|||
|
|
@ -352,20 +352,31 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
and o is not curr
|
||||
]
|
||||
)
|
||||
is_result_of_custom_autograd_fn = False
|
||||
|
||||
# See Note [Accessing .grad_fn on FunctionalTensor]
|
||||
# In-place operations on views will trigger a lazy rebase of the autograd graph;
|
||||
# this runs during access to the .grad_fn. The rebase logic will invoke view ops
|
||||
# on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure
|
||||
# these op calls succeed.
|
||||
grad_fn = None
|
||||
if isinstance(o, Tensor):
|
||||
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) autograd fns
|
||||
if type(o.grad_fn).__name__ == "CppFunction":
|
||||
is_result_of_custom_autograd_fn = True
|
||||
if isinstance(o.grad_fn, torch.autograd.function.BackwardCFunction):
|
||||
is_result_of_custom_autograd_fn = True
|
||||
with FunctionalTensorMode():
|
||||
grad_fn = o.grad_fn
|
||||
|
||||
is_result_of_custom_autograd_fn = False
|
||||
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction)
|
||||
# autograd fns
|
||||
if type(grad_fn).__name__ == "CppFunction":
|
||||
is_result_of_custom_autograd_fn = True
|
||||
if isinstance(grad_fn, torch.autograd.function.BackwardCFunction):
|
||||
is_result_of_custom_autograd_fn = True
|
||||
|
||||
if not isinstance(o, Tensor):
|
||||
output_type = OutputType.non_alias
|
||||
base_idx = None
|
||||
elif (
|
||||
curr_storage in inp_storage_refs
|
||||
and o.grad_fn is not None
|
||||
and grad_fn is not None
|
||||
and is_result_of_custom_autograd_fn
|
||||
):
|
||||
output_type = OutputType.custom_function_view
|
||||
|
|
@ -384,7 +395,7 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
num_aliased_outs - num_multi_output_view_outs
|
||||
)
|
||||
if (
|
||||
o.grad_fn is not None
|
||||
grad_fn is not None
|
||||
and num_aliased_outs_that_are_not_multi_output_views == 0
|
||||
):
|
||||
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from torch import Tensor
|
|||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from .functional_utils import assert_functional_graph
|
||||
|
|
@ -28,7 +29,9 @@ aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
|
|||
|
||||
|
||||
def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
|
||||
with enable_python_dispatcher():
|
||||
# FunctionalTensorMode must be enabled here.
|
||||
# See Note [Accessing .grad_fn on FunctionalTensor]
|
||||
with enable_python_dispatcher(), FunctionalTensorMode(aot_config.pre_dispatch):
|
||||
fx_g = make_fx(
|
||||
f,
|
||||
decomposition_table=aot_config.decompositions,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user