Move native MHA code out of PyTorch core (#72944)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72944

Doesn't make sense to develop it in core right now.
ghstack-source-id: 149456040

Test Plan:
CI

run MHA benchmark in benchmark_transformers.py to make sure it doesn't crash

Reviewed By: zrphercule

Differential Revision: D34283104

fbshipit-source-id: 4f0c7a6bc066f938ceac891320d4cf4c3f8a9cd6
(cherry picked from commit b9df65e97c)
This commit is contained in:
Scott Wolchok 2022-02-18 12:37:56 -08:00 committed by PyTorch MergeBot
parent 1646a0033d
commit 79a216ce57
7 changed files with 0 additions and 733 deletions

View File

@ -1,339 +0,0 @@
#include <type_traits>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec256/vec256.h>
namespace at {
namespace native {
namespace {
Tensor gemm_nt(const Tensor& a, const Tensor& b) {
return at::native::matmul(a, b.t());
}
template <typename scalar_t>
void transform_bias_rescale_qkv_inner_loop(
int64_t B,
int64_t T,
int64_t _3D,
int64_t D,
int64_t num_head,
int64_t dim_per_head,
scalar_t* qkv_data,
scalar_t* qkv_bias_data,
scalar_t* q_k_v_data,
scalar_t sqrt_dim_per_head,
int64_t begin,
int64_t end) {
for (auto i : c10::irange(begin, end)) {
auto t = i % T;
i /= T;
auto nh = i % num_head;
i /= num_head;
auto b = i;
using Vec = vec::Vectorized<scalar_t>;
auto V = vec::Vectorized<scalar_t>::size();
auto dh = 0;
auto d = nh * dim_per_head;
for (; dh + V <= dim_per_head; dh += V, d += V) {
// load
auto q_bias_data = Vec::loadu(&qkv_bias_data[d + 0 * D]);
auto k_bias_data = Vec::loadu(&qkv_bias_data[d + 1 * D]);
auto v_bias_data = Vec::loadu(&qkv_bias_data[d + 2 * D]);
auto q_data =
Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 0 * D]) +
q_bias_data;
auto k_data =
Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 1 * D]) +
k_bias_data;
auto v_data =
Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 2 * D]) +
v_bias_data;
q_data = q_data / Vec(sqrt_dim_per_head);
q_data.store(&q_k_v_data
[0 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head +
t * dim_per_head + dh]);
k_data.store(&q_k_v_data
[1 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head +
t * dim_per_head + dh]);
v_data.store(&q_k_v_data
[2 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head +
t * dim_per_head + dh]);
}
for (; dh < dim_per_head; dh++) {
auto d = nh * dim_per_head + dh;
auto q_bias = qkv_bias_data[d + 0 * D];
auto k_bias = qkv_bias_data[d + 1 * D];
auto v_bias = qkv_bias_data[d + 2 * D];
auto q_data = qkv_data[b * _3D * T + t * _3D + d + 0 * D] + q_bias;
auto k_data = qkv_data[b * _3D * T + t * _3D + d + 1 * D] + k_bias;
auto v_data = qkv_data[b * _3D * T + t * _3D + d + 2 * D] + v_bias;
q_data = q_data / sqrt_dim_per_head;
q_k_v_data[0 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head +
t * dim_per_head + dh] = q_data;
q_k_v_data[1 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head +
t * dim_per_head + dh] = k_data;
q_k_v_data[2 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head +
t * dim_per_head + dh] = v_data;
}
}
}
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
const Tensor& qkv,
const Tensor& qkv_bias,
const int64_t num_head) {
auto B = qkv.size(0);
auto T = qkv.size(1);
auto _3D = qkv.size(2);
auto D = _3D / 3;
TORCH_CHECK(D % num_head == 0);
TORCH_CHECK(_3D % 3 == 0);
const auto dim_per_head = D / num_head;
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v.is_contiguous());
const auto qkv_contig = qkv.expect_contiguous();
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
qkv.scalar_type(),
"transform_bias_rescale_qkv",
[&] {
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
int64_t grain_size =
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
parallel_for(
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
transform_bias_rescale_qkv_inner_loop(B, T, _3D, D, num_head, dim_per_head, qkv_data, qkv_bias_data, q_k_v_data, sqrt_dim_per_head, begin, end);
});
});
auto q_k_v_s =
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
}
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
auto bt_ = b_.transpose(2, 1);
// TODO: are these a single call to cublas batched matmul?
auto c_ = at::matmul(a_, bt_);
return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
}
void masked_softmax_dropout(
Tensor& attn_scores,
const c10::optional<Tensor>& attn_mask) {
auto B = attn_scores.size(0);
auto num_heads = attn_scores.size(1);
auto T = attn_scores.size(2);
if (attn_mask) {
TORCH_CHECK(attn_mask->is_contiguous());
} else {
at::_softmax_out(attn_scores, attn_scores, 3, false);
return;
}
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
attn_scores.scalar_type(),
"masked_softmax_dropout",
[&] {
using accscalar_t = acc_type<scalar_t, false>;
// TODO: proper implementation with masking.
scalar_t* attn_scores_data = attn_scores.data_ptr<scalar_t>();
int64_t grain_size = std::min(internal::GRAIN_SIZE / T, (int64_t)1);
parallel_for(
0, B * num_heads * T, grain_size, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
using Vec = vec::Vectorized<scalar_t>;
auto V = vec::Vectorized<scalar_t>::size();
scalar_t* input_data = attn_scores_data + i;
auto max_input = Vec(std::numeric_limits<scalar_t>::lowest());
// TODO: handle epilogue
TORCH_CHECK(T % V == 0, "epilogue not implemented yet");
for (auto t = 0; t < T; t += V) {
auto v = Vec::loadu(&input_data[t]);
max_input = vec::maximum(max_input, v);
}
auto hmax = std::numeric_limits<scalar_t>::lowest();
for (auto i = 0; i < V; ++i) {
hmax = std::max(max_input[i], hmax);
}
accscalar_t hsum = 0;
TORCH_CHECK(T % V == 0, "epilogue not implemented yet");
for (auto t = 0; t < T; t += V) {
auto v = Vec::loadu(&input_data[t]);
// TODO: vectorize in accscalar_t?
for (auto i = 0; i < V; ++i) {
hsum += std::exp(static_cast<accscalar_t>(v[i]) - hmax);
}
}
auto inv_denominator = 1.0 / hsum;
TORCH_CHECK(T % V == 0, "epilogue not implemented yet");
for (auto t = 0; t < T; t += V) {
Vec v = Vec::loadu(&input_data[t]);
// TODO: vectorize in accscalar_t?
// TODO this faster solution does not work on Android build
/*
for (auto i = 0; i < V; ++i) {
v[i] = static_cast<scalar_t>(std::exp(static_cast<accscalar_t>(v[i]) - hmax) * inv_denominator);
}
v.store(&input_data[t]);
*/
for (auto i = 0; i < V; ++i) {
input_data[t + i] = static_cast<scalar_t>(std::exp(static_cast<accscalar_t>(v[i]) - hmax) * inv_denominator);
}
}
}
});
});
}
Tensor bmm_nn(const Tensor& a, const Tensor& b) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
// TODO: are these a single call to cublas batched matmul?
auto c_ = at::matmul(a_, b_);
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
}
Tensor transform_0213(const Tensor& a) {
// TODO: check perf vs dedicated kernel.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
return a.permute({0, 2, 1, 3})
.contiguous()
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
}
Tensor gemm_nt_bias(const Tensor& a, const Tensor& b, const Tensor& c) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2)});
auto r_ = at::native::linear(a_, b, c);
return r_.view({a.size(0), a.size(1), r_.size(1)});
}
void debug_assert_shape(const Tensor& t, c10::IntArrayRef shape) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((size_t)t.dim() == shape.size(), "expected ", shape.size(), "-D tensor but got ", t.dim());
for (auto idx : c10::irange(shape.size())) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t.sizes()[idx] == shape[idx], "expected dim ", idx, " to be ", shape[idx], " but got ", t.sizes()[idx]);
}
}
} // namespace
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_op_cpu(
const Tensor& qkv,
const Tensor& qkv_bias,
const int64_t num_head) {
auto result = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
return std::make_tuple(std::get<0>(result).clone(), std::get<1>(result).clone(), std::get<2>(result).clone());
}
Tensor multi_head_self_attention_cpu(
const Tensor& query,
const Tensor& qkv_weight,
const Tensor& qkv_bias,
const Tensor& proj_weight,
const Tensor& proj_bias,
const int64_t num_head,
const c10::optional<Tensor>& mask) {
// query shape: [B, T, D]
// qkv_weight shape: [3 * D, D]
const auto D = query.sizes()[2];
TORCH_CHECK(query.dim() == 3, "expected 3-dimensional query, got ", query.dim(), "-D tensor");
TORCH_CHECK(qkv_weight.dim() == 2, "expected 2-dimensional qkv_weight, got ", qkv_weight.dim(), "-D tensor");
TORCH_CHECK(D * 3 == qkv_weight.sizes()[0], "expected qkv_weight first dim to be 3x last dim of query");
TORCH_CHECK(D == qkv_weight.sizes()[1], "expected qkv_weight second dim and last dim of query to be equal");
TORCH_CHECK(qkv_bias.dim() == 1, "expected 2-dimensional qkv_bias, got ", qkv_bias.dim(), "-D tensor");
TORCH_CHECK(qkv_bias.sizes()[0] == 3 * D, "expected qkv_bias first dim and first dim of query to be equal");
TORCH_CHECK(D % num_head == 0, "D must divide evenly by num_head");
#ifndef NDEBUG
const auto B = query.sizes()[0];
const auto T = query.sizes()[1];
const auto dim_per_head = D / num_head;
#endif
// shape: [B, T, 3 x D]
auto qkv = gemm_nt(query, qkv_weight);
#ifndef NDEBUG
debug_assert_shape(qkv, {B, T, 3 * D});
#endif
// shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
const auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG
debug_assert_shape(q, {B, num_head, T, dim_per_head});
debug_assert_shape(k, {B, num_head, T, dim_per_head});
debug_assert_shape(v, {B, num_head, T, dim_per_head});
#endif
// shape: [B, num_head, T, T]
auto qkt = bmm_nt(q, k);
#ifndef NDEBUG
debug_assert_shape(qkt, {B, num_head, T, T});
#endif
// shape: [B, num_head, T, T]
masked_softmax_dropout(qkt, mask);
// shape: [B, num_head, T, dim_per_head]
auto attn_ctx = bmm_nn(qkt, v);
#ifndef NDEBUG
debug_assert_shape(attn_ctx, {B, num_head, T, dim_per_head});
#endif
// shape: [B, T, D]
auto attn = transform_0213(attn_ctx);
#ifndef NDEBUG
debug_assert_shape(attn, {B, T, D});
#endif
// shape: [B, T, D]
auto proj = gemm_nt_bias(attn, proj_weight, proj_bias);
#ifndef NDEBUG
debug_assert_shape(proj, {B, T, D});
#endif
return proj;
}
} // namespace native
} // namespace at

