mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Call _sdp_attention in nn.functional.mha (#89470)
# Summary Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89470 Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
This commit is contained in:
parent
3916d729c8
commit
78bdb858f9
|
|
@ -9,7 +9,6 @@
|
|||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
|
|
@ -741,10 +740,10 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
}
|
||||
auto attn_mask = attn_mask_;
|
||||
// Naive, composite implementation defined here.
|
||||
const auto embed_size = query_.size(-1);
|
||||
|
||||
// Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size)));
|
||||
const auto embed_size = SymFloat(query_.sym_size(-1));
|
||||
const auto scaling_factor = embed_size.sqrt().sqrt();
|
||||
const auto query = query_ / scaling_factor;
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(!attn_mask.has_value(),
|
||||
|
|
@ -753,8 +752,8 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
|
||||
|
||||
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
|
||||
const auto L = query.size(-2), S = key.size(-2);
|
||||
attn_mask = at::ones({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
const auto L = query.sym_size(-2), S = key.sym_size(-2);
|
||||
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
}
|
||||
if (attn_mask.has_value()) {
|
||||
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
|
||||
|
|
|
|||
|
|
@ -40,7 +40,9 @@ inline bool check_tensor_dtype(
|
|||
allowed_dtypes.end()))) {
|
||||
TORCH_CHECK(
|
||||
!debug,
|
||||
"Expected query, key and value to be of dtype float16 or bfloat16 but got Query dtype: ",
|
||||
"Expected query, key and value to all be of dtype: {",
|
||||
c10::Join(", ", allowed_dtypes), "}. Got ",
|
||||
"Query dtype: ",
|
||||
params.query.dtype(),
|
||||
", Key dtype: ",
|
||||
params.key.dtype(),
|
||||
|
|
@ -162,6 +164,25 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
|
||||
const int64_t query_size_last = params.query.size(-1);
|
||||
if (!(query_size_last == params.key.size(-1) &&
|
||||
query_size_last == params.value.size(-1) && query_size_last >= 8)) {
|
||||
TORCH_CHECK(
|
||||
!debug,
|
||||
"Mem efficient attention requires last dimension of inputs to be >= 8.",
|
||||
"Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
params.key.size(-1),
|
||||
", Value.size(-1): ",
|
||||
params.value.size(-1),
|
||||
" instead.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_runtime_disabled_flash(sdp_params params, bool debug) {
|
||||
// We check the global context to see if user has explicitly turned of flash
|
||||
// sdp kernels
|
||||
|
|
@ -259,13 +280,14 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
|||
at::kHalf, at::kFloat, at::kBFloat16};
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints{{
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 9> constraints{{
|
||||
check_gpu_sm50_or_greater,
|
||||
check_runtime_disabled_mem_efficient,
|
||||
check_requires_grad_and_nested,
|
||||
check_for_attn_weights,
|
||||
check_tensor_shapes,
|
||||
check_for_attn_mask,
|
||||
check_head_dim_size_mem_efficient,
|
||||
check_for_seq_len_1_nested_tensor,
|
||||
check_for_non_zero_dropout}};
|
||||
for (auto& constraint : constraints) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <utility>
|
||||
|
||||
namespace c10 {
|
||||
|
|
@ -70,6 +71,15 @@ std::ostream& operator<<(std::ostream& os, const SymFloat& s) {
|
|||
return os;
|
||||
}
|
||||
|
||||
SymFloat SymFloat::sqrt() const {
|
||||
if (!is_symbolic()) {
|
||||
return SymFloat(std::sqrt(data_));
|
||||
}
|
||||
auto other = SymFloat(-0.5);
|
||||
auto res = normalize_symfloats(*this, other);
|
||||
return SymFloat(res[0]->pow(res[1]));
|
||||
}
|
||||
|
||||
double SymFloat::guard_float(const char* file, int64_t line) const {
|
||||
if (!is_symbolic()) {
|
||||
return data_;
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ class C10_API SymFloat {
|
|||
SymFloat operator*(const SymFloat&) const;
|
||||
SymFloat operator/(const SymFloat&) const;
|
||||
|
||||
// Need guidance on where to put this code
|
||||
SymFloat sqrt() const;
|
||||
|
||||
// Insert a guard for the float to be its concrete value, and then return
|
||||
// that value. This operation always works, even if the float is symbolic,
|
||||
// so long as we know what the underlying value is. Don't blindly put this
|
||||
|
|
|
|||
|
|
@ -394,6 +394,7 @@ class TestModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
|
||||
@skipScriptTest() # TODO: #75625
|
||||
@skipIfUnsupportedMinOpsetVersion(20)
|
||||
def test_transformer_encoder(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, ninp, nhead, nhid, dropout, nlayers):
|
||||
|
|
|
|||
|
|
@ -5173,19 +5173,20 @@ def multi_head_attention_forward(
|
|||
# (deep breath) calculate attention and out projection
|
||||
#
|
||||
|
||||
B, Nt, E = q.shape
|
||||
q_scaled = q / math.sqrt(E)
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||
else:
|
||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
if dropout_p > 0.0:
|
||||
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
||||
if attn_mask.size(0) == 1:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
else:
|
||||
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
attn_output, attn_output_weights = _scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, need_weights, False)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user