expose mem-eff to autograd (#110495)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110495
Approved by: https://github.com/jbschlosser
This commit is contained in:
drisspg 2023-11-10 17:22:56 -08:00 committed by PyTorch MergeBot
parent 3afb4e5cf7
commit c46fc46dba
15 changed files with 254 additions and 30 deletions

View File

@ -14491,13 +14491,13 @@
CUDA: _flash_attention_backward CUDA: _flash_attention_backward
# Returns ouput, logsumexp if compute_logsumexp # Returns ouput, logsumexp if compute_logsumexp
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) - func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
variants: function variants: function
dispatch: dispatch:
CUDA: _efficient_attention_forward CUDA: _efficient_attention_forward
tags: nondeterministic_seeded tags: nondeterministic_seeded
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor) - func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
device_check: NoCheck device_check: NoCheck
variants: function variants: function
dispatch: dispatch:

View File

@ -307,7 +307,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
: sdp::CustomMaskType::NoCustomMask; : sdp::CustomMaskType::NoCustomMask;
// See Note [Seed and Offset] for description of seed and offset // See Note [Seed and Offset] for description of seed and offset
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward( // Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values.
auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
query_buffer_reshaped.unsqueeze(0), query_buffer_reshaped.unsqueeze(0),
key_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0),
value_buffer_reshaped.unsqueeze(0), value_buffer_reshaped.unsqueeze(0),

View File