View File

@ -1,342 +0,0 @@
#include <type_traits>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/PersistentSoftmax.cuh>
#include <c10/cuda/CUDAMathCompat.h>
namespace at {
namespace native {
namespace {
Tensor gemm_nt(const Tensor& a, const Tensor& b) {
return at::native::matmul(a, b.t());
}
static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
template <typename scalar_t, typename accscalar_t, bool assume_aligned>
__global__ void transform_bias_rescale_qkv_kernel(
// [B, T, 3 * D]
const PackedTensorAccessor64<scalar_t, 3, RestrictPtrTraits> qkv,
// [3 * D]
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
// [3, B, NH, T, DH]
PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v) {
// warp per DH.
// so launch B * NH * T warps.
auto NH = q_k_v.size(2);
auto T = q_k_v.size(3);
auto DH = q_k_v.size(4);
auto t = blockIdx.x % T;
auto b = blockIdx.x / T;
auto D = NH * DH;
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(DH));
if (assume_aligned) {
constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
using LoadT = memory::aligned_vector<scalar_t, VEC>;
for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
auto d = d_v * VEC;
auto nh = d / DH;
auto dh = d % DH;
scalar_t qkv_bias_q[VEC];
scalar_t qkv_bias_k[VEC];
scalar_t qkv_bias_v[VEC];
scalar_t qkv_q[VEC];
scalar_t qkv_k[VEC];
scalar_t qkv_v[VEC];
// Here we require D % VEC == 0 for these vectorized loads.
*reinterpret_cast<LoadT*>(&qkv_bias_q) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
*reinterpret_cast<LoadT*>(&qkv_bias_k) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
*reinterpret_cast<LoadT*>(&qkv_bias_v) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
*reinterpret_cast<LoadT*>(&qkv_q) =
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 0 * D]);
*reinterpret_cast<LoadT*>(&qkv_k) =
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 1 * D]);
*reinterpret_cast<LoadT*>(&qkv_v) =
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 2 * D]);
#pragma unroll
// TODO: specialize for float2half2/half2float2?
for (auto ii = 0; ii < VEC; ++ii) {
qkv_q[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q[ii]) +
static_cast<accscalar_t>(qkv_bias_q[ii])) /
static_cast<accscalar_t>(sqrt_dim_per_head));
qkv_k[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k[ii]) +
static_cast<accscalar_t>(qkv_bias_k[ii])));
qkv_v[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v[ii]) +
static_cast<accscalar_t>(qkv_bias_v[ii])));
}
// Here we require DH % VEC == 0 for these vectorized stores.
*reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_q);
*reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_k);
*reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_v);
}
} else {
// Same as above, but we can't vectorize memory access.
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
auto nh = d / DH;
auto dh = d % DH;
scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
scalar_t qkv_q = qkv[b][t][d + 0 * D];
scalar_t qkv_k = qkv[b][t][d + 1 * D];
scalar_t qkv_v = qkv[b][t][d + 2 * D];
qkv_q = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q) +
static_cast<accscalar_t>(qkv_bias_q)) /
static_cast<accscalar_t>(sqrt_dim_per_head));
qkv_k = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k) +
static_cast<accscalar_t>(qkv_bias_k)));
qkv_v = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v) +
static_cast<accscalar_t>(qkv_bias_v)));
q_k_v[0][b][nh][t][dh] = qkv_q;
q_k_v[1][b][nh][t][dh] = qkv_k;
q_k_v[2][b][nh][t][dh] = qkv_v;
}
}
}
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
const Tensor& qkv,
const Tensor& qkv_bias,
const int64_t num_head) {
auto B = qkv.size(0);
auto T = qkv.size(1);
auto _3D = qkv.size(2);
auto D = _3D / 3;
TORCH_CHECK(D % num_head == 0);
const auto dim_per_head = D / num_head;
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
#define CALL_KERNEL(assume_aligned) \
transform_bias_rescale_qkv_kernel<scalar_t, accscalar_t, assume_aligned> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
qkv.packed_accessor64<scalar_t, 3, RestrictPtrTraits>(), \
qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>())
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
qkv.scalar_type(),
"transform_bias_rescale_qkv",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
auto threads = std::max(std::min<int32_t>(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1);
auto blocks = B * T;
if (dim_per_head % TRANSFORM_BIAS_RESCALE_VEC == 0) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
D % TRANSFORM_BIAS_RESCALE_VEC == 0,
"D = num_heads * dim_per_head, so we should have dim_per_head % "
"TRANSFORM_BIAS_RESCALE_VEC == 0 => "
"D % TRANSFORM_BIAS_RESCALE_VEC == 0");
CALL_KERNEL(true);
} else {
CALL_KERNEL(false);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef CALL_KERNEL
auto q_k_v_s =
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
}
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
auto bt_ = b_.transpose(2, 1);
// TODO: are these a single call to cublas batched matmul?
auto c_ = at::matmul(a_, bt_);
return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
}
template <typename T>
__inline__ __device__ T WarpReduceMax(T val) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val = std::max(val, WARP_SHFL_DOWN(val, offset));
}
return val;
}
template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
void masked_softmax_dropout(
const Tensor& attn_scores,
const c10::optional<Tensor>& attn_mask) {
auto B = attn_scores.size(0);
auto num_heads = attn_scores.size(1);
auto T = attn_scores.size(2);
if (attn_mask) {
TORCH_CHECK(attn_mask->is_contiguous());
}
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
attn_scores.scalar_type(),
"masked_softmax_dropout",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
// TODO: proper implementation with masking.
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false, false>(
attn_scores.data_ptr<scalar_t>(),
attn_scores.data_ptr<scalar_t>(),
T,
T,
B * num_heads * T
);
});
}
Tensor bmm_nn(const Tensor& a, const Tensor& b) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
// TODO: are these a single call to cublas batched matmul?
auto c_ = at::matmul(a_, b_);
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
}
Tensor transform_0213(const Tensor& a) {
// TODO: check perf vs dedicated kernel.
return a.permute({0, 2, 1, 3})
.contiguous()
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
}
Tensor gemm_nt_bias(const Tensor& a, const Tensor& b, const Tensor& c) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2)});
auto r_ = at::native::linear(a_, b, c);
return r_.view({a.size(0), a.size(1), r_.size(1)});
}
void debug_assert_shape(const Tensor& t, c10::IntArrayRef shape) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((size_t)t.dim() == shape.size(), "expected ", shape.size(), "-D tensor but got ", t.dim());
for (auto idx : c10::irange(shape.size())) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t.sizes()[idx] == shape[idx], "expected dim ", idx, " to be ", shape[idx], " but got ", t.sizes()[idx]);
}
}
} // namespace
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_op_cuda(
const Tensor& qkv,
const Tensor& qkv_bias,
const int64_t num_head) {
auto result = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
return std::make_tuple(std::get<0>(result).clone(), std::get<1>(result).clone(), std::get<2>(result).clone());
}
Tensor multi_head_self_attention_cuda(
const Tensor& query,
const Tensor& qkv_weight,
const Tensor& qkv_bias,
const Tensor& proj_weight,
const Tensor& proj_bias,
const int64_t num_head,
const c10::optional<Tensor>& mask) {
// query shape: [B, T, D]
// qkv_weight shape: [3 * D, D]
const auto D = query.sizes()[2];
TORCH_CHECK(query.dim() == 3, "expected 3-dimensional query, got ", query.dim(), "-D tensor");
TORCH_CHECK(qkv_weight.dim() == 2, "expected 2-dimensional qkv_weight, got ", qkv_weight.dim(), "-D tensor");
TORCH_CHECK(D * 3 == qkv_weight.sizes()[0], "expected qkv_weight first dim to be 3x last dim of query");
TORCH_CHECK(D == qkv_weight.sizes()[1], "expected qkv_weight second dim and last dim of query to be equal");
TORCH_CHECK(D % num_head == 0, "D must divide evenly by num_head");
#ifndef NDEBUG
const auto B = query.sizes()[0];
const auto T = query.sizes()[1];
const auto dim_per_head = D / num_head;
#endif
// shape: [B, T, 3 x D]
auto qkv = gemm_nt(query, qkv_weight);
#ifndef NDEBUG
debug_assert_shape(qkv, {B, T, 3 * D});
#endif
// shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
const auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG
debug_assert_shape(q, {B, num_head, T, dim_per_head});
debug_assert_shape(k, {B, num_head, T, dim_per_head});
debug_assert_shape(v, {B, num_head, T, dim_per_head});
#endif
// shape: [B, num_head, T, T]
auto qkt = bmm_nt(q, k);
#ifndef NDEBUG
debug_assert_shape(qkt, {B, num_head, T, T});
#endif
// shape: [B, num_head, T, T]
masked_softmax_dropout(qkt, mask);
// shape: [B, num_head, T, dim_per_head]
auto attn_ctx = bmm_nn(qkt, v);
#ifndef NDEBUG
debug_assert_shape(attn_ctx, {B, num_head, T, dim_per_head});
#endif
// shape: [B, T, D]
auto attn = transform_0213(attn_ctx);
#ifndef NDEBUG
debug_assert_shape(attn, {B, T, D});
#endif
// shape: [B, T, D]
auto proj = gemm_nt_bias(attn, proj_weight, proj_bias);
#ifndef NDEBUG
debug_assert_shape(proj, {B, T, D});
#endif
return proj;
}
} // namespace native
} // namespace at

