[aotd] Do not force contiguous() for channels_last (#135225)

Original Issue: https://github.com/pytorch/pytorch/issues/134644

We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing.

=>
Tracing time:
- Store trace_tangents_memory_formats in metadata
- Coerce tangents to deduced memory_format

Runtime:
- Coerce tangents to tracing memory format from metadata

Subclasses logic:
 - Previously coercing tangents logic did not handle nested subclasses case, fixing this.

For Subclasses we deduce memory format for subclass_tensor first, then for each element of subclass:
[subclass_tensor_memory_format, subclass_tensor_elem0_memory_format, ... ]

If subclass element (__tensor_flatten__[0] tensors) is also subclass => on its place we will have a nested list of the same structure.

The recursive traversal of subclass tree is expensive. So we do memory format deduction and coercing at the same time, to keep only one traverse for this. With this approach there  is no regression in comparison with previous logic which also does one traversal. (`coerce_tangent_and_suggest_memory_format` method).

Other small change:
Remove duplicated not-related comment.

Testing

```
python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous
```

Benchmarking:
After change:
```
└─ $ PYTORCH_AOTD_DEBUG_PROFILE=1 python test/functorch/test_aotdispatch.py -k test_benchmark_grads_no_force_contiguous
Benchmark SUBCLASS avg_bwd_duration:4.059906005859375 ms
Benchmark NO_SUBCLASS avg_bwd_duration:3.1563830375671387 ms
```
Before change:
```
BEFORE_CHANGE SUBCLASS 4.1194
```

No siginificant changes in processing time.

(We do single traverse of subclass tree for collecting memory_formats and coercing during tracing.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135225
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2024-09-26 09:52:00 -07:00 committed by PyTorch MergeBot
parent de159f0c8d
commit 34d788ffb0
7 changed files with 353 additions and 67 deletions

View File

@ -10,7 +10,7 @@ import copy
import itertools
import unittest
import warnings
from contextlib import nullcontext
from contextlib import ContextDecorator, nullcontext
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch
@ -5746,6 +5746,62 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
class GradsNoForceContiguousContextManager(ContextDecorator):
def __enter__(self):
# flake8: noqa: TOR901
self.lib = torch.library.Library("_mylib", "FRAGMENT")
self.d = {
torch.channels_last: 0,
torch.contiguous_format: 0,
}
self.lib.define("foo(Tensor x) -> Tensor")
self.lib.define("foo2(Tensor x) -> Tensor")
def foo_impl(a):
return a.clone()
def foo_meta(a):
return a.clone()
def foo2_impl(x):
self.d[torch._prims_common.suggest_memory_format(x)] += 1
return x.clone()
def foo2_meta(a):
return a.clone()
for backend in ["CPU", "CUDA"]:
self.lib.impl("foo", foo_impl, backend)
self.lib.impl("foo2", foo2_impl, backend)
self.lib.impl("foo", foo_meta, "Meta")
self.lib.impl("foo2", foo2_meta, "Meta")
def foo_bwd(ctx, grad):
torch.ops._mylib.foo2(grad)
return grad.clone()
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=self.lib)
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
_register_effectful_op(torch.ops._mylib.foo2.default, _EffectType.ORDERED)
return self
def __exit__(self, type, value, tb):
self.lib._destroy()
return False
def reset_counters(self):
self.d = {
torch.channels_last: 0,
torch.contiguous_format: 0,
}
class TestAOTModuleSimplified(AOTTestCase):
def test_aot_module_simplified(self):
class MockModule(torch.nn.Module):
@ -5969,6 +6025,165 @@ class TestAOTModuleSimplified(AOTTestCase):
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
self.assertEqual(ref_out, out)
# Next several tests are related to issue:
# https://github.com/pytorch/pytorch/issues/134644
# AOTD tries to predict tangents for tracing ahead of time.
# The first strategy was to coerce traced_tangents and runtime_tangents to be contiguous().
# But for models working in channels_last memory format this will add additional contiguous() calls.
# The fix is predicting tangents memory format to be similar to outputs memory format.
# And coerce runtime tangents to that traced memory format.
def test_grads_no_force_contiguous_dense(self):
with GradsNoForceContiguousContextManager() as ctx:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x, y, cont_inp):
z = y + 3
y.mul_(2)
r = self.conv(x)
r = torch.ops._mylib.foo(r)
return (
r,
r.transpose(0, 1),
z.view(-1),
z.transpose(0, 1),
cont_inp * 2,
)
m = M()
m.to(memory_format=torch.channels_last)
m.train()
def dense_inps():
return (
torch.randn(2, 3, 5, 5, requires_grad=True).to(
memory_format=torch.channels_last
),
torch.randn(3, 2, 1, 1, requires_grad=True).to(
memory_format=torch.channels_last
),
torch.randn(3, 2, 1, 1, requires_grad=True),
)
ref_inps = dense_inps()
ref_outs = m(*ref_inps)
ref_outs[0].sum().backward()
ctx.reset_counters()
inps = dense_inps()
outs = torch.compile(m, backend="inductor", fullgraph=True)(*inps)
outs[0].sum().backward()
self.assertEqual(ctx.d[torch.channels_last], 1)
self.assertEqual(ctx.d[torch.contiguous_format], 0)
def test_grads_no_force_contiguous_subclass(self):
with GradsNoForceContiguousContextManager() as ctx:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x, y):
r = self.conv(x)
r = torch.ops._mylib.foo(r)
return r, y + 1
m = M()
m.to(memory_format=torch.channels_last)
m.train()
def inps_fn():
return (
TwoTensor(
torch.randn(2, 3, 5, 5, requires_grad=True).to(
memory_format=torch.channels_last
),
torch.randn(2, 3, 5, 5, requires_grad=True).to(
memory_format=torch.channels_last
),
),
torch.randn(3, 2, requires_grad=True).clone(),
)
ref_outs = m(*inps_fn())
ref_outs[0].sum().backward()
ctx.reset_counters()
mc = M()
mc.to(memory_format=torch.channels_last)
mc.train()
outs = torch.compile(mc, backend="aot_eager", fullgraph=True)(*inps_fn())
outs[0].sum().backward()
self.assertEqual(ctx.d[torch.channels_last], 2)
self.assertEqual(ctx.d[torch.contiguous_format], 0)
def test_grads_no_force_contiguous_nested_subclass(self):
with GradsNoForceContiguousContextManager() as ctx:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
r = self.conv(x)
r = torch.ops._mylib.foo(r)
return r
m = M()
m.to(memory_format=torch.channels_last)
m.train()
def inps_fn(x):
return (
TwoTensor(
TwoTensor(x.clone(), x.clone()), TwoTensor(x.clone(), x.clone())
),
)
x = torch.randn(2, 3, 5, 5, requires_grad=True).to(
memory_format=torch.channels_last
)
ref_inps = inps_fn(x)
ref_outs = m(*ref_inps)
ref_outs[0].sum().backward()
ctx.reset_counters()
mc = M()
mc.to(memory_format=torch.channels_last)
mc.train()
x = torch.randn(2, 3, 5, 5, requires_grad=True).to(
memory_format=torch.channels_last
)
inps = inps_fn(x)
outs = torch.compile(mc, backend="aot_eager", fullgraph=True)(*inps)
outs[0].sum().backward()
self.assertEqual(ctx.d[torch.channels_last], 4)
self.assertEqual(ctx.d[torch.contiguous_format], 0)
def test_grads_no_force_contiguous_nested_tensor_tangent(self):
# NestedTensor setattr could fails with AttributeError for attr "_min_seqlen_tensor"
# Adding test to verify that it is handled.
def fn(x):
return x.clone()
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(nt)
out_buffer = out.values()
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
@torch._inductor.config.patch({"freezing": True})
def test_inductor_freezing_with_subclasses(self):
class M(torch.nn.Module):

