expose mem-eff to autograd (#110495)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110495
Approved by: https://github.com/jbschlosser
This commit is contained in:
drisspg 2023-11-10 17:22:56 -08:00 committed by PyTorch MergeBot
parent 3afb4e5cf7
commit c46fc46dba
15 changed files with 254 additions and 30 deletions

View File

@ -14491,13 +14491,13 @@
CUDA: _flash_attention_backward
# Returns ouput, logsumexp if compute_logsumexp
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
variants: function
dispatch:
CUDA: _efficient_attention_forward
tags: nondeterministic_seeded
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
device_check: NoCheck
variants: function
dispatch:

View File

@ -307,7 +307,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
: sdp::CustomMaskType::NoCustomMask;
// See Note [Seed and Offset] for description of seed and offset
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
// Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values.
auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
query_buffer_reshaped.unsqueeze(0),
key_buffer_reshaped.unsqueeze(0),
value_buffer_reshaped.unsqueeze(0),

View File

@ -742,7 +742,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
q_t,
k_t,
v_t,
@ -874,7 +874,7 @@ _flash_attention_forward(
Tensor());
}
std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_attention_forward(
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
@ -915,8 +915,8 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
// Embedding per head
TORCH_CHECK(query.size(3) == key.size(3));
// TODO_DRISS we should return max_seqlen_k;
int64_t max_seqlen_q, max_seqlen_k;
int64_t max_seqlen_q = 0, max_seqlen_k = 0;
TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
if (seqstart_q.has_value()) {
TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
@ -1164,10 +1164,12 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
std::move(res),
std::move(logsumexp),
std::move(seed_t),
std::move(offset_t));
std::move(offset_t),
max_seqlen_q,
max_seqlen_k);
#endif
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0);
}
Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){

View File

@ -134,14 +134,14 @@ _efficient_attention_backward(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const c10::optional<at::Tensor>& bias, // additive attention bias
const c10::optional<at::Tensor>& kernel_bias, // additive attention bias
const at::Tensor& out,
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
// position of the first query token for batch $b
const c10::optional<at::Tensor>& cu_seqlens_q,
const c10::optional<at::Tensor>& cu_seqlens_q_dummy,
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
// position of the first key token for batch $b
const c10::optional<at::Tensor>& cu_seqlens_k,
const c10::optional<at::Tensor>& cu_seqlens_k_dummy,
// (Mode 1MHK only) Maximum sequence length across batches
int64_t max_seqlen_q,
// (Mode 1MHK only) Maximum sequence length across batches
@ -158,6 +158,14 @@ _efficient_attention_backward(
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}
// This path is used when we directly call _efficient_attention_forward
// from python.
// This is needed because SaveVariable automatically converts
// c10::optional to undefined tensor
c10::optional<Tensor> bias, cu_seqlens_q, cu_seqlens_k;
bias = kernel_bias.has_value() && !kernel_bias->defined() ? c10::nullopt : kernel_bias;
cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? c10::nullopt : cu_seqlens_q_dummy;
cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? c10::nullopt : cu_seqlens_k_dummy;
// ndim
TORCH_CHECK(query.dim() == grad_out_.dim());

View File

@ -264,11 +264,8 @@ ALLOW_LIST = [
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
("aten::_scaled_dot_product_attention", datetime.date(2023, 8, 1)),
("aten::_chunk_grad_outputs_efficient_attention", datetime.date(2023, 8, 1)),
("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 9, 30)),
("aten::_flash_attention_forward", datetime.date(2023, 9, 30)),
("aten::_flash_attention_backward", datetime.date(2023, 9, 30)),
("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)),
("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)),
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),

View File

@ -392,6 +392,7 @@ class TestOperators(TestCase):
# query: last dimension must be contiguous
# Fused attention kernels require last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention'),
xfail("torch.ops.aten._efficient_attention_forward"),
}))
@opsToleranceOverride('TestOperators', 'test_grad', (
tol1('nn.functional.binary_cross_entropy_with_logits',
@ -474,6 +475,7 @@ class TestOperators(TestCase):
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'),
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280
@ -601,6 +603,7 @@ class TestOperators(TestCase):
# RuntimeError: query: last dimension must be contiguous
# The fused attention kernels require the last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'),
# BUG
# AssertionError: Tensor-likes are not close!
xfail('as_strided'),
@ -679,6 +682,7 @@ class TestOperators(TestCase):
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'),
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 1 / 15 (6.7%)
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
@ -774,6 +778,7 @@ class TestOperators(TestCase):
skip("nn.functional.fractional_max_pool2d"), # calls random op
skip("nn.functional.fractional_max_pool3d"), # calls random op
xfail('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('nn.functional.multi_head_attention_forward'), # randomness
# It looks like you're either (1) calling .item() on a Tensor or
# (2) attempting to use a Tensor in some data-dependent control flow or
@ -888,6 +893,7 @@ class TestOperators(TestCase):
skip('nn.functional.dropout3d', ''), # randomness
skip('nn.functional.alpha_dropout'), # randomness
skip('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
skip('nn.functional.multi_head_attention_forward'), # randomness
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
xfail('nn.functional.fractional_max_pool2d'), # random
@ -982,6 +988,7 @@ class TestOperators(TestCase):
skip('nn.functional.dropout2d', ''),
skip('nn.functional.dropout3d', ''),
skip('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
skip('nn.functional.multi_head_attention_forward'), # randomness
skip('nn.functional.alpha_dropout'), # randomness
skip('nn.functional.feature_alpha_dropout', 'without_train'),
@ -1253,6 +1260,7 @@ class TestOperators(TestCase):
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
skip('nn.functional.multi_head_attention_forward'), # randomness
skip('nn.functional.alpha_dropout'), # randomness
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
@ -1376,6 +1384,7 @@ class TestOperators(TestCase):
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
@ -1500,6 +1509,7 @@ class TestOperators(TestCase):
xfail('nn.functional.dropout3d'), # calls random op
xfail('nn.functional.dropout'), # calls random op
xfail('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('nn.functional.multi_head_attention_forward'), # randomness
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
xfail('nn.functional.alpha_dropout'), # calls randomn op
@ -1768,6 +1778,7 @@ class TestOperators(TestCase):
xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call
xfail('nn.functional.max_unpool2d'), # contiguous call
xfail('to_sparse'), # dispatch key issue
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
# https://github.com/pytorch/pytorch/issues/96560
decorate('xlogy', decorator=skipIfRocm),

View File

@ -3604,6 +3604,8 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('addcmul'),
xfail('clamp'),
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
# TypeError: expected Tensor as element 0 in argument 0, but got float
xfail('item'),
}))
@ -3660,6 +3662,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
xfail('nn.functional.scaled_dot_product_attention'), # randomness
xfail('nn.functional.multi_head_attention_forward'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('resize_'),
xfail('view_as_complex'),
xfail('matrix_exp'),

View File

@ -234,6 +234,7 @@ inductor_expected_failures_single_sample["cuda"] = {
"to_sparse": {f16, f32, f64},
"pca_lowrank": {f32, f64},
"svd_lowrank": {f32, f64},
"torch.ops.aten._efficient_attention_forward": {f16, bf16, f32},
}

View File

@ -2783,6 +2783,10 @@
# output_differentiability: [True, False, False, False, False, False, False, False]
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
output_differentiability: [True, False, False, False, False, False]
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
# fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))

View File

@ -2082,7 +2082,8 @@ make_fallback(
sdpa_constraint,
warn=False,
)
make_fallback(torch.ops.aten._efficient_attention_forward.default)
make_fallback(torch.ops.aten._efficient_attention_backward.default)
make_fallback(aten.sort)
make_fallback(aten.sort.stable)
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)

View File

@ -5188,6 +5188,90 @@ def meta__scaled_dot_product_efficient_backward(
return grad_q, grad_k, grad_v, grad_bias
@register_meta(
[
aten._efficient_attention_forward,
]
)
def meta__efficient_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
cu_seqlens_q: Optional[Tensor],
cu_seqlens_k: Optional[Tensor],
max_seqlen_q: Optional[int],
dropout_p: float,
custom_mask_type: int,
compute_log_sumexp: bool = False,
scale: Optional[float] = None,
causal_diagonal: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
):
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device=query.device,
)
# See Note [Seed and Offset]:
seed = torch.empty((), dtype=torch.long, device="meta")
offset = torch.empty((), dtype=torch.long, device="meta")
return res, logsum_exp, seed, offset, M, N
@register_meta(
[
aten._efficient_attention_backward,
]
)
def meta__efficient_attention_backward(
grad_out: Tensor,
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
cu_seqlens_q: Optional[Tensor],
cu_seqlens_k: Optional[Tensor],
max_seqlen_q: int,
max_seqlen_k: int,
logsumexp: Tensor,
dropout_p: float,
philox_seed: Tensor,
philox_offset: Tensor,
custom_mask_type: int,
bias_requires_grad: bool,
scale: Optional[float] = None,
num_splits_key: Optional[int] = None,
):
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
if bias is not None:
assert bias is not None
lastDim = bias.size(-1)
lastDimAligned = 16 * ((lastDim + 15) // 16)
new_sizes = list(bias.size())
new_sizes[-1] = lastDimAligned
grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
else:
grad_bias = torch.empty((), device=query.device)
return grad_query, grad_key, grad_value, grad_bias
@register_meta([aten._scaled_mm.default])
def meta_scaled_mm(
self: torch.Tensor,

View File

@ -45,6 +45,25 @@ def output_alias_each_other(outputs):
return False
def is_sdpa_error(func, idx, e):
if (
func is aten._scaled_dot_product_flash_attention.default
and idx in (6, 7)
and "Devices" in repr(e)
):
return True
if (
(
func is aten._scaled_dot_product_efficient_attention.default
or func is aten._efficient_attention_forward.default
)
and idx in (2, 3)
and "Devices" in repr(e)
):
return True
return False
class CrossRefFakeMode(TorchDispatchMode):
def __init__(
self,
@ -155,17 +174,7 @@ class CrossRefFakeMode(TorchDispatchMode):
allow_rhs_unbacked=True,
)
except Exception as e:
if (
func is aten._scaled_dot_product_flash_attention.default
and idx in (6, 7)
and "Devices" in repr(e)
):
continue
if (
func is aten._scaled_dot_product_efficient_attention.default
and idx in (2, 3)
and "Devices" in repr(e)
):
if is_sdpa_error(func, idx, e):
continue
error_message = (
f"{context} mismatched tensor metadata: {e}"

View File

@ -76,6 +76,10 @@ _all_types_and_half = _all_types + (torch.half,)
def all_types_and_half():
return _all_types_and_half
def custom_types(*dtypes):
"""Create a list of arbitrary dtypes"""
return _empty_types + _validate_dtypes(*dtypes)
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.

View File

@ -18,7 +18,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
all_types, empty_types, complex_types_and, integral_types
all_types, empty_types, complex_types_and, integral_types, custom_types
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -8242,6 +8242,78 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
yield from samples
def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, num_heads, head_dim = 4, 4, 8
seq_q = 11
seq_kv = 32
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
samples = []
mask_types = [1, 2] # UpperLeft, LowerRight
scales = [None, 1.0]
for qkv_shape, is_causal, dropout_p, mask_type, scale in product(
qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales):
shape_q, shape_kv = qkv_shape
samples.append(SampleInput(
make(shape_q).transpose(1, 2),
make(shape_kv).transpose(1, 2),
make(shape_kv).transpose(1, 2),
bias=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
dropout_p=dropout_p,
custom_mask_type=mask_type,
compute_log_sumexp=requires_grad,
scale=scale,
causal_diagonal=None,
seqlen_k=None
))
# Add non standard shapes
diff_v_head_dim = SampleInput(
make((batch, seq_q, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim + 8)),
bias=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
dropout_p=dropout_p,
custom_mask_type=0, # No Mask
compute_log_sumexp=requires_grad,
scale=None,
causal_diagonal=None,
seqlen_k=None
)
# Add an attn_mask
samples.append(
SampleInput(
make((batch, seq_q, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim)),
bias=make(batch, num_heads, seq_q, seq_kv),
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
dropout_p=dropout_p,
custom_mask_type=0, # No Mask
compute_log_sumexp=requires_grad,
scale=None,
causal_diagonal=None,
seqlen_k=None
)
)
yield from samples
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -14172,6 +14244,31 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
),
OpInfo(
'torch.ops.aten._efficient_attention_forward',
sample_inputs_func=sample_inputs_efficient_attention_forward,
dtypes=empty_types(),
dtypesIfCUDA=custom_types(torch.float16, torch.float32)
if not SM80OrLater
else custom_types(torch.float16, torch.float32, torch.bfloat16),
supports_out=False,
supports_autograd=True,
supports_fwgrad_bwgrad=False,
supports_forward_ad=False,
check_batched_forward_grad=False,
decorators=[],
skips=(
# Device mismatch due to philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
# Checking the scaler value of the philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
# None Mismatch Tensor
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
)
),
UnaryUfuncInfo(
'nn.functional.silu',
aten_backward_name='silu_backward',

View File

@ -422,6 +422,7 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch
results = output_process_fn_grad(results)
flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0
@ -467,6 +468,7 @@ def check_backward_formula(op: Callable, args, kwargs,
)
flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0