mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
expose mem-eff to autograd (#110495)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110495 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
3afb4e5cf7
commit
c46fc46dba
|
|
@ -14491,13 +14491,13 @@
|
|||
CUDA: _flash_attention_backward
|
||||
|
||||
# Returns ouput, logsumexp if compute_logsumexp
|
||||
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _efficient_attention_forward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -307,7 +307,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
|
|||
: sdp::CustomMaskType::NoCustomMask;
|
||||
|
||||
// See Note [Seed and Offset] for description of seed and offset
|
||||
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
|
||||
// Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values.
|
||||
auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
||||
query_buffer_reshaped.unsqueeze(0),
|
||||
key_buffer_reshaped.unsqueeze(0),
|
||||
value_buffer_reshaped.unsqueeze(0),
|
||||
|
|
|
|||
|
|
@ -742,7 +742,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||
? sdp::CustomMaskType::CausalFromTopLeft
|
||||
: sdp::CustomMaskType::NoCustomMask;
|
||||
|
||||
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
|
||||
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
|
|
@ -874,7 +874,7 @@ _flash_attention_forward(
|
|||
Tensor());
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_attention_forward(
|
||||
const at::Tensor& query, // [b, seqlen, num_heads, K]
|
||||
const at::Tensor& key, // [b, seqlen, num_heads, K]
|
||||
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
|
||||
|
|
@ -915,8 +915,8 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
|||
|
||||
// Embedding per head
|
||||
TORCH_CHECK(query.size(3) == key.size(3));
|
||||
// TODO_DRISS we should return max_seqlen_k;
|
||||
int64_t max_seqlen_q, max_seqlen_k;
|
||||
|
||||
int64_t max_seqlen_q = 0, max_seqlen_k = 0;
|
||||
TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value());
|
||||
if (seqstart_q.has_value()) {
|
||||
TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
|
||||
|
|
@ -1164,10 +1164,12 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
|||
std::move(res),
|
||||
std::move(logsumexp),
|
||||
std::move(seed_t),
|
||||
std::move(offset_t));
|
||||
std::move(offset_t),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k);
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, 0, 0);
|
||||
}
|
||||
|
||||
Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tensor& v, double dropout_p){
|
||||
|
|
|
|||
|
|
@ -134,14 +134,14 @@ _efficient_attention_backward(
|
|||
const at::Tensor& query,
|
||||
const at::Tensor& key,
|
||||
const at::Tensor& value,
|
||||
const c10::optional<at::Tensor>& bias, // additive attention bias
|
||||
const c10::optional<at::Tensor>& kernel_bias, // additive attention bias
|
||||
const at::Tensor& out,
|
||||
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
|
||||
// position of the first query token for batch $b
|
||||
const c10::optional<at::Tensor>& cu_seqlens_q,
|
||||
const c10::optional<at::Tensor>& cu_seqlens_q_dummy,
|
||||
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
|
||||
// position of the first key token for batch $b
|
||||
const c10::optional<at::Tensor>& cu_seqlens_k,
|
||||
const c10::optional<at::Tensor>& cu_seqlens_k_dummy,
|
||||
// (Mode 1MHK only) Maximum sequence length across batches
|
||||
int64_t max_seqlen_q,
|
||||
// (Mode 1MHK only) Maximum sequence length across batches
|
||||
|
|
@ -158,6 +158,14 @@ _efficient_attention_backward(
|
|||
if (!grad_out_.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
// This path is used when we directly call _efficient_attention_forward
|
||||
// from python.
|
||||
// This is needed because SaveVariable automatically converts
|
||||
// c10::optional to undefined tensor
|
||||
c10::optional<Tensor> bias, cu_seqlens_q, cu_seqlens_k;
|
||||
bias = kernel_bias.has_value() && !kernel_bias->defined() ? c10::nullopt : kernel_bias;
|
||||
cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? c10::nullopt : cu_seqlens_q_dummy;
|
||||
cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? c10::nullopt : cu_seqlens_k_dummy;
|
||||
|
||||
// ndim
|
||||
TORCH_CHECK(query.dim() == grad_out_.dim());
|
||||
|
|
|
|||
|
|
@ -264,11 +264,8 @@ ALLOW_LIST = [
|
|||
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
|
||||
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
|
||||
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
|
||||
("aten::_scaled_dot_product_attention", datetime.date(2023, 8, 1)),
|
||||
("aten::_chunk_grad_outputs_efficient_attention", datetime.date(2023, 8, 1)),
|
||||
("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 9, 30)),
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 9, 30)),
|
||||
("aten::_flash_attention_backward", datetime.date(2023, 9, 30)),
|
||||
("aten::_efficient_attention_forward", datetime.date(2023, 11, 30)),
|
||||
("aten::_efficient_attention_backward", datetime.date(2023, 11, 30)),
|
||||
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
||||
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
||||
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),
|
||||
|
|
|
|||
|
|
@ -392,6 +392,7 @@ class TestOperators(TestCase):
|
|||
# query: last dimension must be contiguous
|
||||
# Fused attention kernels require last dim to be contiguous
|
||||
xfail('nn.functional.scaled_dot_product_attention'),
|
||||
xfail("torch.ops.aten._efficient_attention_forward"),
|
||||
}))
|
||||
@opsToleranceOverride('TestOperators', 'test_grad', (
|
||||
tol1('nn.functional.binary_cross_entropy_with_logits',
|
||||
|
|
@ -474,6 +475,7 @@ class TestOperators(TestCase):
|
|||
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
||||
|
||||
xfail('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||
|
||||
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
|
||||
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280
|
||||
|
|
@ -601,6 +603,7 @@ class TestOperators(TestCase):
|
|||
# RuntimeError: query: last dimension must be contiguous
|
||||
# The fused attention kernels require the last dim to be contiguous
|
||||
xfail('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||
# BUG
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
xfail('as_strided'),
|
||||
|
|
@ -679,6 +682,7 @@ class TestOperators(TestCase):
|
|||
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
|
||||
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
|
||||
skip('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'),
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
# Mismatched elements: 1 / 15 (6.7%)
|
||||
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
|
||||
|
|
@ -774,6 +778,7 @@ class TestOperators(TestCase):
|
|||
skip("nn.functional.fractional_max_pool2d"), # calls random op
|
||||
skip("nn.functional.fractional_max_pool3d"), # calls random op
|
||||
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
||||
# It looks like you're either (1) calling .item() on a Tensor or
|
||||
# (2) attempting to use a Tensor in some data-dependent control flow or
|
||||
|
|
@ -888,6 +893,7 @@ class TestOperators(TestCase):
|
|||
skip('nn.functional.dropout3d', ''), # randomness
|
||||
skip('nn.functional.alpha_dropout'), # randomness
|
||||
skip('nn.functional.scaled_dot_product_attention'), # randomness
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
skip('nn.functional.multi_head_attention_forward'), # randomness
|
||||
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
|
||||
xfail('nn.functional.fractional_max_pool2d'), # random
|
||||
|
|
@ -982,6 +988,7 @@ class TestOperators(TestCase):
|
|||
skip('nn.functional.dropout2d', ''),
|
||||
skip('nn.functional.dropout3d', ''),
|
||||
skip('nn.functional.scaled_dot_product_attention'), # randomness
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
skip('nn.functional.multi_head_attention_forward'), # randomness
|
||||
skip('nn.functional.alpha_dropout'), # randomness
|
||||
skip('nn.functional.feature_alpha_dropout', 'without_train'),
|
||||
|
|
@ -1253,6 +1260,7 @@ class TestOperators(TestCase):
|
|||
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
|
||||
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
|
||||
skip('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
skip('nn.functional.multi_head_attention_forward'), # randomness
|
||||
skip('nn.functional.alpha_dropout'), # randomness
|
||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
|
|
@ -1376,6 +1384,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
|
||||
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
|
||||
skip('nn.functional.scaled_dot_product_attention'),
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
|
||||
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
||||
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
||||
|
|
@ -1500,6 +1509,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.dropout3d'), # calls random op
|
||||
xfail('nn.functional.dropout'), # calls random op
|
||||
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
||||
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
|
||||
xfail('nn.functional.alpha_dropout'), # calls randomn op
|
||||
|
|
@ -1768,6 +1778,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call
|
||||
xfail('nn.functional.max_unpool2d'), # contiguous call
|
||||
xfail('to_sparse'), # dispatch key issue
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/96560
|
||||
decorate('xlogy', decorator=skipIfRocm),
|
||||
|
|
|
|||
|
|
@ -3604,6 +3604,8 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('addcmul'),
|
||||
xfail('clamp'),
|
||||
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
|
||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||
xfail('item'),
|
||||
}))
|
||||
|
|
@ -3660,6 +3662,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
|
||||
xfail('nn.functional.scaled_dot_product_attention'), # randomness
|
||||
xfail('nn.functional.multi_head_attention_forward'), # randomness
|
||||
xfail('torch.ops.aten._efficient_attention_forward'), # outputs ints
|
||||
xfail('resize_'),
|
||||
xfail('view_as_complex'),
|
||||
xfail('matrix_exp'),
|
||||
|
|
|
|||
|
|
@ -234,6 +234,7 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"to_sparse": {f16, f32, f64},
|
||||
"pca_lowrank": {f32, f64},
|
||||
"svd_lowrank": {f32, f64},
|
||||
"torch.ops.aten._efficient_attention_forward": {f16, bf16, f32},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2783,6 +2783,10 @@
|
|||
# output_differentiability: [True, False, False, False, False, False, False, False]
|
||||
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
|
||||
output_differentiability: [True, False, False, False, False, False]
|
||||
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
|
||||
|
||||
# fft
|
||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
|
||||
|
|
|
|||
|
|
@ -2082,7 +2082,8 @@ make_fallback(
|
|||
sdpa_constraint,
|
||||
warn=False,
|
||||
)
|
||||
|
||||
make_fallback(torch.ops.aten._efficient_attention_forward.default)
|
||||
make_fallback(torch.ops.aten._efficient_attention_backward.default)
|
||||
make_fallback(aten.sort)
|
||||
make_fallback(aten.sort.stable)
|
||||
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
|
||||
|
|
|
|||
|
|
@ -5188,6 +5188,90 @@ def meta__scaled_dot_product_efficient_backward(
|
|||
return grad_q, grad_k, grad_v, grad_bias
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._efficient_attention_forward,
|
||||
]
|
||||
)
|
||||
def meta__efficient_attention_forward(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
cu_seqlens_q: Optional[Tensor],
|
||||
cu_seqlens_k: Optional[Tensor],
|
||||
max_seqlen_q: Optional[int],
|
||||
dropout_p: float,
|
||||
custom_mask_type: int,
|
||||
compute_log_sumexp: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
causal_diagonal: Optional[Tensor] = None,
|
||||
seqlen_k: Optional[Tensor] = None,
|
||||
):
|
||||
B = query.size(0)
|
||||
M = query.size(1)
|
||||
N = key.size(1)
|
||||
num_heads = query.size(-2)
|
||||
K = query.size(-1)
|
||||
Kv = value.size(-1)
|
||||
|
||||
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
|
||||
|
||||
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
|
||||
logsum_exp = torch.empty(
|
||||
(B, num_heads, logsumexp_dim),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
seed = torch.empty((), dtype=torch.long, device="meta")
|
||||
offset = torch.empty((), dtype=torch.long, device="meta")
|
||||
|
||||
return res, logsum_exp, seed, offset, M, N
|
||||
|
||||
|
||||
@register_meta(
|
||||
[
|
||||
aten._efficient_attention_backward,
|
||||
]
|
||||
)
|
||||
def meta__efficient_attention_backward(
|
||||
grad_out: Tensor,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
cu_seqlens_q: Optional[Tensor],
|
||||
cu_seqlens_k: Optional[Tensor],
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
logsumexp: Tensor,
|
||||
dropout_p: float,
|
||||
philox_seed: Tensor,
|
||||
philox_offset: Tensor,
|
||||
custom_mask_type: int,
|
||||
bias_requires_grad: bool,
|
||||
scale: Optional[float] = None,
|
||||
num_splits_key: Optional[int] = None,
|
||||
):
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
if bias is not None:
|
||||
assert bias is not None
|
||||
lastDim = bias.size(-1)
|
||||
lastDimAligned = 16 * ((lastDim + 15) // 16)
|
||||
new_sizes = list(bias.size())
|
||||
new_sizes[-1] = lastDimAligned
|
||||
grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
|
||||
else:
|
||||
grad_bias = torch.empty((), device=query.device)
|
||||
|
||||
return grad_query, grad_key, grad_value, grad_bias
|
||||
|
||||
|
||||
@register_meta([aten._scaled_mm.default])
|
||||
def meta_scaled_mm(
|
||||
self: torch.Tensor,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,25 @@ def output_alias_each_other(outputs):
|
|||
return False
|
||||
|
||||
|
||||
def is_sdpa_error(func, idx, e):
|
||||
if (
|
||||
func is aten._scaled_dot_product_flash_attention.default
|
||||
and idx in (6, 7)
|
||||
and "Devices" in repr(e)
|
||||
):
|
||||
return True
|
||||
if (
|
||||
(
|
||||
func is aten._scaled_dot_product_efficient_attention.default
|
||||
or func is aten._efficient_attention_forward.default
|
||||
)
|
||||
and idx in (2, 3)
|
||||
and "Devices" in repr(e)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class CrossRefFakeMode(TorchDispatchMode):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -155,17 +174,7 @@ class CrossRefFakeMode(TorchDispatchMode):
|
|||
allow_rhs_unbacked=True,
|
||||
)
|
||||
except Exception as e:
|
||||
if (
|
||||
func is aten._scaled_dot_product_flash_attention.default
|
||||
and idx in (6, 7)
|
||||
and "Devices" in repr(e)
|
||||
):
|
||||
continue
|
||||
if (
|
||||
func is aten._scaled_dot_product_efficient_attention.default
|
||||
and idx in (2, 3)
|
||||
and "Devices" in repr(e)
|
||||
):
|
||||
if is_sdpa_error(func, idx, e):
|
||||
continue
|
||||
error_message = (
|
||||
f"{context} mismatched tensor metadata: {e}"
|
||||
|
|
|
|||
|
|
@ -76,6 +76,10 @@ _all_types_and_half = _all_types + (torch.half,)
|
|||
def all_types_and_half():
|
||||
return _all_types_and_half
|
||||
|
||||
def custom_types(*dtypes):
|
||||
"""Create a list of arbitrary dtypes"""
|
||||
return _empty_types + _validate_dtypes(*dtypes)
|
||||
|
||||
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
|
||||
|
||||
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_dtype import (
|
||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||
all_types, empty_types, complex_types_and, integral_types
|
||||
all_types, empty_types, complex_types_and, integral_types, custom_types
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
|
|
@ -8242,6 +8242,78 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
|||
|
||||
yield from samples
|
||||
|
||||
|
||||
def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
batch, num_heads, head_dim = 4, 4, 8
|
||||
seq_q = 11
|
||||
seq_kv = 32
|
||||
|
||||
dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
|
||||
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
|
||||
|
||||
qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
|
||||
samples = []
|
||||
mask_types = [1, 2] # UpperLeft, LowerRight
|
||||
scales = [None, 1.0]
|
||||
|
||||
for qkv_shape, is_causal, dropout_p, mask_type, scale in product(
|
||||
qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales):
|
||||
shape_q, shape_kv = qkv_shape
|
||||
samples.append(SampleInput(
|
||||
make(shape_q).transpose(1, 2),
|
||||
make(shape_kv).transpose(1, 2),
|
||||
make(shape_kv).transpose(1, 2),
|
||||
bias=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
max_seqlen_q=None,
|
||||
dropout_p=dropout_p,
|
||||
custom_mask_type=mask_type,
|
||||
compute_log_sumexp=requires_grad,
|
||||
scale=scale,
|
||||
causal_diagonal=None,
|
||||
seqlen_k=None
|
||||
))
|
||||
|
||||
# Add non standard shapes
|
||||
diff_v_head_dim = SampleInput(
|
||||
make((batch, seq_q, num_heads, head_dim)),
|
||||
make((batch, seq_kv, num_heads, head_dim)),
|
||||
make((batch, seq_kv, num_heads, head_dim + 8)),
|
||||
bias=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
max_seqlen_q=None,
|
||||
dropout_p=dropout_p,
|
||||
custom_mask_type=0, # No Mask
|
||||
compute_log_sumexp=requires_grad,
|
||||
scale=None,
|
||||
causal_diagonal=None,
|
||||
seqlen_k=None
|
||||
)
|
||||
|
||||
# Add an attn_mask
|
||||
samples.append(
|
||||
SampleInput(
|
||||
make((batch, seq_q, num_heads, head_dim)),
|
||||
make((batch, seq_kv, num_heads, head_dim)),
|
||||
make((batch, seq_kv, num_heads, head_dim)),
|
||||
bias=make(batch, num_heads, seq_q, seq_kv),
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
max_seqlen_q=None,
|
||||
dropout_p=dropout_p,
|
||||
custom_mask_type=0, # No Mask
|
||||
compute_log_sumexp=requires_grad,
|
||||
scale=None,
|
||||
causal_diagonal=None,
|
||||
seqlen_k=None
|
||||
)
|
||||
)
|
||||
|
||||
yield from samples
|
||||
|
||||
def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
@ -14172,6 +14244,31 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
||||
device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),),
|
||||
),
|
||||
OpInfo(
|
||||
'torch.ops.aten._efficient_attention_forward',
|
||||
sample_inputs_func=sample_inputs_efficient_attention_forward,
|
||||
dtypes=empty_types(),
|
||||
dtypesIfCUDA=custom_types(torch.float16, torch.float32)
|
||||
if not SM80OrLater
|
||||
else custom_types(torch.float16, torch.float32, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_autograd=True,
|
||||
supports_fwgrad_bwgrad=False,
|
||||
supports_forward_ad=False,
|
||||
check_batched_forward_grad=False,
|
||||
decorators=[],
|
||||
skips=(
|
||||
# Device mismatch due to philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'),
|
||||
# Checking the scaler value of the philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
|
||||
# None Mismatch Tensor
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
|
||||
)
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
'nn.functional.silu',
|
||||
aten_backward_name='silu_backward',
|
||||
|
|
|
|||
|
|
@ -422,6 +422,7 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch
|
|||
results = output_process_fn_grad(results)
|
||||
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
|
||||
flat_diff_results = [r for r in flat_results if r.requires_grad]
|
||||
assert len(flat_diff_results) > 0
|
||||
|
||||
|
|
@ -467,6 +468,7 @@ def check_backward_formula(op: Callable, args, kwargs,
|
|||
)
|
||||
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
|
||||
flat_diff_results = [r for r in flat_results if r.requires_grad]
|
||||
assert len(flat_diff_results) > 0
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user