View File

@ -56,19 +56,36 @@ log = logging.getLogger(__name__)
static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs")
# Note [Tangents must be contiguous]
# We force tangents to be contiguous today.
# Note [Tangents memory format]
# We assume tangents memory format to be similar to corresponding output's memory_format.
# The idea is that we are technically making a guess about the strides of our tangents,
# while we trace out the joint.
# Today, we force this guess to be correct by additioanlly calling contiguous()
# on all tangents at runtime.
# In the future, you could imagine lifting this restriction, since these contiguous()
# calls can have noticeable perf overhead depending on the model.
def coerce_tangent(x):
# If runtime specfied tangents will not have the same memory format as predicted traced tangents,
# we coerce them at runtime to traced tangents memory format.
# Coercing and collecting traced tangents memory format in one recursive traversal
# mypy: ignore-errors
def coerce_tangent_and_suggest_memory_format(x: Tensor):
updated = False
if not isinstance(x, Tensor):
return x
out = x.detach().contiguous()
# Note [Tangents must be contiguous, Part 2]
return x, None, updated
out = x.detach()
suggest_memory_format = torch._prims_common.suggest_memory_format
is_subclass = is_traceable_wrapper_subclass(out)
memory_format = suggest_memory_format(out)
was = out
out = out.contiguous(memory_format=memory_format)
updated = out is not was
# For subclass we keep memory format of outer strides at the beggining of the list
out_memory_format = [memory_format] if is_subclass else memory_format
# Note [Tangents memory format, Part 2]
# In the same way that "what strides do we assigns to our tangents" is a question
# that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time,
# The same applies to any tensor subclass metadata, when we have tangents that are subclasses.
@ -87,20 +104,24 @@ def coerce_tangent(x):
# placement into one with a Shard() placement, in the case that we "guessed wrong",
# and traced tangents with a Shard() placement at compile time.
#
if is_traceable_wrapper_subclass(out) and hasattr(
out, "__coerce_tangent_metadata__"
):
if is_subclass and hasattr(out, "__coerce_tangent_metadata__"):
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
# It's possible to have a subclass that advertises as contiguous,
# but has noncontiguous inner tensors.
# Force these to be conntiguous too
if is_traceable_wrapper_subclass(out):
for attr in out.__tensor_flatten__()[0]: # type: ignore[attr-defined]
if is_subclass:
attrs = out.__tensor_flatten__()[0]
for attr in attrs:
elem = getattr(out, attr)
if not elem.is_contiguous():
elem_contig = elem.contiguous()
setattr(out, attr, elem_contig)
return out
(
new_elem,
new_elem_memory_format,
elem_updated,
) = coerce_tangent_and_suggest_memory_format(elem)
out_memory_format.append(new_elem_memory_format)
if elem_updated:
setattr(out, attr, new_elem)
return out, out_memory_format, updated
# This is a version of functionalization that is specifically designed
@ -669,13 +690,15 @@ from a multi-output view call"
traced_tangents = pytree.tree_map(
view_avoid_dupes_with_primals, traced_tangents
)
# See Note [Tangents must be contiguous]
traced_tangents = pytree.tree_map(
coerce_tangent,
traced_tangents,
)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
output_tangents_start_idx = len(f_input_tangents)
output_tangents_end_idx = output_tangents_start_idx + len(f_output_tangents)
tangents_and_memory_formats = [
coerce_tangent_and_suggest_memory_format(tt)
for i, tt in enumerate(traced_tangents)
]
traced_tangents = [t[0] for t in tangents_and_memory_formats]
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
@ -736,6 +759,7 @@ from a multi-output view call"
num_intermediate_bases=len(intermediate_bases),
keep_input_mutations=keep_input_mutations,
traced_tangents=traced_tangents,
traced_tangent_memory_formats=traced_tangent_memory_formats,
subclass_inp_meta=create_subclass_meta(flat_args),
subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs),
subclass_tangent_meta=create_subclass_meta(traced_tangents),