@ -742,7 +742,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
? sdp::CustomMaskType::CausalFromTopLeft ? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask; : sdp::CustomMaskType::NoCustomMask;
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward( auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
q_t, q_t,
k_t, k_t,
v_t, v_t,
@ -874,7 +874,7 @@ _flash_attention_forward(
Tensor()); Tensor());
} }
std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward( std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_attention_forward(
const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv] const at::Tensor& value, // [b, seqlen, num_heads, Kv]
@ -915,8 +915,8 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
// Embedding per head // Embedding per head
TORCH_CHECK(query.size(3) == key.size(3)); TORCH_CHECK(query.size(3) == key.size(3));
// TODO_DRISS we should return max_seqlen_k;
int64_t max_seqlen_q, max_seqlen_k; int64_t max_seqlen_q = 0, max_seqlen_k = 0;
TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value()); TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
if (seqstart_q.has_value()) { if (seqstart_q.has_value()) {
TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int); TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
@ -1164,10 +1164,12 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
std::move(res), std::move(res),
std::move(logsumexp), std::move(logsumexp),
std::move(seed_t), std::move(seed_t),
std::move(offset_t)); std::move(offset_t),
max_seqlen_q,
max_seqlen_k);
#endif #endif
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0);
} }
Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){

View File

@ -134,14 +134,14 @@ _efficient_attention_backward(
const at::Tensor& query, const at::Tensor& query,
const at::Tensor& key, const at::Tensor& key,
const at::Tensor& value, const at::Tensor& value,
const c10::optional<at::Tensor>& bias, // additive attention bias const c10::optional<at::Tensor>& kernel_bias, // additive attention bias
const at::Tensor& out, const at::Tensor& out,
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
// position of the first query token for batch $b // position of the first query token for batch $b
const c10::optional<at::Tensor>& cu_seqlens_q, const c10::optional<at::Tensor>& cu_seqlens_q_dummy,
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
// position of the first key token for batch $b // position of the first key token for batch $b
const c10::optional<at::Tensor>& cu_seqlens_k, const c10::optional<at::Tensor>& cu_seqlens_k_dummy,
// (Mode 1MHK only) Maximum sequence length across batches // (Mode 1MHK only) Maximum sequence length across batches
int64_t max_seqlen_q, int64_t max_seqlen_q,
// (Mode 1MHK only) Maximum sequence length across batches // (Mode 1MHK only) Maximum sequence length across batches
@ -158,6 +158,14 @@ _efficient_attention_backward(
if (!grad_out_.defined()) { if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
} }
// This path is used when we directly call _efficient_attention_forward
// from python.
// This is needed because SaveVariable automatically converts
// c10::optional to undefined tensor
c10::optional<Tensor> bias, cu_seqlens_q, cu_seqlens_k;
bias = kernel_bias.has_value() && !kernel_bias->defined() ? c10::nullopt : kernel_bias;
cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? c10::nullopt : cu_seqlens_q_dummy;
cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? c10::nullopt : cu_seqlens_k_dummy;
// ndim // ndim
TORCH_CHECK(query.dim() == grad_out_.dim()); TORCH_CHECK(query.dim() == grad_out_.dim());

View File

@ -264,11 +264,8 @@ ALLOW_LIST = [
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
("aten::_scaled_dot_product_attention", datetime.date(2023, 8, 1)), ("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)),
("aten::_chunk_grad_outputs_efficient_attention", datetime.date(2023, 8, 1)), ("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)),
("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 9, 30)),
("aten::_flash_attention_forward", datetime.date(2023, 9, 30)),
("aten::_flash_attention_backward", datetime.date(2023, 9, 30)),
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)), ("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)), ("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),

View File

@ -392,6 +392,7 @@ class TestOperators(TestCase):
# query: last dimension must be contiguous # query: last dimension must be contiguous
# Fused attention kernels require last dim to be contiguous # Fused attention kernels require last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention'), xfail('nn.functional.scaled_dot_product_attention'),
xfail("torch.ops.aten._efficient_attention_forward"),
})) }))
@opsToleranceOverride('TestOperators', 'test_grad', ( @opsToleranceOverride('TestOperators', 'test_grad', (
tol1('nn.functional.binary_cross_entropy_with_logits', tol1('nn.functional.binary_cross_entropy_with_logits',
@ -474,6 +475,7 @@ class TestOperators(TestCase):
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail('nn.functional.scaled_dot_product_attention'), xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'),
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280 xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280
@ -601,6 +603,7 @@ class TestOperators(TestCase):
# RuntimeError: query: last dimension must be contiguous # RuntimeError: query: last dimension must be contiguous
# The fused attention kernels require the last dim to be contiguous # The fused attention kernels require the last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention'), xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'),
# BUG # BUG
# AssertionError: Tensor-likes are not close! # AssertionError: Tensor-likes are not close!
xfail('as_strided'), xfail('as_strided'),
@ -679,6 +682,7 @@ class TestOperators(TestCase):
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
skip('nn.functional.scaled_dot_product_attention'), skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'),
# AssertionError: Tensor-likes are not close! # AssertionError: Tensor-likes are not close!
# Mismatched elements: 1 / 15 (6.7%) # Mismatched elements: 1 / 15 (6.7%)
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
@ -774,6 +778,7 @@ class TestOperators(TestCase):
skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op
skip("nn.functional.fractional_max_pool3d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op
xfail('nn.functional.scaled_dot_product_attention'), # randomness xfail('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('nn.functional.multi_head_attention_forward'), # randomness xfail('nn.functional.multi_head_attention_forward'), # randomness
# It looks like you're either (1) calling .item() on a Tensor or # It looks like you're either (1) calling .item() on a Tensor or
# (2) attempting to use a Tensor in some data-dependent control flow or # (2) attempting to use a Tensor in some data-dependent control flow or
@ -888,6 +893,7 @@ class TestOperators(TestCase):
skip('nn.functional.dropout3d', ''), # randomness skip('nn.functional.dropout3d', ''), # randomness
skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.alpha_dropout'), # randomness
skip('nn.functional.scaled_dot_product_attention'), # randomness skip('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
skip('nn.functional.multi_head_attention_forward'), # randomness skip('nn.functional.multi_head_attention_forward'), # randomness
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
xfail('nn.functional.fractional_max_pool2d'), # random xfail('nn.functional.fractional_max_pool2d'), # random
@ -982,6 +988,7 @@ class TestOperators(TestCase):
skip('nn.functional.dropout2d', ''), skip('nn.functional.dropout2d', ''),
skip('nn.functional.dropout3d', ''), skip('nn.functional.dropout3d', ''),
skip('nn.functional.scaled_dot_product_attention'), # randomness skip('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
skip('nn.functional.multi_head_attention_forward'), # randomness skip('nn.functional.multi_head_attention_forward'), # randomness
skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.alpha_dropout'), # randomness
skip('nn.functional.feature_alpha_dropout', 'without_train'), skip('nn.functional.feature_alpha_dropout', 'without_train'),
@ -1253,6 +1260,7 @@ class TestOperators(TestCase):
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
skip('nn.functional.scaled_dot_product_attention'), skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
skip('nn.functional.multi_head_attention_forward'), # randomness skip('nn.functional.multi_head_attention_forward'), # randomness
skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.alpha_dropout'), # randomness
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
@ -1376,6 +1384,7 @@ class TestOperators(TestCase):
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
skip('nn.functional.scaled_dot_product_attention'), skip('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
@ -1500,6 +1509,7 @@ class TestOperators(TestCase):
xfail('nn.functional.dropout3d'), # calls random op xfail('nn.functional.dropout3d'), # calls random op
xfail('nn.functional.dropout'), # calls random op xfail('nn.functional.dropout'), # calls random op
xfail('nn.functional.scaled_dot_product_attention'), # randomness xfail('nn.functional.scaled_dot_product_attention'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('nn.functional.multi_head_attention_forward'), # randomness xfail('nn.functional.multi_head_attention_forward'), # randomness
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
xfail('nn.functional.alpha_dropout'), # calls randomn op xfail('nn.functional.alpha_dropout'), # calls randomn op
@ -1768,6 +1778,7 @@ class TestOperators(TestCase):
xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call
xfail('nn.functional.max_unpool2d'), # contiguous call xfail('nn.functional.max_unpool2d'), # contiguous call
xfail('to_sparse'), # dispatch key issue xfail('to_sparse'), # dispatch key issue
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
# https://github.com/pytorch/pytorch/issues/96560 # https://github.com/pytorch/pytorch/issues/96560
decorate('xlogy', decorator=skipIfRocm), decorate('xlogy', decorator=skipIfRocm),

View File

@ -3604,6 +3604,8 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('addcmul'), xfail('addcmul'),
xfail('clamp'), xfail('clamp'),
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
# TypeError: expected Tensor as element 0 in argument 0, but got float # TypeError: expected Tensor as element 0 in argument 0, but got float
xfail('item'), xfail('item'),
})) }))
@ -3660,6 +3662,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
xfail('nn.functional.scaled_dot_product_attention'), # randomness xfail('nn.functional.scaled_dot_product_attention'), # randomness
xfail('nn.functional.multi_head_attention_forward'), # randomness xfail('nn.functional.multi_head_attention_forward'), # randomness
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
xfail('resize_'), xfail('resize_'),
xfail('view_as_complex'), xfail('view_as_complex'),
xfail('matrix_exp'), xfail('matrix_exp'),

View File

@ -234,6 +234,7 @@ inductor_expected_failures_single_sample["cuda"] = {
"to_sparse": {f16, f32, f64}, "to_sparse": {f16, f32, f64},
"pca_lowrank": {f32, f64}, "pca_lowrank": {f32, f64},
"svd_lowrank": {f32, f64}, "svd_lowrank": {f32, f64},
"torch.ops.aten._efficient_attention_forward": {f16, bf16, f32},
} }

View File

@ -2783,6 +2783,10 @@
# output_differentiability: [True, False, False, False, False, False, False, False] # output_differentiability: [True, False, False, False, False, False, False, False]
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) # query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
output_differentiability: [True, False, False, False, False, False]
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
# fft # fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))