View File

@ -2549,16 +2549,6 @@
CUDA: layer_norm_cuda CUDA: layer_norm_cuda
CompositeImplicitAutograd: math_native_layer_norm CompositeImplicitAutograd: math_native_layer_norm
- func: _native_multi_head_self_attention(Tensor query, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, int num_head, Tensor? mask=None) -> Tensor
dispatch:
CPU: multi_head_self_attention_cpu
CUDA: multi_head_self_attention_cuda
- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_head) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: transform_bias_rescale_qkv_op_cpu
CUDA: transform_bias_rescale_qkv_op_cuda
- func: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - func: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
dispatch: dispatch:
CPU: layer_norm_backward_cpu CPU: layer_norm_backward_cpu

View File

@ -104,8 +104,6 @@ ALLOW_LIST = [
("aten::nanquantile", datetime.date(2022, 9, 30)), ("aten::nanquantile", datetime.date(2022, 9, 30)),
("aten::_convolution_double_backward", datetime.date(2022, 3, 31)), ("aten::_convolution_double_backward", datetime.date(2022, 3, 31)),
("aten::_scatter_reduce", datetime.date(2022, 1, 31)), ("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)), ("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
] ]

View File

@ -17540,44 +17540,6 @@ class TestNNDeviceType(NNTestCase):
self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True) self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True) self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
@dtypesIfCUDA(torch.float)
@dtypes(torch.float)
def test_transform_bias_rescale_qkv(self, device, dtype):
# TODO: debug CPU test failure with settings (48, 4, 16, 8) and add that mode
tests = [
(64, 4, 16, 8),
# dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
(24, 2, 4, 2),
# Make sure CUDA can handle small input sizes
(2, 2, 2, 2),
# dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4, causes alignment issues
(24, 4, 4, 2)
]
for (embed_dim, num_heads, sl, bs) in tests:
x = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) * 10
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
with torch.no_grad():
(q, k, v) = torch._transform_bias_rescale_qkv(x @ qkv.weight.t(), qkv.bias, num_head=num_heads)
def simple_transform_bias_rescale_qkv(qkv, bias):
(q, k, v) = torch.split(qkv, embed_dim, dim=-1)
(q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
return tuple(
x.reshape((sl, bs, num_heads, embed_dim // num_heads)).transpose(2, 1)
for x in (
(q + q_bias) / math.sqrt(embed_dim // num_heads),
(k + k_bias),
(v + v_bias)
)
)
correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(x @ qkv.weight.t(), qkv.bias)
self.assertEqual(q.size(), correct_q.size())
self.assertTrue(torch.allclose(q, correct_q))
self.assertTrue(torch.allclose(k, correct_k))
self.assertTrue(torch.allclose(v, correct_v))
@onlyCUDA @onlyCUDA
@dtypes(torch.half, torch.float, torch.double) @dtypes(torch.half, torch.float, torch.double)
def test_multihead_attention_dtype(self, device, dtype): def test_multihead_attention_dtype(self, device, dtype):

View File

@ -1176,7 +1176,6 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/quantized/library.cpp", "aten/src/ATen/native/quantized/library.cpp",
"aten/src/ATen/quantized/QTensorImpl.cpp", "aten/src/ATen/quantized/QTensorImpl.cpp",
"aten/src/ATen/quantized/Quantizer.cpp", "aten/src/ATen/quantized/Quantizer.cpp",
"aten/src/ATen/native/attention.cpp",
"aten/src/ATen/native/Activation.cpp", "aten/src/ATen/native/Activation.cpp",
"aten/src/ATen/native/AdaptiveAveragePooling.cpp", "aten/src/ATen/native/AdaptiveAveragePooling.cpp",
"aten/src/ATen/native/AdaptiveAveragePooling3d.cpp", "aten/src/ATen/native/AdaptiveAveragePooling3d.cpp",

View File

@ -669,7 +669,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
torch.native_dropout: lambda input, p, train: -1, torch.native_dropout: lambda input, p, train: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch._native_multi_head_self_attention: lambda query, qkv_weight, qkv_bias, proj_weight, proj_bias, mask=None: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
torch.native_norm: lambda input, p=2: -1, torch.native_norm: lambda input, p=2: -1,
torch.native_norm: lambda input, p=2: -1, torch.native_norm: lambda input, p=2: -1,