View File

@ -20,7 +20,7 @@ from torch._subclasses.functional_tensor import FunctionalTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
from .. import config
from .collect_metadata_analysis import coerce_tangent
from .collect_metadata_analysis import coerce_tangent_and_suggest_memory_format
from .schemas import (
BackwardSignature,
GraphSignature,
@ -49,12 +49,16 @@ def remove_dupe_metadata(
other_traced_tangents = m.traced_tangents[num_data_mutations:]
inp_traced_tangents = m.traced_tangents[:num_data_mutations]
filtered_inp_traced_tangents = [
# See Note [Tangents must be contiguous]
# See Note [Tangents memory format]
x
for i, x in enumerate(inp_traced_tangents)
if keep_arg_mask[m.mutated_inp_runtime_indices[i]]
]
traced_tangents = filtered_inp_traced_tangents + other_traced_tangents
assert m.traced_tangent_memory_formats is not None
traced_tangent_memory_formats = [torch.contiguous_format] * len(
filtered_inp_traced_tangents
) + m.traced_tangent_memory_formats[num_data_mutations:]
return ViewAndMutationMeta(
input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]],
@ -74,6 +78,7 @@ def remove_dupe_metadata(
num_intermediate_bases=m.num_intermediate_bases,
keep_input_mutations=m.keep_input_mutations,
traced_tangents=traced_tangents,
traced_tangent_memory_formats=traced_tangent_memory_formats,
# We are guaranteed not to get here, since dupes are not supported today with subclass inputs.
subclass_inp_meta=[],
subclass_fw_graph_out_meta=[],
@ -82,18 +87,6 @@ def remove_dupe_metadata(
)
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
# after adding synthetic base arguments to the function.
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
# and updating it with our synthetic base calling convention.
#
# When config.debug_assert is set, we automatically regenerate the metadata
# and compare it to this output for sanity.
#
# In addition to the updated metadata, also return the list of input indices
# that will need to be updated in the synthetic base epilogue
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
# after adding synthetic base arguments to the function.
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
@ -235,18 +228,27 @@ def create_synthetic_base_metadata(
)
)
inner_mutated_tangents = [
# See Note [Tangents must be contiguous]
coerce_tangent(x)
inner_mutated_tangents_and_memory_formats = [
# See Note [Tangents memory format]
coerce_tangent_and_suggest_memory_format(x)
for inner_idx, x in enumerate(inner_args)
if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad
]
inner_mutated_tangents = [x[0] for x in inner_mutated_tangents_and_memory_formats]
inner_mutated_tangents_memory_formats = [
x[1] for x in inner_mutated_tangents_and_memory_formats
]
output_info = existing_output_infos + input_metadata_output_info
# Regenerate traced tangents to include mutated inputs including synthetic bases
traced_tangents = (
inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :]
)
assert m.traced_tangent_memory_formats is not None
traced_tangent_memory_formats = (
inner_mutated_tangents_memory_formats
+ m.traced_tangent_memory_formats[len(inner_mutated_tangents) :]
)
return (
ViewAndMutationMeta(
@ -255,6 +257,7 @@ def create_synthetic_base_metadata(
num_intermediate_bases=m.num_intermediate_bases,
keep_input_mutations=m.keep_input_mutations,
traced_tangents=traced_tangents,
traced_tangent_memory_formats=traced_tangent_memory_formats,
# We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs.
subclass_inp_meta=[],
subclass_fw_graph_out_meta=[],

View File

@ -1438,19 +1438,29 @@ class AutogradLazyBackwardCompileInfo:
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
class AOTDispatchAutograd:
@staticmethod
def _force_contiguous(x):
def coerce_runtime_tangent_tracing_memory_format(x, memory_format):
if not isinstance(x, torch.Tensor):
return x
x = x.contiguous()
if not is_traceable_wrapper_subclass(x):
is_subclass: bool = is_traceable_wrapper_subclass(x)
mem_format = memory_format[0] if is_subclass else memory_format
if not x.is_contiguous(memory_format=mem_format):
x = x.contiguous(memory_format=mem_format)
if not is_subclass:
return x
for attr in x.__tensor_flatten__()[0]: # type: ignore[attr-defined]
for i, attr in enumerate(x.__tensor_flatten__()[0]): # type: ignore[attr-defined]
elem = getattr(x, attr)
if not elem.is_contiguous():
setattr(x, attr, elem.contiguous())
new_elem = AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format(
elem, memory_format[1 + i]
)
if new_elem is not elem:
setattr(x, attr, new_elem)
return x
# See Note [Tangents must be contiguous, Part 2]
# See Note [Tangents memory format, Part 2]
@staticmethod
def coerce_runtime_tangent(x, metadata):
if not isinstance(x, torch.Tensor):
@ -1865,19 +1875,41 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
)
for i, t in enumerate(all_args)
]
# Coercing tangents memory format before unwrapping tensor subclasses,
# As we have to coerce Subclass tangent first and then its Tensor attributes.
assert (
CompiledFunction.metadata.traced_tangent_memory_formats
is not None
)
all_args = [
(
AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format(
t,
CompiledFunction.metadata.traced_tangent_memory_formats[
i - tangents_start_idx
],
)
if tangents_start_idx <= i < tangents_end_idx
else t
)
for i, t in enumerate(all_args)
]
all_args = unwrap_tensor_subclasses(
all_args, is_joint_structure=False
)
tangents_start_idx = (
len(all_args) - len_tangents - len(rng_args) - len(bw_tokens)
else:
assert (
CompiledFunction.metadata.traced_tangent_memory_formats
is not None
)
tangents_end_idx = tangents_start_idx + len_tangents
# Make the tangents contiguous. Note that we must do this after subclass desugaring
# because inputs to inductor have to be contiguous
all_args = [
(
AOTDispatchAutograd._force_contiguous(t)
AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format(
t,
CompiledFunction.metadata.traced_tangent_memory_formats[
i - tangents_start_idx
],
)
if (tangents_start_idx <= i < tangents_end_idx)
else t
)

View File

@ -29,6 +29,7 @@ from .utils import strict_zip
zip = strict_zip
OutputType = Enum(
"OutputType",
(
@ -313,6 +314,14 @@ class ViewAndMutationMeta:
# This list is generated after calling make_runtime_safe().
traced_tangent_metas: Optional[List[Any]] = None
# for each tangent at index i:
# if the tangent is a plain tensor, traced_tangent_memory_formats[i] holds the memory format
# of the tangent that we need to coerce to
# if the tangent is a subclass, traced_tangent_memory_formats[i] holds a list of memory formats,
# containing the expected memory format of the subclass **and** all of its inner tensors
TANGENT_MEMORY_FORMAT = Union[torch.memory_format, List["TANGENT_MEMORY_FORMAT"]]
traced_tangent_memory_formats: Optional[List[TANGENT_MEMORY_FORMAT]] = None
num_symints_saved_for_bw: Optional[int] = None
# The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue

View File

@ -290,6 +290,7 @@ def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMe
num_intermediate_bases = None
keep_input_mutations = meta.keep_input_mutations
traced_tangents = None
traced_tangent_memory_formats = None
subclass_inp_meta = None
subclass_fw_graph_out_meta = None
subclass_tangent_meta = None
@ -300,6 +301,7 @@ def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMe
num_intermediate_bases=num_intermediate_bases, # type: ignore[arg-type]
keep_input_mutations=keep_input_mutations, # type: ignore[arg-type]
traced_tangents=traced_tangents, # type: ignore[arg-type]
traced_tangent_memory_formats=traced_tangent_memory_formats,
subclass_inp_meta=subclass_inp_meta, # type: ignore[arg-type]
subclass_fw_graph_out_meta=subclass_fw_graph_out_meta, # type: ignore[arg-type]
subclass_tangent_meta=subclass_tangent_meta, # type: ignore[arg-type]

View File

@ -678,6 +678,7 @@ def _create_aot_dispatcher_function(
num_intermediate_bases=fw_metadata.num_intermediate_bases,
keep_input_mutations=aot_config.keep_inference_input_mutations,
traced_tangents=fw_metadata.traced_tangents,
traced_tangent_memory_formats=fw_metadata.traced_tangent_memory_formats,
subclass_inp_meta=fw_metadata.subclass_inp_meta,
subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,