Call _sdp_attention in nn.functional.mha (#89470)

# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
This commit is contained in:
Driss Guessous 2022-12-02 19:46:22 +00:00 committed by PyTorch MergeBot
parent 3916d729c8
commit 78bdb858f9
6 changed files with 53 additions and 17 deletions

View File

@ -9,7 +9,6 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
@ -741,10 +740,10 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
}
auto attn_mask = attn_mask_;
// Naive, composite implementation defined here.
const auto embed_size = query_.size(-1);
// Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math
const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size)));
const auto embed_size = SymFloat(query_.sym_size(-1));
const auto scaling_factor = embed_size.sqrt().sqrt();
const auto query = query_ / scaling_factor;
if (is_causal) {
TORCH_CHECK(!attn_mask.has_value(),
@ -753,8 +752,8 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
const auto L = query.size(-2), S = key.size(-2);
attn_mask = at::ones({L, S}, query.options().dtype(at::kBool)).tril();
const auto L = query.sym_size(-2), S = key.sym_size(-2);
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
}
if (attn_mask.has_value()) {
TORCH_CHECK(!query.is_nested() && !key.is_nested(),

View File

@ -40,7 +40,9 @@ inline bool check_tensor_dtype(
allowed_dtypes.end()))) {
TORCH_CHECK(
!debug,
"Expected query, key and value to be of dtype float16 or bfloat16 but got Query dtype: ",
"Expected query, key and value to all be of dtype: {",
c10::Join(", ", allowed_dtypes), "}. Got ",
"Query dtype: ",
params.query.dtype(),
", Key dtype: ",
params.key.dtype(),
@ -162,6 +164,25 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
return true;
}
inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
const int64_t query_size_last = params.query.size(-1);
if (!(query_size_last == params.key.size(-1) &&
query_size_last == params.value.size(-1) && query_size_last >= 8)) {
TORCH_CHECK(
!debug,
"Mem efficient attention requires last dimension of inputs to be >= 8.",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.size(-1),
", Value.size(-1): ",
params.value.size(-1),
" instead.");
return false;
}
return true;
}
inline bool check_runtime_disabled_flash(sdp_params params, bool debug) {
// We check the global context to see if user has explicitly turned of flash
// sdp kernels
@ -259,13 +280,14 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
at::kHalf, at::kFloat, at::kBFloat16};
// Define gate functions that determine if a flash kernel can be ran
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints{{
constexpr std::array<bool(*)(sdp_params, bool), 9> constraints{{
check_gpu_sm50_or_greater,
check_runtime_disabled_mem_efficient,
check_requires_grad_and_nested,
check_for_attn_weights,
check_tensor_shapes,
check_for_attn_mask,
check_head_dim_size_mem_efficient,
check_for_seq_len_1_nested_tensor,
check_for_non_zero_dropout}};
for (auto& constraint : constraints) {

View File

@ -1,6 +1,7 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
#include <cmath>
#include <utility>
namespace c10 {
@ -70,6 +71,15 @@ std::ostream& operator<<(std::ostream& os, const SymFloat& s) {
return os;
}
SymFloat SymFloat::sqrt() const {
if (!is_symbolic()) {
return SymFloat(std::sqrt(data_));
}
auto other = SymFloat(-0.5);
auto res = normalize_symfloats(*this, other);
return SymFloat(res[0]->pow(res[1]));
}
double SymFloat::guard_float(const char* file, int64_t line) const {
if (!is_symbolic()) {
return data_;

View File

@ -40,6 +40,9 @@ class C10_API SymFloat {
SymFloat operator*(const SymFloat&) const;
SymFloat operator/(const SymFloat&) const;
// Need guidance on where to put this code
SymFloat sqrt() const;
// Insert a guard for the float to be its concrete value, and then return
// that value. This operation always works, even if the float is symbolic,
// so long as we know what the underlying value is. Don't blindly put this

View File

@ -394,6 +394,7 @@ class TestModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
)
@skipScriptTest() # TODO: #75625
@skipIfUnsupportedMinOpsetVersion(20)
def test_transformer_encoder(self):
class MyModule(torch.nn.Module):
def __init__(self, ninp, nhead, nhid, dropout, nlayers):

View File

@ -5173,19 +5173,20 @@ def multi_head_attention_forward(
# (deep breath) calculate attention and out projection
#
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
if attn_mask.size(0) == 1:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
attn_output = torch.bmm(attn_output_weights, v)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output, attn_output_weights = _scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, need_weights, False)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))