View File

@ -2082,7 +2082,8 @@ make_fallback(
sdpa_constraint, sdpa_constraint,
warn=False, warn=False,
) )
make_fallback(torch.ops.aten._efficient_attention_forward.default)
make_fallback(torch.ops.aten._efficient_attention_backward.default)
make_fallback(aten.sort) make_fallback(aten.sort)
make_fallback(aten.sort.stable) make_fallback(aten.sort.stable)
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)

View File

@ -5188,6 +5188,90 @@ def meta__scaled_dot_product_efficient_backward(
return grad_q, grad_k, grad_v, grad_bias return grad_q, grad_k, grad_v, grad_bias
@register_meta(
[
aten._efficient_attention_forward,
]
)
def meta__efficient_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
cu_seqlens_q: Optional[Tensor],
cu_seqlens_k: Optional[Tensor],
max_seqlen_q: Optional[int],
dropout_p: float,
custom_mask_type: int,
compute_log_sumexp: bool = False,
scale: Optional[float] = None,
causal_diagonal: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
):
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device=query.device,
)
# See Note [Seed and Offset]:
seed = torch.empty((), dtype=torch.long, device="meta")
offset = torch.empty((), dtype=torch.long, device="meta")
return res, logsum_exp, seed, offset, M, N
@register_meta(
[
aten._efficient_attention_backward,
]
)
def meta__efficient_attention_backward(
grad_out: Tensor,
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor],
cu_seqlens_q: Optional[Tensor],
cu_seqlens_k: Optional[Tensor],
max_seqlen_q: int,
max_seqlen_k: int,
logsumexp: Tensor,
dropout_p: float,
philox_seed: Tensor,
philox_offset: Tensor,
custom_mask_type: int,
bias_requires_grad: bool,
scale: Optional[float] = None,
num_splits_key: Optional[int] = None,
):
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
if bias is not None:
assert bias is not None
lastDim = bias.size(-1)
lastDimAligned = 16 * ((lastDim + 15) // 16)
new_sizes = list(bias.size())
new_sizes[-1] = lastDimAligned
grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
else:
grad_bias = torch.empty((), device=query.device)
return grad_query, grad_key, grad_value, grad_bias
@register_meta([aten._scaled_mm.default]) @register_meta([aten._scaled_mm.default])
def meta_scaled_mm( def meta_scaled_mm(
self: torch.Tensor, self: torch.Tensor,

View File

@ -45,6 +45,25 @@ def output_alias_each_other(outputs):
return False return False
def is_sdpa_error(func, idx, e):
if (
func is aten._scaled_dot_product_flash_attention.default
and idx in (6, 7)
and "Devices" in repr(e)
):
return True
if (
(
func is aten._scaled_dot_product_efficient_attention.default
or func is aten._efficient_attention_forward.default
)
and idx in (2, 3)
and "Devices" in repr(e)
):
return True
return False
class CrossRefFakeMode(TorchDispatchMode): class CrossRefFakeMode(TorchDispatchMode):
def __init__( def __init__(
self, self,
@ -155,17 +174,7 @@ class CrossRefFakeMode(TorchDispatchMode):
allow_rhs_unbacked=True, allow_rhs_unbacked=True,
) )
except Exception as e: except Exception as e:
if ( if is_sdpa_error(func, idx, e):
func is aten._scaled_dot_product_flash_attention.default
and idx in (6, 7)
and "Devices" in repr(e)
):
continue
if (
func is aten._scaled_dot_product_efficient_attention.default
and idx in (2, 3)
and "Devices" in repr(e)
):
continue continue
error_message = ( error_message = (
f"{context} mismatched tensor metadata: {e}" f"{context} mismatched tensor metadata: {e}"

View File

@ -76,6 +76,10 @@ _all_types_and_half = _all_types + (torch.half,)
def all_types_and_half(): def all_types_and_half():
return _all_types_and_half return _all_types_and_half
def custom_types(*dtypes):
"""Create a list of arbitrary dtypes"""
return _empty_types + _validate_dtypes(*dtypes)
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro # The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. # See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.

View File

@ -18,7 +18,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import ( from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
all_types, empty_types, complex_types_and, integral_types all_types, empty_types, complex_types_and, integral_types, custom_types
) )
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -8242,6 +8242,78 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
yield from samples yield from samples
def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, num_heads, head_dim = 4, 4, 8
seq_q = 11
seq_kv = 32
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
samples = []
mask_types = [1, 2] # UpperLeft, LowerRight
scales = [None, 1.0]
for qkv_shape, is_causal, dropout_p, mask_type, scale in product(
qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales):
shape_q, shape_kv = qkv_shape
samples.append(SampleInput(
make(shape_q).transpose(1, 2),
make(shape_kv).transpose(1, 2),
make(shape_kv).transpose(1, 2),
bias=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
dropout_p=dropout_p,
custom_mask_type=mask_type,
compute_log_sumexp=requires_grad,
scale=scale,
causal_diagonal=None,
seqlen_k=None
))
# Add non standard shapes
diff_v_head_dim = SampleInput(
make((batch, seq_q, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim + 8)),
bias=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
dropout_p=dropout_p,
custom_mask_type=0, # No Mask
compute_log_sumexp=requires_grad,
scale=None,
causal_diagonal=None,
seqlen_k=None
)
# Add an attn_mask
samples.append(
SampleInput(
make((batch, seq_q, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim)),
make((batch, seq_kv, num_heads, head_dim)),
bias=make(batch, num_heads, seq_q, seq_kv),
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
dropout_p=dropout_p,
custom_mask_type=0, # No Mask
compute_log_sumexp=requires_grad,
scale=None,
causal_diagonal=None,
seqlen_k=None
)
)
yield from samples
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -14172,6 +14244,31 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),), device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
), ),
OpInfo(
'torch.ops.aten._efficient_attention_forward',
sample_inputs_func=sample_inputs_efficient_attention_forward,
dtypes=empty_types(),
dtypesIfCUDA=custom_types(torch.float16, torch.float32)
if not SM80OrLater
else custom_types(torch.float16, torch.float32, torch.bfloat16),
supports_out=False,
supports_autograd=True,
supports_fwgrad_bwgrad=False,
supports_forward_ad=False,
check_batched_forward_grad=False,
decorators=[],
skips=(
# Device mismatch due to philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
# Checking the scaler value of the philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
# None Mismatch Tensor
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
)
),
UnaryUfuncInfo( UnaryUfuncInfo(
'nn.functional.silu', 'nn.functional.silu',
aten_backward_name='silu_backward', aten_backward_name='silu_backward',

View File

@ -422,6 +422,7 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch
results = output_process_fn_grad(results) results = output_process_fn_grad(results)
flat_results = pytree.tree_leaves(results) flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad] flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0 assert len(flat_diff_results) > 0
@ -467,6 +468,7 @@ def check_backward_formula(op: Callable, args, kwargs,
) )
flat_results = pytree.tree_leaves(results) flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad] flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0 assert len(flat_diff_results) > 0