mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotd] Do not force contiguous() for channels_last (#135225)
Original Issue: https://github.com/pytorch/pytorch/issues/134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Subclasses logic: - Previously coercing tangents logic did not handle nested subclasses case, fixing this. For Subclasses we deduce memory format for subclass_tensor first, then for each element of subclass: [subclass_tensor_memory_format, subclass_tensor_elem0_memory_format, ... ] If subclass element (__tensor_flatten__[0] tensors) is also subclass => on its place we will have a nested list of the same structure. The recursive traversal of subclass tree is expensive. So we do memory format deduction and coercing at the same time, to keep only one traverse for this. With this approach there is no regression in comparison with previous logic which also does one traversal. (`coerce_tangent_and_suggest_memory_format` method). Other small change: Remove duplicated not-related comment. Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` Benchmarking: After change: ``` └─ $ PYTORCH_AOTD_DEBUG_PROFILE=1 python test/functorch/test_aotdispatch.py -k test_benchmark_grads_no_force_contiguous Benchmark SUBCLASS avg_bwd_duration:4.059906005859375 ms Benchmark NO_SUBCLASS avg_bwd_duration:3.1563830375671387 ms ``` Before change: ``` BEFORE_CHANGE SUBCLASS 4.1194 ``` No siginificant changes in processing time. (We do single traverse of subclass tree for collecting memory_formats and coercing during tracing.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135225 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
de159f0c8d
commit
34d788ffb0
|
|
@ -10,7 +10,7 @@ import copy
|
|||
import itertools
|
||||
import unittest
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from contextlib import ContextDecorator, nullcontext
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
|
@ -5746,6 +5746,62 @@ def forward(self, tangents_1, tangents_2):
|
|||
self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
|
||||
|
||||
|
||||
class GradsNoForceContiguousContextManager(ContextDecorator):
|
||||
def __enter__(self):
|
||||
# flake8: noqa: TOR901
|
||||
self.lib = torch.library.Library("_mylib", "FRAGMENT")
|
||||
self.d = {
|
||||
torch.channels_last: 0,
|
||||
torch.contiguous_format: 0,
|
||||
}
|
||||
|
||||
self.lib.define("foo(Tensor x) -> Tensor")
|
||||
self.lib.define("foo2(Tensor x) -> Tensor")
|
||||
|
||||
def foo_impl(a):
|
||||
return a.clone()
|
||||
|
||||
def foo_meta(a):
|
||||
return a.clone()
|
||||
|
||||
def foo2_impl(x):
|
||||
self.d[torch._prims_common.suggest_memory_format(x)] += 1
|
||||
return x.clone()
|
||||
|
||||
def foo2_meta(a):
|
||||
return a.clone()
|
||||
|
||||
for backend in ["CPU", "CUDA"]:
|
||||
self.lib.impl("foo", foo_impl, backend)
|
||||
self.lib.impl("foo2", foo2_impl, backend)
|
||||
|
||||
self.lib.impl("foo", foo_meta, "Meta")
|
||||
self.lib.impl("foo2", foo2_meta, "Meta")
|
||||
|
||||
def foo_bwd(ctx, grad):
|
||||
torch.ops._mylib.foo2(grad)
|
||||
return grad.clone()
|
||||
|
||||
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=self.lib)
|
||||
|
||||
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
|
||||
_register_effectful_op(torch.ops._mylib.foo2.default, _EffectType.ORDERED)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
self.lib._destroy()
|
||||
return False
|
||||
|
||||
def reset_counters(self):
|
||||
self.d = {
|
||||
torch.channels_last: 0,
|
||||
torch.contiguous_format: 0,
|
||||
}
|
||||
|
||||
|
||||
class TestAOTModuleSimplified(AOTTestCase):
|
||||
def test_aot_module_simplified(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
|
|
@ -5969,6 +6025,165 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
|
||||
self.assertEqual(ref_out, out)
|
||||
|
||||
# Next several tests are related to issue:
|
||||
# https://github.com/pytorch/pytorch/issues/134644
|
||||
# AOTD tries to predict tangents for tracing ahead of time.
|
||||
# The first strategy was to coerce traced_tangents and runtime_tangents to be contiguous().
|
||||
# But for models working in channels_last memory format this will add additional contiguous() calls.
|
||||
# The fix is predicting tangents memory format to be similar to outputs memory format.
|
||||
# And coerce runtime tangents to that traced memory format.
|
||||
def test_grads_no_force_contiguous_dense(self):
|
||||
with GradsNoForceContiguousContextManager() as ctx:
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x, y, cont_inp):
|
||||
z = y + 3
|
||||
y.mul_(2)
|
||||
r = self.conv(x)
|
||||
r = torch.ops._mylib.foo(r)
|
||||
return (
|
||||
r,
|
||||
r.transpose(0, 1),
|
||||
z.view(-1),
|
||||
z.transpose(0, 1),
|
||||
cont_inp * 2,
|
||||
)
|
||||
|
||||
m = M()
|
||||
m.to(memory_format=torch.channels_last)
|
||||
m.train()
|
||||
|
||||
def dense_inps():
|
||||
return (
|
||||
torch.randn(2, 3, 5, 5, requires_grad=True).to(
|
||||
memory_format=torch.channels_last
|
||||
),
|
||||
torch.randn(3, 2, 1, 1, requires_grad=True).to(
|
||||
memory_format=torch.channels_last
|
||||
),
|
||||
torch.randn(3, 2, 1, 1, requires_grad=True),
|
||||
)
|
||||
|
||||
ref_inps = dense_inps()
|
||||
ref_outs = m(*ref_inps)
|
||||
ref_outs[0].sum().backward()
|
||||
|
||||
ctx.reset_counters()
|
||||
inps = dense_inps()
|
||||
outs = torch.compile(m, backend="inductor", fullgraph=True)(*inps)
|
||||
outs[0].sum().backward()
|
||||
|
||||
self.assertEqual(ctx.d[torch.channels_last], 1)
|
||||
self.assertEqual(ctx.d[torch.contiguous_format], 0)
|
||||
|
||||
def test_grads_no_force_contiguous_subclass(self):
|
||||
with GradsNoForceContiguousContextManager() as ctx:
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x, y):
|
||||
r = self.conv(x)
|
||||
r = torch.ops._mylib.foo(r)
|
||||
return r, y + 1
|
||||
|
||||
m = M()
|
||||
m.to(memory_format=torch.channels_last)
|
||||
m.train()
|
||||
|
||||
def inps_fn():
|
||||
return (
|
||||
TwoTensor(
|
||||
torch.randn(2, 3, 5, 5, requires_grad=True).to(
|
||||
memory_format=torch.channels_last
|
||||
),
|
||||
torch.randn(2, 3, 5, 5, requires_grad=True).to(
|
||||
memory_format=torch.channels_last
|
||||
),
|
||||
),
|
||||
torch.randn(3, 2, requires_grad=True).clone(),
|
||||
)
|
||||
|
||||
ref_outs = m(*inps_fn())
|
||||
ref_outs[0].sum().backward()
|
||||
|
||||
ctx.reset_counters()
|
||||
mc = M()
|
||||
mc.to(memory_format=torch.channels_last)
|
||||
mc.train()
|
||||
outs = torch.compile(mc, backend="aot_eager", fullgraph=True)(*inps_fn())
|
||||
outs[0].sum().backward()
|
||||
|
||||
self.assertEqual(ctx.d[torch.channels_last], 2)
|
||||
self.assertEqual(ctx.d[torch.contiguous_format], 0)
|
||||
|
||||
def test_grads_no_force_contiguous_nested_subclass(self):
|
||||
with GradsNoForceContiguousContextManager() as ctx:
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
r = self.conv(x)
|
||||
r = torch.ops._mylib.foo(r)
|
||||
return r
|
||||
|
||||
m = M()
|
||||
m.to(memory_format=torch.channels_last)
|
||||
m.train()
|
||||
|
||||
def inps_fn(x):
|
||||
return (
|
||||
TwoTensor(
|
||||
TwoTensor(x.clone(), x.clone()), TwoTensor(x.clone(), x.clone())
|
||||
),
|
||||
)
|
||||
|
||||
x = torch.randn(2, 3, 5, 5, requires_grad=True).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
ref_inps = inps_fn(x)
|
||||
ref_outs = m(*ref_inps)
|
||||
ref_outs[0].sum().backward()
|
||||
|
||||
ctx.reset_counters()
|
||||
|
||||
mc = M()
|
||||
mc.to(memory_format=torch.channels_last)
|
||||
mc.train()
|
||||
|
||||
x = torch.randn(2, 3, 5, 5, requires_grad=True).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
inps = inps_fn(x)
|
||||
outs = torch.compile(mc, backend="aot_eager", fullgraph=True)(*inps)
|
||||
outs[0].sum().backward()
|
||||
self.assertEqual(ctx.d[torch.channels_last], 4)
|
||||
self.assertEqual(ctx.d[torch.contiguous_format], 0)
|
||||
|
||||
def test_grads_no_force_contiguous_nested_tensor_tangent(self):
|
||||
# NestedTensor setattr could fails with AttributeError for attr "_min_seqlen_tensor"
|
||||
# Adding test to verify that it is handled.
|
||||
def fn(x):
|
||||
return x.clone()
|
||||
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
|
||||
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
||||
|
||||
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(nt)
|
||||
out_buffer = out.values()
|
||||
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
|
||||
|
||||
@torch._inductor.config.patch({"freezing": True})
|
||||
def test_inductor_freezing_with_subclasses(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -56,19 +56,36 @@ log = logging.getLogger(__name__)
|
|||
static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs")
|
||||
|
||||
|
||||
# Note [Tangents must be contiguous]
|
||||
# We force tangents to be contiguous today.
|
||||
# Note [Tangents memory format]
|
||||
# We assume tangents memory format to be similar to corresponding output's memory_format.
|
||||
# The idea is that we are technically making a guess about the strides of our tangents,
|
||||
# while we trace out the joint.
|
||||
# Today, we force this guess to be correct by additioanlly calling contiguous()
|
||||
# on all tangents at runtime.
|
||||
# In the future, you could imagine lifting this restriction, since these contiguous()
|
||||
# calls can have noticeable perf overhead depending on the model.
|
||||
def coerce_tangent(x):
|
||||
# If runtime specfied tangents will not have the same memory format as predicted traced tangents,
|
||||
# we coerce them at runtime to traced tangents memory format.
|
||||
|
||||
|
||||
# Coercing and collecting traced tangents memory format in one recursive traversal
|
||||
# mypy: ignore-errors
|
||||
def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||
updated = False
|
||||
if not isinstance(x, Tensor):
|
||||
return x
|
||||
out = x.detach().contiguous()
|
||||
# Note [Tangents must be contiguous, Part 2]
|
||||
return x, None, updated
|
||||
|
||||
out = x.detach()
|
||||
|
||||
suggest_memory_format = torch._prims_common.suggest_memory_format
|
||||
is_subclass = is_traceable_wrapper_subclass(out)
|
||||
|
||||
memory_format = suggest_memory_format(out)
|
||||
|
||||
was = out
|
||||
out = out.contiguous(memory_format=memory_format)
|
||||
updated = out is not was
|
||||
|
||||
# For subclass we keep memory format of outer strides at the beggining of the list
|
||||
out_memory_format = [memory_format] if is_subclass else memory_format
|
||||
|
||||
# Note [Tangents memory format, Part 2]
|
||||
# In the same way that "what strides do we assigns to our tangents" is a question
|
||||
# that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time,
|
||||
# The same applies to any tensor subclass metadata, when we have tangents that are subclasses.
|
||||
|
|
@ -87,20 +104,24 @@ def coerce_tangent(x):
|
|||
# placement into one with a Shard() placement, in the case that we "guessed wrong",
|
||||
# and traced tangents with a Shard() placement at compile time.
|
||||
#
|
||||
if is_traceable_wrapper_subclass(out) and hasattr(
|
||||
out, "__coerce_tangent_metadata__"
|
||||
):
|
||||
if is_subclass and hasattr(out, "__coerce_tangent_metadata__"):
|
||||
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
|
||||
# It's possible to have a subclass that advertises as contiguous,
|
||||
# but has noncontiguous inner tensors.
|
||||
# Force these to be conntiguous too
|
||||
if is_traceable_wrapper_subclass(out):
|
||||
for attr in out.__tensor_flatten__()[0]: # type: ignore[attr-defined]
|
||||
|
||||
if is_subclass:
|
||||
attrs = out.__tensor_flatten__()[0]
|
||||
|
||||
for attr in attrs:
|
||||
elem = getattr(out, attr)
|
||||
if not elem.is_contiguous():
|
||||
elem_contig = elem.contiguous()
|
||||
setattr(out, attr, elem_contig)
|
||||
return out
|
||||
(
|
||||
new_elem,
|
||||
new_elem_memory_format,
|
||||
elem_updated,
|
||||
) = coerce_tangent_and_suggest_memory_format(elem)
|
||||
out_memory_format.append(new_elem_memory_format)
|
||||
if elem_updated:
|
||||
setattr(out, attr, new_elem)
|
||||
|
||||
return out, out_memory_format, updated
|
||||
|
||||
|
||||
# This is a version of functionalization that is specifically designed
|
||||
|
|
@ -669,13 +690,15 @@ from a multi-output view call"
|
|||
traced_tangents = pytree.tree_map(
|
||||
view_avoid_dupes_with_primals, traced_tangents
|
||||
)
|
||||
# See Note [Tangents must be contiguous]
|
||||
traced_tangents = pytree.tree_map(
|
||||
coerce_tangent,
|
||||
traced_tangents,
|
||||
)
|
||||
user_outs = pytree.tree_map(from_fun, f_output_tangents)
|
||||
|
||||
output_tangents_start_idx = len(f_input_tangents)
|
||||
output_tangents_end_idx = output_tangents_start_idx + len(f_output_tangents)
|
||||
tangents_and_memory_formats = [
|
||||
coerce_tangent_and_suggest_memory_format(tt)
|
||||
for i, tt in enumerate(traced_tangents)
|
||||
]
|
||||
traced_tangents = [t[0] for t in tangents_and_memory_formats]
|
||||
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
|
||||
nonlocal static_input_indices
|
||||
static_input_indices = static_input_indices or []
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
|
|
@ -736,6 +759,7 @@ from a multi-output view call"
|
|||
num_intermediate_bases=len(intermediate_bases),
|
||||
keep_input_mutations=keep_input_mutations,
|
||||
traced_tangents=traced_tangents,
|
||||
traced_tangent_memory_formats=traced_tangent_memory_formats,
|
||||
subclass_inp_meta=create_subclass_meta(flat_args),
|
||||
subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs),
|
||||
subclass_tangent_meta=create_subclass_meta(traced_tangents),
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from torch._subclasses.functional_tensor import FunctionalTensor
|
|||
from torch.fx.experimental.symbolic_shapes import is_concrete_int
|
||||
|
||||
from .. import config
|
||||
from .collect_metadata_analysis import coerce_tangent
|
||||
from .collect_metadata_analysis import coerce_tangent_and_suggest_memory_format
|
||||
from .schemas import (
|
||||
BackwardSignature,
|
||||
GraphSignature,
|
||||
|
|
@ -49,12 +49,16 @@ def remove_dupe_metadata(
|
|||
other_traced_tangents = m.traced_tangents[num_data_mutations:]
|
||||
inp_traced_tangents = m.traced_tangents[:num_data_mutations]
|
||||
filtered_inp_traced_tangents = [
|
||||
# See Note [Tangents must be contiguous]
|
||||
# See Note [Tangents memory format]
|
||||
x
|
||||
for i, x in enumerate(inp_traced_tangents)
|
||||
if keep_arg_mask[m.mutated_inp_runtime_indices[i]]
|
||||
]
|
||||
traced_tangents = filtered_inp_traced_tangents + other_traced_tangents
|
||||
assert m.traced_tangent_memory_formats is not None
|
||||
traced_tangent_memory_formats = [torch.contiguous_format] * len(
|
||||
filtered_inp_traced_tangents
|
||||
) + m.traced_tangent_memory_formats[num_data_mutations:]
|
||||
|
||||
return ViewAndMutationMeta(
|
||||
input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]],
|
||||
|
|
@ -74,6 +78,7 @@ def remove_dupe_metadata(
|
|||
num_intermediate_bases=m.num_intermediate_bases,
|
||||
keep_input_mutations=m.keep_input_mutations,
|
||||
traced_tangents=traced_tangents,
|
||||
traced_tangent_memory_formats=traced_tangent_memory_formats,
|
||||
# We are guaranteed not to get here, since dupes are not supported today with subclass inputs.
|
||||
subclass_inp_meta=[],
|
||||
subclass_fw_graph_out_meta=[],
|
||||
|
|
@ -82,18 +87,6 @@ def remove_dupe_metadata(
|
|||
)
|
||||
|
||||
|
||||
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
|
||||
# after adding synthetic base arguments to the function.
|
||||
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
|
||||
# and updating it with our synthetic base calling convention.
|
||||
#
|
||||
# When config.debug_assert is set, we automatically regenerate the metadata
|
||||
# and compare it to this output for sanity.
|
||||
#
|
||||
# In addition to the updated metadata, also return the list of input indices
|
||||
# that will need to be updated in the synthetic base epilogue
|
||||
|
||||
|
||||
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
|
||||
# after adding synthetic base arguments to the function.
|
||||
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
|
||||
|
|
@ -235,18 +228,27 @@ def create_synthetic_base_metadata(
|
|||
)
|
||||
)
|
||||
|
||||
inner_mutated_tangents = [
|
||||
# See Note [Tangents must be contiguous]
|
||||
coerce_tangent(x)
|
||||
inner_mutated_tangents_and_memory_formats = [
|
||||
# See Note [Tangents memory format]
|
||||
coerce_tangent_and_suggest_memory_format(x)
|
||||
for inner_idx, x in enumerate(inner_args)
|
||||
if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad
|
||||
]
|
||||
inner_mutated_tangents = [x[0] for x in inner_mutated_tangents_and_memory_formats]
|
||||
inner_mutated_tangents_memory_formats = [
|
||||
x[1] for x in inner_mutated_tangents_and_memory_formats
|
||||
]
|
||||
|
||||
output_info = existing_output_infos + input_metadata_output_info
|
||||
# Regenerate traced tangents to include mutated inputs including synthetic bases
|
||||
traced_tangents = (
|
||||
inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :]
|
||||
)
|
||||
assert m.traced_tangent_memory_formats is not None
|
||||
traced_tangent_memory_formats = (
|
||||
inner_mutated_tangents_memory_formats
|
||||
+ m.traced_tangent_memory_formats[len(inner_mutated_tangents) :]
|
||||
)
|
||||
|
||||
return (
|
||||
ViewAndMutationMeta(
|
||||
|
|
@ -255,6 +257,7 @@ def create_synthetic_base_metadata(
|
|||
num_intermediate_bases=m.num_intermediate_bases,
|
||||
keep_input_mutations=m.keep_input_mutations,
|
||||
traced_tangents=traced_tangents,
|
||||
traced_tangent_memory_formats=traced_tangent_memory_formats,
|
||||
# We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs.
|
||||
subclass_inp_meta=[],
|
||||
subclass_fw_graph_out_meta=[],
|
||||
|
|
|
|||
|
|
@ -1438,19 +1438,29 @@ class AutogradLazyBackwardCompileInfo:
|
|||
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
|
||||
class AOTDispatchAutograd:
|
||||
@staticmethod
|
||||
def _force_contiguous(x):
|
||||
def coerce_runtime_tangent_tracing_memory_format(x, memory_format):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
x = x.contiguous()
|
||||
if not is_traceable_wrapper_subclass(x):
|
||||
|
||||
is_subclass: bool = is_traceable_wrapper_subclass(x)
|
||||
mem_format = memory_format[0] if is_subclass else memory_format
|
||||
|
||||
if not x.is_contiguous(memory_format=mem_format):
|
||||
x = x.contiguous(memory_format=mem_format)
|
||||
|
||||
if not is_subclass:
|
||||
return x
|
||||
for attr in x.__tensor_flatten__()[0]: # type: ignore[attr-defined]
|
||||
for i, attr in enumerate(x.__tensor_flatten__()[0]): # type: ignore[attr-defined]
|
||||
elem = getattr(x, attr)
|
||||
if not elem.is_contiguous():
|
||||
setattr(x, attr, elem.contiguous())
|
||||
new_elem = AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format(
|
||||
elem, memory_format[1 + i]
|
||||
)
|
||||
if new_elem is not elem:
|
||||
setattr(x, attr, new_elem)
|
||||
|
||||
return x
|
||||
|
||||
# See Note [Tangents must be contiguous, Part 2]
|
||||
# See Note [Tangents memory format, Part 2]
|
||||
@staticmethod
|
||||
def coerce_runtime_tangent(x, metadata):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
|
|
@ -1865,24 +1875,46 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
)
|
||||
for i, t in enumerate(all_args)
|
||||
]
|
||||
# Coercing tangents memory format before unwrapping tensor subclasses,
|
||||
# As we have to coerce Subclass tangent first and then its Tensor attributes.
|
||||
assert (
|
||||
CompiledFunction.metadata.traced_tangent_memory_formats
|
||||
is not None
|
||||
)
|
||||
all_args = [
|
||||
(
|
||||
AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format(
|
||||
t,
|
||||
CompiledFunction.metadata.traced_tangent_memory_formats[
|
||||
i - tangents_start_idx
|
||||
],
|
||||
)
|
||||
if tangents_start_idx <= i < tangents_end_idx
|
||||
else t
|
||||
)
|
||||
for i, t in enumerate(all_args)
|
||||
]
|
||||
all_args = unwrap_tensor_subclasses(
|
||||
all_args, is_joint_structure=False
|
||||
)
|
||||
tangents_start_idx = (
|
||||
len(all_args) - len_tangents - len(rng_args) - len(bw_tokens)
|
||||
else:
|
||||
assert (
|
||||
CompiledFunction.metadata.traced_tangent_memory_formats
|
||||
is not None
|
||||
)
|
||||
tangents_end_idx = tangents_start_idx + len_tangents
|
||||
|
||||
# Make the tangents contiguous. Note that we must do this after subclass desugaring
|
||||
# because inputs to inductor have to be contiguous
|
||||
all_args = [
|
||||
(
|
||||
AOTDispatchAutograd._force_contiguous(t)
|
||||
if (tangents_start_idx <= i < tangents_end_idx)
|
||||
else t
|
||||
)
|
||||
for i, t in enumerate(all_args)
|
||||
]
|
||||
all_args = [
|
||||
(
|
||||
AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format(
|
||||
t,
|
||||
CompiledFunction.metadata.traced_tangent_memory_formats[
|
||||
i - tangents_start_idx
|
||||
],
|
||||
)
|
||||
if (tangents_start_idx <= i < tangents_end_idx)
|
||||
else t
|
||||
)
|
||||
for i, t in enumerate(all_args)
|
||||
]
|
||||
|
||||
def call_compiled_backward():
|
||||
if ctx._is_compiled_autograd_tracing():
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from .utils import strict_zip
|
|||
|
||||
zip = strict_zip
|
||||
|
||||
|
||||
OutputType = Enum(
|
||||
"OutputType",
|
||||
(
|
||||
|
|
@ -313,6 +314,14 @@ class ViewAndMutationMeta:
|
|||
# This list is generated after calling make_runtime_safe().
|
||||
traced_tangent_metas: Optional[List[Any]] = None
|
||||
|
||||
# for each tangent at index i:
|
||||
# if the tangent is a plain tensor, traced_tangent_memory_formats[i] holds the memory format
|
||||
# of the tangent that we need to coerce to
|
||||
# if the tangent is a subclass, traced_tangent_memory_formats[i] holds a list of memory formats,
|
||||
# containing the expected memory format of the subclass **and** all of its inner tensors
|
||||
TANGENT_MEMORY_FORMAT = Union[torch.memory_format, List["TANGENT_MEMORY_FORMAT"]]
|
||||
traced_tangent_memory_formats: Optional[List[TANGENT_MEMORY_FORMAT]] = None
|
||||
|
||||
num_symints_saved_for_bw: Optional[int] = None
|
||||
|
||||
# The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue
|
||||
|
|
|
|||
|
|
@ -290,6 +290,7 @@ def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMe
|
|||
num_intermediate_bases = None
|
||||
keep_input_mutations = meta.keep_input_mutations
|
||||
traced_tangents = None
|
||||
traced_tangent_memory_formats = None
|
||||
subclass_inp_meta = None
|
||||
subclass_fw_graph_out_meta = None
|
||||
subclass_tangent_meta = None
|
||||
|
|
@ -300,6 +301,7 @@ def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMe
|
|||
num_intermediate_bases=num_intermediate_bases, # type: ignore[arg-type]
|
||||
keep_input_mutations=keep_input_mutations, # type: ignore[arg-type]
|
||||
traced_tangents=traced_tangents, # type: ignore[arg-type]
|
||||
traced_tangent_memory_formats=traced_tangent_memory_formats,
|
||||
subclass_inp_meta=subclass_inp_meta, # type: ignore[arg-type]
|
||||
subclass_fw_graph_out_meta=subclass_fw_graph_out_meta, # type: ignore[arg-type]
|
||||
subclass_tangent_meta=subclass_tangent_meta, # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -678,6 +678,7 @@ def _create_aot_dispatcher_function(
|
|||
num_intermediate_bases=fw_metadata.num_intermediate_bases,
|
||||
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
||||
traced_tangents=fw_metadata.traced_tangents,
|
||||
traced_tangent_memory_formats=fw_metadata.traced_tangent_memory_formats,
|
||||
subclass_inp_meta=fw_metadata.subclass_inp_meta,
|
||||
subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,
|
||||
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user