mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
expose mem-eff to autograd (#110495)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110495 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
3afb4e5cf7
commit
c46fc46dba
|
|
@ -14491,13 +14491,13 @@
|
||||||
CUDA: _flash_attention_backward
|
CUDA: _flash_attention_backward
|
||||||
|
|
||||||
# Returns ouput, logsumexp if compute_logsumexp
|
# Returns ouput, logsumexp if compute_logsumexp
|
||||||
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
|
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
CUDA: _efficient_attention_forward
|
CUDA: _efficient_attention_forward
|
||||||
tags: nondeterministic_seeded
|
tags: nondeterministic_seeded
|
||||||
|
|
||||||
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
|
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
|
||||||
|
|
@ -307,7 +307,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
|
||||||
: sdp::CustomMaskType::NoCustomMask;
|
: sdp::CustomMaskType::NoCustomMask;
|
||||||
|
|
||||||
// See Note [Seed and Offset] for description of seed and offset
|
// See Note [Seed and Offset] for description of seed and offset
|
||||||
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
|
// Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values.
|
||||||
|
auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
||||||
query_buffer_reshaped.unsqueeze(0),
|
query_buffer_reshaped.unsqueeze(0),
|
||||||
key_buffer_reshaped.unsqueeze(0),
|
key_buffer_reshaped.unsqueeze(0),
|
||||||
value_buffer_reshaped.unsqueeze(0),
|
value_buffer_reshaped.unsqueeze(0),
|
||||||
|
|
|
||||||
|
|
@ -742,7 +742,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||||
? sdp::CustomMaskType::CausalFromTopLeft
|
? sdp::CustomMaskType::CausalFromTopLeft
|
||||||
: sdp::CustomMaskType::NoCustomMask;
|
: sdp::CustomMaskType::NoCustomMask;
|
||||||
|
|
||||||
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
|
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
||||||
q_t,
|
q_t,
|
||||||
k_t,
|
k_t,
|
||||||
v_t,
|
v_t,
|
||||||
|
|
@ -874,7 +874,7 @@ _flash_attention_forward(
|
||||||
Tensor());
|
Tensor());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_attention_forward(
|
||||||
const at::Tensor& query, // [b, seqlen, num_heads, K]
|
const at::Tensor& query, // [b, seqlen, num_heads, K]
|
||||||
const at::Tensor& key, // [b, seqlen, num_heads, K]
|
const at::Tensor& key, // [b, seqlen, num_heads, K]
|
||||||
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
|
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
|
||||||
|
|
@ -915,8 +915,8 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
||||||
|
|
||||||
// Embedding per head
|
// Embedding per head
|
||||||
TORCH_CHECK(query.size(3) == key.size(3));
|
TORCH_CHECK(query.size(3) == key.size(3));
|
||||||
// TODO_DRISS we should return max_seqlen_k;
|
|
||||||
int64_t max_seqlen_q, max_seqlen_k;
|
int64_t max_seqlen_q = 0, max_seqlen_k = 0;
|
||||||
TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
|
TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
|
||||||
if (seqstart_q.has_value()) {
|
if (seqstart_q.has_value()) {
|
||||||
TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
|
TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
|
||||||
|
|
@ -1164,10 +1164,12 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
||||||
std::move(res),
|
std::move(res),
|
||||||
std::move(logsumexp),
|
std::move(logsumexp),
|
||||||
std::move(seed_t),
|
std::move(seed_t),
|
||||||
std::move(offset_t));
|
std::move(offset_t),
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k);
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
||||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){
|
Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){
|
||||||
|
|
|
||||||
|
|
@ -134,14 +134,14 @@ _efficient_attention_backward(
|
||||||
const at::Tensor& query,
|
const at::Tensor& query,
|
||||||
const at::Tensor& key,
|
const at::Tensor& key,
|
||||||
const at::Tensor& value,
|
const at::Tensor& value,
|
||||||
const c10::optional<at::Tensor>& bias, // additive attention bias
|
const c10::optional<at::Tensor>& kernel_bias, // additive attention bias
|
||||||
const at::Tensor& out,
|
const at::Tensor& out,
|
||||||
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
|
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
|
||||||
// position of the first query token for batch $b
|
// position of the first query token for batch $b
|
||||||
const c10::optional<at::Tensor>& cu_seqlens_q,
|
const c10::optional<at::Tensor>& cu_seqlens_q_dummy,
|
||||||
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
|
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
|
||||||
// position of the first key token for batch $b
|
// position of the first key token for batch $b
|
||||||
const c10::optional<at::Tensor>& cu_seqlens_k,
|
const c10::optional<at::Tensor>& cu_seqlens_k_dummy,
|
||||||
// (Mode 1MHK only) Maximum sequence length across batches
|
// (Mode 1MHK only) Maximum sequence length across batches
|
||||||
int64_t max_seqlen_q,
|
int64_t max_seqlen_q,
|
||||||
// (Mode 1MHK only) Maximum sequence length across batches
|
// (Mode 1MHK only) Maximum sequence length across batches
|
||||||
|
|
@ -158,6 +158,14 @@ _efficient_attention_backward(
|
||||||
if (!grad_out_.defined()) {
|
if (!grad_out_.defined()) {
|
||||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||||
}
|
}
|
||||||
|
// This path is used when we directly call _efficient_attention_forward
|
||||||
|
// from python.
|
||||||
|
// This is needed because SaveVariable automatically converts
|
||||||
|
// c10::optional to undefined tensor
|
||||||
|
c10::optional<Tensor> bias, cu_seqlens_q, cu_seqlens_k;
|
||||||
|
bias = kernel_bias.has_value() && !kernel_bias->defined() ? c10::nullopt : kernel_bias;
|
||||||
|
cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? c10::nullopt : cu_seqlens_q_dummy;
|
||||||
|
cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? c10::nullopt : cu_seqlens_k_dummy;
|
||||||
|
|
||||||
// ndim
|
// ndim
|
||||||
TORCH_CHECK(query.dim() == grad_out_.dim());
|
TORCH_CHECK(query.dim() == grad_out_.dim());
|
||||||
|
|
|
||||||
|
|
@ -264,11 +264,8 @@ ALLOW_LIST = [
|
||||||
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
|
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
|
||||||
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
|
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
|
||||||
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
|
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
|
||||||
("aten::_scaled_dot_product_attention", datetime.date(2023, 8, 1)),
|
("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)),
|
||||||
("aten::_chunk_grad_outputs_efficient_attention", datetime.date(2023, 8, 1)),
|
("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)),
|
||||||
("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 9, 30)),
|
|
||||||
("aten::_flash_attention_forward", datetime.date(2023, 9, 30)),
|
|
||||||
("aten::_flash_attention_backward", datetime.date(2023, 9, 30)),
|
|
||||||
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
||||||
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
||||||
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),
|
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),
|
||||||
|
|
|
||||||
|
|
@ -392,6 +392,7 @@ class TestOperators(TestCase):
|
||||||
# query: last dimension must be contiguous
|
# query: last dimension must be contiguous
|
||||||
# Fused attention kernels require last dim to be contiguous
|
# Fused attention kernels require last dim to be contiguous
|
||||||
xfail('nn.functional.scaled_dot_product_attention'),
|
xfail('nn.functional.scaled_dot_product_attention'),
|
||||||
|
xfail("torch.ops.aten._efficient_attention_forward"),
|
||||||
}))
|
}))
|
||||||
@opsToleranceOverride('TestOperators', 'test_grad', (
|
@opsToleranceOverride('TestOperators', 'test_grad', (
|
||||||
tol1('nn.functional.binary_cross_entropy_with_logits',
|
tol1('nn.functional.binary_cross_entropy_with_logits',
|
||||||
|
|
@ -474,6 +475,7 @@ class TestOperators(TestCase):
|
||||||
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
||||||
|
|
||||||
xfail('nn.functional.scaled_dot_product_attention'),
|
xfail('nn.functional.scaled_dot_product_attention'),
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||||
|
|
||||||
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
|
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
|
||||||
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280
|
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280
|
||||||
|
|
@ -601,6 +603,7 @@ class TestOperators(TestCase):
|
||||||
# RuntimeError: query: last dimension must be contiguous
|
# RuntimeError: query: last dimension must be contiguous
|
||||||
# The fused attention kernels require the last dim to be contiguous
|
# The fused attention kernels require the last dim to be contiguous
|
||||||
xfail('nn.functional.scaled_dot_product_attention'),
|
xfail('nn.functional.scaled_dot_product_attention'),
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||||
# BUG
|
# BUG
|
||||||
# AssertionError: Tensor-likes are not close!
|
# AssertionError: Tensor-likes are not close!
|
||||||
xfail('as_strided'),
|
xfail('as_strided'),
|
||||||
|
|
@ -679,6 +682,7 @@ class TestOperators(TestCase):
|
||||||
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
|
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
|
||||||
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
|
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
|
||||||
skip('nn.functional.scaled_dot_product_attention'),
|
skip('nn.functional.scaled_dot_product_attention'),
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||||
# AssertionError: Tensor-likes are not close!
|
# AssertionError: Tensor-likes are not close!
|
||||||
# Mismatched elements: 1 / 15 (6.7%)
|
# Mismatched elements: 1 / 15 (6.7%)
|
||||||
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
|
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
|
||||||
|
|
@ -774,6 +778,7 @@ class TestOperators(TestCase):
|
||||||
skip("nn.functional.fractional_max_pool2d"), # calls random op
|
skip("nn.functional.fractional_max_pool2d"), # calls random op
|
||||||
skip("nn.functional.fractional_max_pool3d"), # calls random op
|
skip("nn.functional.fractional_max_pool3d"), # calls random op
|
||||||
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
||||||
# It looks like you're either (1) calling .item() on a Tensor or
|
# It looks like you're either (1) calling .item() on a Tensor or
|
||||||
# (2) attempting to use a Tensor in some data-dependent control flow or
|
# (2) attempting to use a Tensor in some data-dependent control flow or
|
||||||
|
|
@ -888,6 +893,7 @@ class TestOperators(TestCase):
|
||||||
skip('nn.functional.dropout3d', ''), # randomness
|
skip('nn.functional.dropout3d', ''), # randomness
|
||||||
skip('nn.functional.alpha_dropout'), # randomness
|
skip('nn.functional.alpha_dropout'), # randomness
|
||||||
skip('nn.functional.scaled_dot_product_attention'), # randomness
|
skip('nn.functional.scaled_dot_product_attention'), # randomness
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
skip('nn.functional.multi_head_attention_forward'), # randomness
|
skip('nn.functional.multi_head_attention_forward'), # randomness
|
||||||
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
|
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
|
||||||
xfail('nn.functional.fractional_max_pool2d'), # random
|
xfail('nn.functional.fractional_max_pool2d'), # random
|
||||||
|
|
@ -982,6 +988,7 @@ class TestOperators(TestCase):
|
||||||
skip('nn.functional.dropout2d', ''),
|
skip('nn.functional.dropout2d', ''),
|
||||||
skip('nn.functional.dropout3d', ''),
|
skip('nn.functional.dropout3d', ''),
|
||||||
skip('nn.functional.scaled_dot_product_attention'), # randomness
|
skip('nn.functional.scaled_dot_product_attention'), # randomness
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
skip('nn.functional.multi_head_attention_forward'), # randomness
|
skip('nn.functional.multi_head_attention_forward'), # randomness
|
||||||
skip('nn.functional.alpha_dropout'), # randomness
|
skip('nn.functional.alpha_dropout'), # randomness
|
||||||
skip('nn.functional.feature_alpha_dropout', 'without_train'),
|
skip('nn.functional.feature_alpha_dropout', 'without_train'),
|
||||||
|
|
@ -1253,6 +1260,7 @@ class TestOperators(TestCase):
|
||||||
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
|
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
|
||||||
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
|
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
|
||||||
skip('nn.functional.scaled_dot_product_attention'),
|
skip('nn.functional.scaled_dot_product_attention'),
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
skip('nn.functional.multi_head_attention_forward'), # randomness
|
skip('nn.functional.multi_head_attention_forward'), # randomness
|
||||||
skip('nn.functional.alpha_dropout'), # randomness
|
skip('nn.functional.alpha_dropout'), # randomness
|
||||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||||
|
|
@ -1376,6 +1384,7 @@ class TestOperators(TestCase):
|
||||||
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
|
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
|
||||||
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
|
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
|
||||||
skip('nn.functional.scaled_dot_product_attention'),
|
skip('nn.functional.scaled_dot_product_attention'),
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
|
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
|
||||||
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
||||||
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
||||||
|
|
@ -1500,6 +1509,7 @@ class TestOperators(TestCase):
|
||||||
xfail('nn.functional.dropout3d'), # calls random op
|
xfail('nn.functional.dropout3d'), # calls random op
|
||||||
xfail('nn.functional.dropout'), # calls random op
|
xfail('nn.functional.dropout'), # calls random op
|
||||||
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
||||||
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
|
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
|
||||||
xfail('nn.functional.alpha_dropout'), # calls randomn op
|
xfail('nn.functional.alpha_dropout'), # calls randomn op
|
||||||
|
|
@ -1768,6 +1778,7 @@ class TestOperators(TestCase):
|
||||||
xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call
|
xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call
|
||||||
xfail('nn.functional.max_unpool2d'), # contiguous call
|
xfail('nn.functional.max_unpool2d'), # contiguous call
|
||||||
xfail('to_sparse'), # dispatch key issue
|
xfail('to_sparse'), # dispatch key issue
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/96560
|
# https://github.com/pytorch/pytorch/issues/96560
|
||||||
decorate('xlogy', decorator=skipIfRocm),
|
decorate('xlogy', decorator=skipIfRocm),
|
||||||
|
|
|
||||||
|
|
@ -3604,6 +3604,8 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail('addcmul'),
|
xfail('addcmul'),
|
||||||
xfail('clamp'),
|
xfail('clamp'),
|
||||||
|
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
|
|
||||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||||
xfail('item'),
|
xfail('item'),
|
||||||
}))
|
}))
|
||||||
|
|
@ -3660,6 +3662,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
|
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
|
||||||
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
||||||
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
||||||
|
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||||
xfail('resize_'),
|
xfail('resize_'),
|
||||||
xfail('view_as_complex'),
|
xfail('view_as_complex'),
|
||||||
xfail('matrix_exp'),
|
xfail('matrix_exp'),
|
||||||
|
|
|
||||||
|
|
@ -234,6 +234,7 @@ inductor_expected_failures_single_sample["cuda"] = {
|
||||||
"to_sparse": {f16, f32, f64},
|
"to_sparse": {f16, f32, f64},
|
||||||
"pca_lowrank": {f32, f64},
|
"pca_lowrank": {f32, f64},
|
||||||
"svd_lowrank": {f32, f64},
|
"svd_lowrank": {f32, f64},
|
||||||
|
"torch.ops.aten._efficient_attention_forward": {f16, bf16, f32},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2783,6 +2783,10 @@
|
||||||
# output_differentiability: [True, False, False, False, False, False, False, False]
|
# output_differentiability: [True, False, False, False, False, False, False, False]
|
||||||
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||||
|
|
||||||
|
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
|
||||||
|
output_differentiability: [True, False, False, False, False, False]
|
||||||
|
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
|
||||||
|
|
||||||
# fft
|
# fft
|
||||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||||
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
|
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
|
||||||
|
|
|
||||||
|
|
@ -2082,7 +2082,8 @@ make_fallback(
|
||||||
sdpa_constraint,
|
sdpa_constraint,
|
||||||
warn=False,
|
warn=False,
|
||||||
)
|
)
|
||||||
|
make_fallback(torch.ops.aten._efficient_attention_forward.default)
|
||||||
|
make_fallback(torch.ops.aten._efficient_attention_backward.default)
|
||||||
make_fallback(aten.sort)
|
make_fallback(aten.sort)
|
||||||
make_fallback(aten.sort.stable)
|
make_fallback(aten.sort.stable)
|
||||||
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
|
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
|
||||||
|
|
|
||||||
|
|
@ -5188,6 +5188,90 @@ def meta__scaled_dot_product_efficient_backward(
|
||||||
return grad_q, grad_k, grad_v, grad_bias
|
return grad_q, grad_k, grad_v, grad_bias
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(
|
||||||
|
[
|
||||||
|
aten._efficient_attention_forward,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def meta__efficient_attention_forward(
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
bias: Optional[Tensor],
|
||||||
|
cu_seqlens_q: Optional[Tensor],
|
||||||
|
cu_seqlens_k: Optional[Tensor],
|
||||||
|
max_seqlen_q: Optional[int],
|
||||||
|
dropout_p: float,
|
||||||
|
custom_mask_type: int,
|
||||||
|
compute_log_sumexp: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
causal_diagonal: Optional[Tensor] = None,
|
||||||
|
seqlen_k: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
B = query.size(0)
|
||||||
|
M = query.size(1)
|
||||||
|
N = key.size(1)
|
||||||
|
num_heads = query.size(-2)
|
||||||
|
K = query.size(-1)
|
||||||
|
Kv = value.size(-1)
|
||||||
|
|
||||||
|
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
|
||||||
|
|
||||||
|
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
|
||||||
|
logsum_exp = torch.empty(
|
||||||
|
(B, num_heads, logsumexp_dim),
|
||||||
|
dtype=torch.float,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# See Note [Seed and Offset]:
|
||||||
|
seed = torch.empty((), dtype=torch.long, device="meta")
|
||||||
|
offset = torch.empty((), dtype=torch.long, device="meta")
|
||||||
|
|
||||||
|
return res, logsum_exp, seed, offset, M, N
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(
|
||||||
|
[
|
||||||
|
aten._efficient_attention_backward,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def meta__efficient_attention_backward(
|
||||||
|
grad_out: Tensor,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
bias: Optional[Tensor],
|
||||||
|
cu_seqlens_q: Optional[Tensor],
|
||||||
|
cu_seqlens_k: Optional[Tensor],
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
logsumexp: Tensor,
|
||||||
|
dropout_p: float,
|
||||||
|
philox_seed: Tensor,
|
||||||
|
philox_offset: Tensor,
|
||||||
|
custom_mask_type: int,
|
||||||
|
bias_requires_grad: bool,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
num_splits_key: Optional[int] = None,
|
||||||
|
):
|
||||||
|
grad_query = torch.empty_like(query)
|
||||||
|
grad_key = torch.empty_like(key)
|
||||||
|
grad_value = torch.empty_like(value)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
assert bias is not None
|
||||||
|
lastDim = bias.size(-1)
|
||||||
|
lastDimAligned = 16 * ((lastDim + 15) // 16)
|
||||||
|
new_sizes = list(bias.size())
|
||||||
|
new_sizes[-1] = lastDimAligned
|
||||||
|
grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
|
||||||
|
else:
|
||||||
|
grad_bias = torch.empty((), device=query.device)
|
||||||
|
|
||||||
|
return grad_query, grad_key, grad_value, grad_bias
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten._scaled_mm.default])
|
@register_meta([aten._scaled_mm.default])
|
||||||
def meta_scaled_mm(
|
def meta_scaled_mm(
|
||||||
self: torch.Tensor,
|
self: torch.Tensor,
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,25 @@ def output_alias_each_other(outputs):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_sdpa_error(func, idx, e):
|
||||||
|
if (
|
||||||
|
func is aten._scaled_dot_product_flash_attention.default
|
||||||
|
and idx in (6, 7)
|
||||||
|
and "Devices" in repr(e)
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
func is aten._scaled_dot_product_efficient_attention.default
|
||||||
|
or func is aten._efficient_attention_forward.default
|
||||||
|
)
|
||||||
|
and idx in (2, 3)
|
||||||
|
and "Devices" in repr(e)
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class CrossRefFakeMode(TorchDispatchMode):
|
class CrossRefFakeMode(TorchDispatchMode):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -155,17 +174,7 @@ class CrossRefFakeMode(TorchDispatchMode):
|
||||||
allow_rhs_unbacked=True,
|
allow_rhs_unbacked=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if (
|
if is_sdpa_error(func, idx, e):
|
||||||
func is aten._scaled_dot_product_flash_attention.default
|
|
||||||
and idx in (6, 7)
|
|
||||||
and "Devices" in repr(e)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if (
|
|
||||||
func is aten._scaled_dot_product_efficient_attention.default
|
|
||||||
and idx in (2, 3)
|
|
||||||
and "Devices" in repr(e)
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
error_message = (
|
error_message = (
|
||||||
f"{context} mismatched tensor metadata: {e}"
|
f"{context} mismatched tensor metadata: {e}"
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,10 @@ _all_types_and_half = _all_types + (torch.half,)
|
||||||
def all_types_and_half():
|
def all_types_and_half():
|
||||||
return _all_types_and_half
|
return _all_types_and_half
|
||||||
|
|
||||||
|
def custom_types(*dtypes):
|
||||||
|
"""Create a list of arbitrary dtypes"""
|
||||||
|
return _empty_types + _validate_dtypes(*dtypes)
|
||||||
|
|
||||||
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
|
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
|
||||||
|
|
||||||
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
|
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||||
all_types, empty_types, complex_types_and, integral_types
|
all_types, empty_types, complex_types_and, integral_types, custom_types
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||||
|
|
@ -8242,6 +8242,78 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
||||||
|
|
||||||
yield from samples
|
yield from samples
|
||||||
|
|
||||||
|
|
||||||
|
def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
|
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
batch, num_heads, head_dim = 4, 4, 8
|
||||||
|
seq_q = 11
|
||||||
|
seq_kv = 32
|
||||||
|
|
||||||
|
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
|
||||||
|
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
|
||||||
|
|
||||||
|
qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
|
||||||
|
samples = []
|
||||||
|
mask_types = [1, 2] # UpperLeft, LowerRight
|
||||||
|
scales = [None, 1.0]
|
||||||
|
|
||||||
|
for qkv_shape, is_causal, dropout_p, mask_type, scale in product(
|
||||||
|
qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales):
|
||||||
|
shape_q, shape_kv = qkv_shape
|
||||||
|
samples.append(SampleInput(
|
||||||
|
make(shape_q).transpose(1, 2),
|
||||||
|
make(shape_kv).transpose(1, 2),
|
||||||
|
make(shape_kv).transpose(1, 2),
|
||||||
|
bias=None,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
max_seqlen_q=None,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
custom_mask_type=mask_type,
|
||||||
|
compute_log_sumexp=requires_grad,
|
||||||
|
scale=scale,
|
||||||
|
causal_diagonal=None,
|
||||||
|
seqlen_k=None
|
||||||
|
))
|
||||||
|
|
||||||
|
# Add non standard shapes
|
||||||
|
diff_v_head_dim = SampleInput(
|
||||||
|
make((batch, seq_q, num_heads, head_dim)),
|
||||||
|
make((batch, seq_kv, num_heads, head_dim)),
|
||||||
|
make((batch, seq_kv, num_heads, head_dim + 8)),
|
||||||
|
bias=None,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
max_seqlen_q=None,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
custom_mask_type=0, # No Mask
|
||||||
|
compute_log_sumexp=requires_grad,
|
||||||
|
scale=None,
|
||||||
|
causal_diagonal=None,
|
||||||
|
seqlen_k=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add an attn_mask
|
||||||
|
samples.append(
|
||||||
|
SampleInput(
|
||||||
|
make((batch, seq_q, num_heads, head_dim)),
|
||||||
|
make((batch, seq_kv, num_heads, head_dim)),
|
||||||
|
make((batch, seq_kv, num_heads, head_dim)),
|
||||||
|
bias=make(batch, num_heads, seq_q, seq_kv),
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
max_seqlen_q=None,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
custom_mask_type=0, # No Mask
|
||||||
|
compute_log_sumexp=requires_grad,
|
||||||
|
scale=None,
|
||||||
|
causal_diagonal=None,
|
||||||
|
seqlen_k=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from samples
|
||||||
|
|
||||||
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
@ -14172,6 +14244,31 @@ op_db: List[OpInfo] = [
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
||||||
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
|
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
|
||||||
),
|
),
|
||||||
|
OpInfo(
|
||||||
|
'torch.ops.aten._efficient_attention_forward',
|
||||||
|
sample_inputs_func=sample_inputs_efficient_attention_forward,
|
||||||
|
dtypes=empty_types(),
|
||||||
|
dtypesIfCUDA=custom_types(torch.float16, torch.float32)
|
||||||
|
if not SM80OrLater
|
||||||
|
else custom_types(torch.float16, torch.float32, torch.bfloat16),
|
||||||
|
supports_out=False,
|
||||||
|
supports_autograd=True,
|
||||||
|
supports_fwgrad_bwgrad=False,
|
||||||
|
supports_forward_ad=False,
|
||||||
|
check_batched_forward_grad=False,
|
||||||
|
decorators=[],
|
||||||
|
skips=(
|
||||||
|
# Device mismatch due to philox seed and offset
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
|
||||||
|
# Checking the scaler value of the philox seed and offset
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
|
||||||
|
# None Mismatch Tensor
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
|
||||||
|
)
|
||||||
|
),
|
||||||
UnaryUfuncInfo(
|
UnaryUfuncInfo(
|
||||||
'nn.functional.silu',
|
'nn.functional.silu',
|
||||||
aten_backward_name='silu_backward',
|
aten_backward_name='silu_backward',
|
||||||
|
|
|
||||||
|
|
@ -422,6 +422,7 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch
|
||||||
results = output_process_fn_grad(results)
|
results = output_process_fn_grad(results)
|
||||||
|
|
||||||
flat_results = pytree.tree_leaves(results)
|
flat_results = pytree.tree_leaves(results)
|
||||||
|
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
|
||||||
flat_diff_results = [r for r in flat_results if r.requires_grad]
|
flat_diff_results = [r for r in flat_results if r.requires_grad]
|
||||||
assert len(flat_diff_results) > 0
|
assert len(flat_diff_results) > 0
|
||||||
|
|
||||||
|
|
@ -467,6 +468,7 @@ def check_backward_formula(op: Callable, args, kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
flat_results = pytree.tree_leaves(results)
|
flat_results = pytree.tree_leaves(results)
|
||||||
|
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
|
||||||
flat_diff_results = [r for r in flat_results if r.requires_grad]
|
flat_diff_results = [r for r in flat_results if r.requires_grad]
|
||||||
assert len(flat_diff_results) > 0
|
assert len(flat_diff_results) > 0
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user