From 2c4bc65366a2cfce00fbddb2efa19e7b337e9b60 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 20 Mar 2025 04:01:28 -0700 Subject: [PATCH] [aotd] Guess tangents stride as output strides (#144579) AOTDispatch doing AOT backward graph preparation does not know real tangents that user will specify when runs backward. AOTD guesses the tangents. Before - we guessed that memory format of tangents will be as memory format of corresponding outputs. And if specified tangents at runtime are not the same memory format as we guessed during compilation, AOTD does coercion (copy) to guessed memory_format But as Horace found, there are popular use cases, where the outputs of compiled region will be in specific memory_format. E.g. in 4D tensor transposing dims 1 and 2. https://github.com/karpathy/nanoGPT/blob/master/model.py#L57 This PR changes the logic, that AOTD expects the same "strideness" of tangents as outputs. As a result it will avoid coercion for the case of transposed dims. Limitations: We keep guessing memory_format for: 1/ Dynamic shapes (needs more changes) 2/ Tensor subclasses (needs more changes) Other changes: test_torchinductor was always creating contiguous tangents via `torch.randn()`, changing them to be `torch.randn_like()` to compare computation with the same strideness. (E.g. for cuda float16 strideness affects numerics for fft ops). Pull Request resolved: https://github.com/pytorch/pytorch/pull/144579 Approved by: https://github.com/bdhirsh --- test/dynamo/test_backward_higher_order_ops.py | 42 ++--- test/functorch/test_aotdispatch.py | 174 ++++++++++++++++-- test/inductor/test_torchinductor.py | 2 +- test/inductor/test_torchinductor_opinfo.py | 27 +++ .../collect_metadata_analysis.py | 11 +- .../_aot_autograd/input_output_analysis.py | 5 +- .../_aot_autograd/runtime_wrappers.py | 42 ++++- torch/_functorch/_aot_autograd/schemas.py | 44 ++++- .../_aot_autograd/subclass_utils.py | 6 +- torch/nested/_internal/nested_tensor.py | 8 + .../_internal/common_methods_invocations.py | 20 +- 11 files changed, 322 insertions(+), 59 deletions(-) diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 4eae0c877e2..52baa610dcf 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -131,23 +131,23 @@ class _multiply_invoke(torch.nn.Module): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"): + def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)"): l_inputs_ = L_inputs_ l_sizes_0_ = L_sizes_0_ - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None - getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None + getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None - aot1_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None + aot1_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad: "f32[s0]" = torch.clone(aot1_tangents_1) + new_grad: "f32[2]" = torch.clone(aot1_tangents_1) - result: "f32[s0]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None + result: "f32[2]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, ) @@ -156,23 +156,23 @@ class GraphModule(torch.nn.Module): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"): + def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)"): l_inputs_ = L_inputs_ l_sizes_0_ = L_sizes_0_ - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None - getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None + getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None - aot3_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None + aot3_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad: "f32[s0]" = torch.clone(aot3_tangents_1) + new_grad: "f32[2]" = torch.clone(aot3_tangents_1) - result: "f32[s0]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None + result: "f32[2]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, ) @@ -233,26 +233,26 @@ class GraphModule(torch.nn.Module): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"): + def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"): l_inputs_ = L_inputs_ l_sizes_0_ = L_sizes_0_ l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None - getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None + getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None - aot0_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None + aot0_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad: "f32[s0]" = torch.clone(aot0_tangents_1) + new_grad: "f32[2]" = torch.clone(aot0_tangents_1) add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None - result: "f32[s0]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None + result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1, add) """, ) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 0b9347b5391..bc71e7eb9e1 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -52,6 +52,7 @@ from torch._inductor.output_code import MockFXGraphCacheOutput from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv +from torch.nn.attention.flex_attention import flex_attention from torch.nn.utils.rnn import PackedSequence from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, @@ -5802,45 +5803,69 @@ metadata incorrectly. class GradsNoForceContiguousContextManager(ContextDecorator): def __enter__(self): # flake8: noqa: TOR901 - self.lib = torch.library.Library("_mylib", "FRAGMENT") + self.lib = torch.library.Library("_test_aotdispatch_lib", "FRAGMENT") self.d = { torch.channels_last: 0, torch.contiguous_format: 0, } + self.tangent_strides = [] - self.lib.define("foo(Tensor x) -> Tensor") - self.lib.define("foo2(Tensor x) -> Tensor") + self.lib.define("log_tangents_memory_format(Tensor x) -> Tensor") + self.lib.define("log_tangents_memory_format_log(Tensor x) -> Tensor") - def foo_impl(a): + def log_tangents_memory_format_impl(a): return a.clone() - def foo_meta(a): + def log_tangents_memory_format_meta(a): return a.clone() - def foo2_impl(x): + def log_tangents_memory_format_log_impl(x): self.d[torch._prims_common.suggest_memory_format(x)] += 1 + self.tangent_strides.append(x.stride()) return x.clone() - def foo2_meta(a): + def log_tangents_memory_format_log_meta(a): return a.clone() for backend in ["CPU", "CUDA"]: - self.lib.impl("foo", foo_impl, backend) - self.lib.impl("foo2", foo2_impl, backend) + self.lib.impl( + "log_tangents_memory_format", log_tangents_memory_format_impl, backend + ) + self.lib.impl( + "log_tangents_memory_format_log", + log_tangents_memory_format_log_impl, + backend, + ) - self.lib.impl("foo", foo_meta, "Meta") - self.lib.impl("foo2", foo2_meta, "Meta") + self.lib.impl( + "log_tangents_memory_format", log_tangents_memory_format_meta, "Meta" + ) + self.lib.impl( + "log_tangents_memory_format_log", + log_tangents_memory_format_log_meta, + "Meta", + ) - def foo_bwd(ctx, grad): - torch.ops._mylib.foo2(grad) + def log_tangents_memory_format_bwd(ctx, grad): + torch.ops._test_aotdispatch_lib.log_tangents_memory_format_log(grad) return grad.clone() - torch.library.register_autograd("_mylib::foo", foo_bwd, lib=self.lib) + torch.library.register_autograd( + "_test_aotdispatch_lib::log_tangents_memory_format", + log_tangents_memory_format_bwd, + lib=self.lib, + ) from torch._higher_order_ops.effects import _EffectType, _register_effectful_op - _register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED) - _register_effectful_op(torch.ops._mylib.foo2.default, _EffectType.ORDERED) + _register_effectful_op( + torch.ops._test_aotdispatch_lib.log_tangents_memory_format.default, + _EffectType.ORDERED, + ) + _register_effectful_op( + torch.ops._test_aotdispatch_lib.log_tangents_memory_format_log.default, + _EffectType.ORDERED, + ) return self @@ -6097,7 +6122,7 @@ class TestAOTModuleSimplified(AOTTestCase): z = y + 3 y.mul_(2) r = self.conv(x) - r = torch.ops._mylib.foo(r) + r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r) return ( r, r.transpose(0, 1), @@ -6143,7 +6168,7 @@ class TestAOTModuleSimplified(AOTTestCase): def forward(self, x, y): r = self.conv(x) - r = torch.ops._mylib.foo(r) + r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r) return r, y + 1 m = M() @@ -6186,7 +6211,7 @@ class TestAOTModuleSimplified(AOTTestCase): def forward(self, x): r = self.conv(x) - r = torch.ops._mylib.foo(r) + r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r) return r m = M() @@ -6466,6 +6491,116 @@ metadata incorrectly. _test_fn(fn_mutation) _test_fn(fn_inplace, check_backward=False) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + @parametrize("dynamic_shapes", [True, False]) + @parametrize("test_subclasses", [True, False]) + @parametrize("device", ["cuda", "cpu"]) + def test_noncontig_nonmemformat_tangents( + self, dynamic_shapes, test_subclasses, device + ): + B = 2 + T = 4 + E = 6 + + def fn(x): + x = x + 1 + return x.transpose(1, 2) + + def _inp_dense(): + t = torch.randn(B, T, E, device=device, requires_grad=True) + if dynamic_shapes: + for i in range(t.ndim): + torch._dynamo.mark_dynamic(t, i) + return t + + def _inp_sc(): + return TwoTensor(_inp_dense(), _inp_dense()) + + _inp = _inp_dense if not test_subclasses else _inp_sc + + comp_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + + def _tg3(y): + t = torch.randn( + 2 * y.shape, dtype=y.dtype, layout=y.layout, device=y.device + ) + return t.as_strided(y.shape, tuple(s * 2 for s in y.stride())) + + TEST_CASES = [ + (_inp, lambda y: torch.ones(y.shape, dtype=y.dtype, device=y.device)), + # Memory overlap, dense tangent + ( + _inp, + lambda y: torch.tensor([1], dtype=y.dtype, device=y.device).as_strided( + y.shape, (0,) * y.ndim + ), + ), + # No memory overlap, not-dense tangent + (_inp, _tg3), + ] + + for i, (inp_fn, tg_fn) in enumerate(TEST_CASES): + ref_x = inp_fn() + x = ref_x.detach().clone().requires_grad_() + + ref_y = fn(ref_x) + + y = comp_fn(x) + self.assertEqual(ref_y, y) + + ref_tg = ( + tg_fn(ref_y) + if not test_subclasses + else TwoTensor(tg_fn(ref_y), tg_fn(ref_y)) + ) + tg = ref_tg.clone() + + ref_y.backward(ref_tg) + y.backward(tg) + + self.assertEqual(ref_x.grad, x.grad) + + def test_flex_attn_noncontiguous_tangents(self): + with GradsNoForceContiguousContextManager() as ctx: + E = 16 # embedding dim + H = 4 # number of heads + + @torch.compile(backend="aot_eager", fullgraph=True) + def attn_fn(q, k, v): + y = flex_attention(query=q, key=k, value=v) + y = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(y) + return y + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.c_attn = torch.nn.Linear(E, 3 * E) + + def forward(self, x): + B, T, E = x.size() + q, k, v = self.c_attn(x).split(E, dim=2) + k = k.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs) + + y = attn_fn(q, k, v) + + return y.transpose(1, 2).contiguous().view(B, T, E) + + m = M() + B = 1 + T = 8 + + def _inp(): + return torch.randn(B, T, E, requires_grad=True) + + x = _inp() + y = m(x) + y.backward(torch.ones_like(y).contiguous()) + + self.assertEqual(1, len(ctx.tangent_strides)) + self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0]) + # entries in here don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) @@ -6745,6 +6880,7 @@ class TestEagerFusionModuleInfo(AOTTestCase): instantiate_parametrized_tests(TestAOTAutograd) +instantiate_parametrized_tests(TestAOTModuleSimplified) only_for = "cpu" instantiate_device_type_tests( TestPythonKey, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fde903c9fd4..e500b548be9 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -551,7 +551,7 @@ def check_model( # generate random unit norm gradients grads = [ - torch.rand(r.shape, device=r.device, dtype=r.dtype) + torch.randn_like(r) for r in correct_flat if isinstance(r, torch.Tensor) and r.requires_grad ] diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 0e06c5af79b..62d43ea1493 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -466,6 +466,33 @@ inductor_override_kwargs["cuda"] = { ("index_reduce.amax", f32): {"check_gradient": False}, ("index_reduce.amax", f16): {"check_gradient": False}, ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, + ("_unsafe_masked_index", f16): { + "reference_in_float": True, + "atol": 3e-4, + "rtol": 2e-3, + }, + ("nn.functional.interpolate.linear", f16): {"reference_in_float": True}, + ("nn.functional.prelu", f16): { + "reference_in_float": True, + "atol": 1e-3, + "rtol": 4e-3, + }, + ("addmm", f16): {"reference_in_float": True}, + ("logaddexp", f16): {"reference_in_float": True}, + ("std_mean", f16): {"reference_in_float": True}, + ("hypot", f16): {"reference_in_float": True, "atol": 3e-4, "rtol": 2e-3}, + ("cummin", f16): {"reference_in_float": True, "atol": 5e-5, "rtol": 2e-3}, + ("unfold_copy", f16): {"reference_in_float": True, "atol": 2e-5, "rtol": 1e-2}, + ("nn.functional.upsample_bilinear", f16): { + "reference_in_float": True, + "atol": 1e-4, + "rtol": 2e-3, + }, + ("nn.functional.embedding_bag", f16): { + "reference_in_float": True, + "atol": 1e-4, + "rtol": 1e-2, + }, } inductor_override_kwargs["xpu"] = { diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 87d5411c05d..e128901d39a 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -41,6 +41,7 @@ from .functional_utils import ( from .schemas import ( FunctionalTensorMetadataEq, InputAliasInfo, + MemoryFormatMeta, MutationType, OutputAliasInfo, OutputType, @@ -73,14 +74,14 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor): out = x.detach() - suggest_memory_format = torch._prims_common.suggest_memory_format is_subclass = is_traceable_wrapper_subclass(out) - memory_format = suggest_memory_format(out) + memory_format = MemoryFormatMeta.from_tensor(out) - was = out - out = out.contiguous(memory_format=memory_format) - updated = out is not was + if memory_format.memory_format is not None: + was = out + out = out.contiguous(memory_format=memory_format.memory_format) + updated = was is not out # For subclass we keep memory format of outer strides at the beggining of the list out_memory_format = [memory_format] if is_subclass else memory_format diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 7dc4112a101..62470e3b683 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -28,6 +28,7 @@ from .schemas import ( BackwardSignature, GraphSignature, InputAliasInfo, + MemoryFormatMeta, OutputAliasInfo, OutputType, ViewAndMutationMeta, @@ -61,7 +62,9 @@ def remove_dupe_metadata( assert m.subclass_tangent_meta is not None subclass_tangent_meta = [ - PlainTensorMeta(0, memory_format=torch.contiguous_format) + PlainTensorMeta( + 0, memory_format=MemoryFormatMeta(memory_format=torch.contiguous_format) + ) ] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:] return ViewAndMutationMeta( diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 539c1a91052..f5081324527 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -45,6 +45,7 @@ from .logging_utils import describe_input, format_guard_bug_msg, track_graph_com from .schemas import ( AOTConfig, InputAliasInfo, + MemoryFormatMeta, MutationType, OutputType, PlainTensorMeta, @@ -1752,6 +1753,39 @@ def _backward_epilogue_functional( return out +def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryFormatMeta): + if memory_format.memory_format is not None: + # Coerce to torch.memory_format + if not x.is_contiguous(memory_format=memory_format.memory_format): + x = x.contiguous(memory_format=memory_format.memory_format) + return x + + expected_size = memory_format.size + assert expected_size is not None + expected_stride = memory_format.stride + assert expected_stride is not None + # Expected size and stride are static ints + # ok to use == to compare runtime tensor strides and shapes + + if x.shape == expected_size and x.stride() == expected_stride: + # Runtime tangent size and stride are the same as expected, no need to coerce + return x + + # Empty_strided creates a raw Tensor. + # We are guranteed that only raw Tensors has expected size and stride. + # Subclasses have only expected memory_format. + restrided = torch.empty_strided( + size=expected_size, + stride=expected_stride, + dtype=x.dtype, + device=x.device, + layout=x.layout, + requires_grad=x.requires_grad, + ) + restrided.copy_(x) + return restrided + + # This is wrapped in a class just for namespacing purposes # No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly class AOTDispatchAutograd: @@ -1761,8 +1795,8 @@ class AOTDispatchAutograd: return x, [x] if isinstance(x, FakeTensor): - if not x.is_contiguous(memory_format=meta.memory_format): - x = x.contiguous(memory_format=meta.memory_format) + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) return x, [x] expected_type: Optional[type] = torch.Tensor @@ -1820,8 +1854,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa ) # Coerce to expected memory format - if not x.is_contiguous(memory_format=meta.memory_format): - x = x.contiguous(memory_format=meta.memory_format) + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) if not is_traceable_wrapper_subclass(x): return x, [x] diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 6259d082e2a..c0ef87fd50e 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -7,7 +7,8 @@ input/output types, metadata, config, function signatures etc. import collections import dataclasses import functools -from collections.abc import Iterable +import itertools +from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, NewType, Optional, Union @@ -155,10 +156,47 @@ class InputAliasInfo: return MutationType.MUTATED_OUT_GRAPH +@dataclass +class MemoryFormatMeta: + # For static shapes we assume tangents have the same strideness as outputs + size: Optional[Sequence[int]] = None + stride: Optional[Sequence[int]] = None + + # For dynamic shapes we assume the same memory format: contiguous, channels_last etc. + memory_format: Optional[torch.memory_format] = None + + @staticmethod + def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]: + # We only memorize expected memory format for + # 1. Traceable wrapper subclasses + # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. + # 2. Dynamic shape tensors + # Support for symbolic shapes is not implemented yet. + use_memory_format: bool = is_traceable_wrapper_subclass(t) + if not use_memory_format: + is_static_shape = True + for s in itertools.chain(t.shape, t.stride()): + if not isinstance(s, int): + is_static_shape = False + break + + use_memory_format = not is_static_shape + + if use_memory_format: + return MemoryFormatMeta( + memory_format=torch._prims_common.suggest_memory_format(t), + ) + + return MemoryFormatMeta( + size=t.size(), + stride=t.stride(), + ) + + @dataclass class PlainTensorMeta: unwrapped_idx: int - memory_format: Optional[torch.memory_format] = None + memory_format: Optional[MemoryFormatMeta] = None @dataclass @@ -204,7 +242,7 @@ class SubclassCreationMeta: # Used at runtime to determine the subclass type, so we don't need to save the original subclass original_subclass_type: Optional[type] = None - memory_format: Optional[torch.memory_format] = None + memory_format: Optional[MemoryFormatMeta] = None def compute_outer_size_and_stride( self, diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index d352f43da37..79e27daf7c1 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -46,16 +46,16 @@ def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: return any_subclass_args or any_subclass_outputs -suggest_memory_format = torch._prims_common.suggest_memory_format +from .schemas import MemoryFormatMeta def maybe_suggest_memory_format( t, with_memory_format: bool -) -> Optional[torch.memory_format]: +) -> Optional[MemoryFormatMeta]: if not with_memory_format: return None - return suggest_memory_format(t) + return MemoryFormatMeta.from_tensor(t) def get_subclass_typing_container( diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 958ee96c499..1ec4f34e404 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -200,10 +200,18 @@ class NestedTensor(torch.Tensor): def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: return self._metadata_cache.get("max_seqlen", None) + @_max_seqlen_tensor.setter + def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None: + self._metadata_cache["max_seqlen"] = val + @property def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: return self._metadata_cache.get("min_seqlen", None) + @_min_seqlen_tensor.setter + def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None: + self._metadata_cache["min_seqlen"] = val + # These are old private @property accessors that are kept around for internal BC # reasons. TODO: Remove these! @property diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 95d717b3be3..4e53c489d9b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12696,10 +12696,26 @@ op_db: list[OpInfo] = [ check_batched_gradgrad=True, sample_inputs_func=sample_inputs_linalg_cholesky_inverse, gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal, - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestCommon', device_type='cpu', + ), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestEagerFusionOpInfo', device_type='cpu', + ), + ], skips=( # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),) - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),), + ), OpInfo('cholesky_solve', op=torch.cholesky_solve, dtypes=floating_and_complex_types(),