[aotd] Guess tangents stride as output strides (#144579)

AOTDispatch  doing AOT backward graph preparation does not know real tangents that user will specify when runs backward.

AOTD guesses the tangents. Before - we guessed that memory format of tangents will be as memory format of corresponding outputs. And if specified tangents at runtime are not the same memory format as we guessed during compilation, AOTD does coercion (copy) to guessed memory_format

But as Horace found, there are popular use cases, where the outputs of compiled region will be in specific memory_format. E.g. in 4D tensor transposing dims 1 and 2.

https://github.com/karpathy/nanoGPT/blob/master/model.py#L57

This PR changes the logic, that AOTD expects the same "strideness" of tangents as outputs. As a result it will avoid coercion for the case of transposed dims.

Limitations:
We keep guessing memory_format for:
1/ Dynamic shapes (needs more changes)
2/ Tensor subclasses (needs more changes)

Other changes:
test_torchinductor was always creating contiguous tangents via `torch.randn()`, changing them to be `torch.randn_like()` to compare computation with the same strideness.

(E.g. for cuda float16 strideness affects numerics for fft ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144579
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2025-03-20 04:01:28 -07:00 committed by PyTorch MergeBot
parent 9b1127437e
commit 2c4bc65366
11 changed files with 322 additions and 59 deletions

View File

@ -131,23 +131,23 @@ class _multiply_invoke(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
aot1_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
aot1_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad: "f32[s0]" = torch.clone(aot1_tangents_1)
new_grad: "f32[2]" = torch.clone(aot1_tangents_1)
result: "f32[s0]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
result: "f32[2]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
@ -156,23 +156,23 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
aot3_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
aot3_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad: "f32[s0]" = torch.clone(aot3_tangents_1)
new_grad: "f32[2]" = torch.clone(aot3_tangents_1)
result: "f32[s0]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
result: "f32[2]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
@ -233,26 +233,26 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
aot0_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
aot0_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad: "f32[s0]" = torch.clone(aot0_tangents_1)
new_grad: "f32[2]" = torch.clone(aot0_tangents_1)
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[s0]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1, add)
""",
)

View File

@ -52,6 +52,7 @@ from torch._inductor.output_code import MockFXGraphCacheOutput
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.utils.rnn import PackedSequence
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
@ -5802,45 +5803,69 @@ metadata incorrectly.
class GradsNoForceContiguousContextManager(ContextDecorator):
def __enter__(self):
# flake8: noqa: TOR901
self.lib = torch.library.Library("_mylib", "FRAGMENT")
self.lib = torch.library.Library("_test_aotdispatch_lib", "FRAGMENT")
self.d = {
torch.channels_last: 0,
torch.contiguous_format: 0,
}
self.tangent_strides = []
self.lib.define("foo(Tensor x) -> Tensor")
self.lib.define("foo2(Tensor x) -> Tensor")
self.lib.define("log_tangents_memory_format(Tensor x) -> Tensor")
self.lib.define("log_tangents_memory_format_log(Tensor x) -> Tensor")
def foo_impl(a):
def log_tangents_memory_format_impl(a):
return a.clone()
def foo_meta(a):
def log_tangents_memory_format_meta(a):
return a.clone()
def foo2_impl(x):
def log_tangents_memory_format_log_impl(x):
self.d[torch._prims_common.suggest_memory_format(x)] += 1
self.tangent_strides.append(x.stride())
return x.clone()
def foo2_meta(a):
def log_tangents_memory_format_log_meta(a):
return a.clone()
for backend in ["CPU", "CUDA"]:
self.lib.impl("foo", foo_impl, backend)
self.lib.impl("foo2", foo2_impl, backend)
self.lib.impl(
"log_tangents_memory_format", log_tangents_memory_format_impl, backend
)
self.lib.impl(
"log_tangents_memory_format_log",
log_tangents_memory_format_log_impl,
backend,
)
self.lib.impl("foo", foo_meta, "Meta")
self.lib.impl("foo2", foo2_meta, "Meta")
self.lib.impl(
"log_tangents_memory_format", log_tangents_memory_format_meta, "Meta"
)
self.lib.impl(
"log_tangents_memory_format_log",
log_tangents_memory_format_log_meta,
"Meta",
)
def foo_bwd(ctx, grad):
torch.ops._mylib.foo2(grad)
def log_tangents_memory_format_bwd(ctx, grad):
torch.ops._test_aotdispatch_lib.log_tangents_memory_format_log(grad)
return grad.clone()
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=self.lib)
torch.library.register_autograd(
"_test_aotdispatch_lib::log_tangents_memory_format",
log_tangents_memory_format_bwd,
lib=self.lib,
)
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
_register_effectful_op(torch.ops._mylib.foo2.default, _EffectType.ORDERED)
_register_effectful_op(
torch.ops._test_aotdispatch_lib.log_tangents_memory_format.default,
_EffectType.ORDERED,
)
_register_effectful_op(
torch.ops._test_aotdispatch_lib.log_tangents_memory_format_log.default,
_EffectType.ORDERED,
)
return self
@ -6097,7 +6122,7 @@ class TestAOTModuleSimplified(AOTTestCase):
z = y + 3
y.mul_(2)
r = self.conv(x)
r = torch.ops._mylib.foo(r)
r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r)
return (
r,
r.transpose(0, 1),
@ -6143,7 +6168,7 @@ class TestAOTModuleSimplified(AOTTestCase):
def forward(self, x, y):
r = self.conv(x)
r = torch.ops._mylib.foo(r)
r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r)
return r, y + 1
m = M()
@ -6186,7 +6211,7 @@ class TestAOTModuleSimplified(AOTTestCase):
def forward(self, x):
r = self.conv(x)
r = torch.ops._mylib.foo(r)
r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r)
return r
m = M()
@ -6466,6 +6491,116 @@ metadata incorrectly.
_test_fn(fn_mutation)
_test_fn(fn_inplace, check_backward=False)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@parametrize("dynamic_shapes", [True, False])
@parametrize("test_subclasses", [True, False])
@parametrize("device", ["cuda", "cpu"])
def test_noncontig_nonmemformat_tangents(
self, dynamic_shapes, test_subclasses, device
):
B = 2
T = 4
E = 6
def fn(x):
x = x + 1
return x.transpose(1, 2)
def _inp_dense():
t = torch.randn(B, T, E, device=device, requires_grad=True)
if dynamic_shapes:
for i in range(t.ndim):
torch._dynamo.mark_dynamic(t, i)
return t
def _inp_sc():
return TwoTensor(_inp_dense(), _inp_dense())
_inp = _inp_dense if not test_subclasses else _inp_sc
comp_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
def _tg3(y):
t = torch.randn(
2 * y.shape, dtype=y.dtype, layout=y.layout, device=y.device
)
return t.as_strided(y.shape, tuple(s * 2 for s in y.stride()))
TEST_CASES = [
(_inp, lambda y: torch.ones(y.shape, dtype=y.dtype, device=y.device)),
# Memory overlap, dense tangent
(
_inp,
lambda y: torch.tensor([1], dtype=y.dtype, device=y.device).as_strided(
y.shape, (0,) * y.ndim
),
),
# No memory overlap, not-dense tangent
(_inp, _tg3),
]
for i, (inp_fn, tg_fn) in enumerate(TEST_CASES):
ref_x = inp_fn()
x = ref_x.detach().clone().requires_grad_()
ref_y = fn(ref_x)
y = comp_fn(x)
self.assertEqual(ref_y, y)
ref_tg = (
tg_fn(ref_y)
if not test_subclasses
else TwoTensor(tg_fn(ref_y), tg_fn(ref_y))
)
tg = ref_tg.clone()
ref_y.backward(ref_tg)
y.backward(tg)
self.assertEqual(ref_x.grad, x.grad)
def test_flex_attn_noncontiguous_tangents(self):
with GradsNoForceContiguousContextManager() as ctx:
E = 16 # embedding dim
H = 4 # number of heads
@torch.compile(backend="aot_eager", fullgraph=True)
def attn_fn(q, k, v):
y = flex_attention(query=q, key=k, value=v)
y = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(y)
return y
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.c_attn = torch.nn.Linear(E, 3 * E)
def forward(self, x):
B, T, E = x.size()
q, k, v = self.c_attn(x).split(E, dim=2)
k = k.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs)
y = attn_fn(q, k, v)
return y.transpose(1, 2).contiguous().view(B, T, E)
m = M()
B = 1
T = 8
def _inp():
return torch.randn(B, T, E, requires_grad=True)
x = _inp()
y = m(x)
y.backward(torch.ones_like(y).contiguous())
self.assertEqual(1, len(ctx.tangent_strides))
self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0])
# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
@ -6745,6 +6880,7 @@ class TestEagerFusionModuleInfo(AOTTestCase):
instantiate_parametrized_tests(TestAOTAutograd)
instantiate_parametrized_tests(TestAOTModuleSimplified)
only_for = "cpu"
instantiate_device_type_tests(
TestPythonKey,

View File

@ -551,7 +551,7 @@ def check_model(
# generate random unit norm gradients
grads = [
torch.rand(r.shape, device=r.device, dtype=r.dtype)
torch.randn_like(r)
for r in correct_flat
if isinstance(r, torch.Tensor) and r.requires_grad
]

View File

@ -466,6 +466,33 @@ inductor_override_kwargs["cuda"] = {
("index_reduce.amax", f32): {"check_gradient": False},
("index_reduce.amax", f16): {"check_gradient": False},
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
("_unsafe_masked_index", f16): {
"reference_in_float": True,
"atol": 3e-4,
"rtol": 2e-3,
},
("nn.functional.interpolate.linear", f16): {"reference_in_float": True},
("nn.functional.prelu", f16): {
"reference_in_float": True,
"atol": 1e-3,
"rtol": 4e-3,
},
("addmm", f16): {"reference_in_float": True},
("logaddexp", f16): {"reference_in_float": True},
("std_mean", f16): {"reference_in_float": True},
("hypot", f16): {"reference_in_float": True, "atol": 3e-4, "rtol": 2e-3},
("cummin", f16): {"reference_in_float": True, "atol": 5e-5, "rtol": 2e-3},
("unfold_copy", f16): {"reference_in_float": True, "atol": 2e-5, "rtol": 1e-2},
("nn.functional.upsample_bilinear", f16): {
"reference_in_float": True,
"atol": 1e-4,
"rtol": 2e-3,
},
("nn.functional.embedding_bag", f16): {
"reference_in_float": True,
"atol": 1e-4,
"rtol": 1e-2,
},
}
inductor_override_kwargs["xpu"] = {

View File

@ -41,6 +41,7 @@ from .functional_utils import (
from .schemas import (
FunctionalTensorMetadataEq,
InputAliasInfo,
MemoryFormatMeta,
MutationType,
OutputAliasInfo,
OutputType,
@ -73,14 +74,14 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
out = x.detach()
suggest_memory_format = torch._prims_common.suggest_memory_format
is_subclass = is_traceable_wrapper_subclass(out)
memory_format = suggest_memory_format(out)
memory_format = MemoryFormatMeta.from_tensor(out)
if memory_format.memory_format is not None:
was = out
out = out.contiguous(memory_format=memory_format)
updated = out is not was
out = out.contiguous(memory_format=memory_format.memory_format)
updated = was is not out
# For subclass we keep memory format of outer strides at the beggining of the list
out_memory_format = [memory_format] if is_subclass else memory_format

View File

@ -28,6 +28,7 @@ from .schemas import (
BackwardSignature,
GraphSignature,
InputAliasInfo,
MemoryFormatMeta,
OutputAliasInfo,
OutputType,
ViewAndMutationMeta,
@ -61,7 +62,9 @@ def remove_dupe_metadata(
assert m.subclass_tangent_meta is not None
subclass_tangent_meta = [
PlainTensorMeta(0, memory_format=torch.contiguous_format)
PlainTensorMeta(
0, memory_format=MemoryFormatMeta(memory_format=torch.contiguous_format)
)
] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:]
return ViewAndMutationMeta(

View File

@ -45,6 +45,7 @@ from .logging_utils import describe_input, format_guard_bug_msg, track_graph_com
from .schemas import (
AOTConfig,
InputAliasInfo,
MemoryFormatMeta,
MutationType,
OutputType,
PlainTensorMeta,
@ -1752,6 +1753,39 @@ def _backward_epilogue_functional(
return out
def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryFormatMeta):
if memory_format.memory_format is not None:
# Coerce to torch.memory_format
if not x.is_contiguous(memory_format=memory_format.memory_format):
x = x.contiguous(memory_format=memory_format.memory_format)
return x
expected_size = memory_format.size
assert expected_size is not None
expected_stride = memory_format.stride
assert expected_stride is not None
# Expected size and stride are static ints
# ok to use == to compare runtime tensor strides and shapes
if x.shape == expected_size and x.stride() == expected_stride:
# Runtime tangent size and stride are the same as expected, no need to coerce
return x
# Empty_strided creates a raw Tensor.
# We are guranteed that only raw Tensors has expected size and stride.
# Subclasses have only expected memory_format.
restrided = torch.empty_strided(
size=expected_size,
stride=expected_stride,
dtype=x.dtype,
device=x.device,
layout=x.layout,
requires_grad=x.requires_grad,
)
restrided.copy_(x)
return restrided
# This is wrapped in a class just for namespacing purposes
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
class AOTDispatchAutograd:
@ -1761,8 +1795,8 @@ class AOTDispatchAutograd:
return x, [x]
if isinstance(x, FakeTensor):
if not x.is_contiguous(memory_format=meta.memory_format):
x = x.contiguous(memory_format=meta.memory_format)
assert meta.memory_format
x = coerce_to_expected_memory_format(x, meta.memory_format)
return x, [x]
expected_type: Optional[type] = torch.Tensor
@ -1820,8 +1854,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
)
# Coerce to expected memory format
if not x.is_contiguous(memory_format=meta.memory_format):
x = x.contiguous(memory_format=meta.memory_format)
assert meta.memory_format
x = coerce_to_expected_memory_format(x, meta.memory_format)
if not is_traceable_wrapper_subclass(x):
return x, [x]

View File

@ -7,7 +7,8 @@ input/output types, metadata, config, function signatures etc.
import collections
import dataclasses
import functools
from collections.abc import Iterable
import itertools
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, NewType, Optional, Union
@ -155,10 +156,47 @@ class InputAliasInfo:
return MutationType.MUTATED_OUT_GRAPH
@dataclass
class MemoryFormatMeta:
# For static shapes we assume tangents have the same strideness as outputs
size: Optional[Sequence[int]] = None
stride: Optional[Sequence[int]] = None
# For dynamic shapes we assume the same memory format: contiguous, channels_last etc.
memory_format: Optional[torch.memory_format] = None
@staticmethod
def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]:
# We only memorize expected memory format for
# 1. Traceable wrapper subclasses
# We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors.
# 2. Dynamic shape tensors
# Support for symbolic shapes is not implemented yet.
use_memory_format: bool = is_traceable_wrapper_subclass(t)
if not use_memory_format:
is_static_shape = True
for s in itertools.chain(t.shape, t.stride()):
if not isinstance(s, int):
is_static_shape = False
break
use_memory_format = not is_static_shape
if use_memory_format:
return MemoryFormatMeta(
memory_format=torch._prims_common.suggest_memory_format(t),
)
return MemoryFormatMeta(
size=t.size(),
stride=t.stride(),
)
@dataclass
class PlainTensorMeta:
unwrapped_idx: int
memory_format: Optional[torch.memory_format] = None
memory_format: Optional[MemoryFormatMeta] = None
@dataclass
@ -204,7 +242,7 @@ class SubclassCreationMeta:
# Used at runtime to determine the subclass type, so we don't need to save the original subclass
original_subclass_type: Optional[type] = None
memory_format: Optional[torch.memory_format] = None
memory_format: Optional[MemoryFormatMeta] = None
def compute_outer_size_and_stride(
self,

View File

@ -46,16 +46,16 @@ def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
return any_subclass_args or any_subclass_outputs
suggest_memory_format = torch._prims_common.suggest_memory_format
from .schemas import MemoryFormatMeta
def maybe_suggest_memory_format(
t, with_memory_format: bool
) -> Optional[torch.memory_format]:
) -> Optional[MemoryFormatMeta]:
if not with_memory_format:
return None
return suggest_memory_format(t)
return MemoryFormatMeta.from_tensor(t)
def get_subclass_typing_container(

View File

@ -200,10 +200,18 @@ class NestedTensor(torch.Tensor):
def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
return self._metadata_cache.get("max_seqlen", None)
@_max_seqlen_tensor.setter
def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
self._metadata_cache["max_seqlen"] = val
@property
def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
return self._metadata_cache.get("min_seqlen", None)
@_min_seqlen_tensor.setter
def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
self._metadata_cache["min_seqlen"] = val
# These are old private @property accessors that are kept around for internal BC
# reasons. TODO: Remove these!
@property

View File

@ -12696,10 +12696,26 @@ op_db: list[OpInfo] = [
check_batched_gradgrad=True,
sample_inputs_func=sample_inputs_linalg_cholesky_inverse,
gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
decorators=[
skipCUDAIfNoMagma,
skipCPUIfNoLapack,
DecorateInfo(
toleranceOverride({
torch.float32: tol(atol=5e-03, rtol=1e-04)
}),
'TestCommon', device_type='cpu',
),
DecorateInfo(
toleranceOverride({
torch.float32: tol(atol=5e-03, rtol=1e-04)
}),
'TestEagerFusionOpInfo', device_type='cpu',
),
],
skips=(
# Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),)
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),)),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),),
),
OpInfo('cholesky_solve',
op=torch.cholesky_solve,
dtypes=floating_and_complex_types(),