mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move native MHA code out of PyTorch core (#72944)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72944
Doesn't make sense to develop it in core right now.
ghstack-source-id: 149456040
Test Plan:
CI
run MHA benchmark in benchmark_transformers.py to make sure it doesn't crash
Reviewed By: zrphercule
Differential Revision: D34283104
fbshipit-source-id: 4f0c7a6bc066f938ceac891320d4cf4c3f8a9cd6
(cherry picked from commit b9df65e97c)
This commit is contained in:
parent
1646a0033d
commit
79a216ce57
|
|
@ -1,339 +0,0 @@
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/AccumulateType.h>
|
|
||||||
#include <ATen/Dispatch.h>
|
|
||||||
#include <ATen/NativeFunctions.h>
|
|
||||||
#include <ATen/Parallel.h>
|
|
||||||
#include <ATen/cpu/vec/vec256/vec256.h>
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
|
|
||||||
namespace native {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
Tensor gemm_nt(const Tensor& a, const Tensor& b) {
|
|
||||||
return at::native::matmul(a, b.t());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
void transform_bias_rescale_qkv_inner_loop(
|
|
||||||
int64_t B,
|
|
||||||
int64_t T,
|
|
||||||
int64_t _3D,
|
|
||||||
int64_t D,
|
|
||||||
int64_t num_head,
|
|
||||||
int64_t dim_per_head,
|
|
||||||
scalar_t* qkv_data,
|
|
||||||
scalar_t* qkv_bias_data,
|
|
||||||
scalar_t* q_k_v_data,
|
|
||||||
scalar_t sqrt_dim_per_head,
|
|
||||||
int64_t begin,
|
|
||||||
int64_t end) {
|
|
||||||
for (auto i : c10::irange(begin, end)) {
|
|
||||||
auto t = i % T;
|
|
||||||
i /= T;
|
|
||||||
auto nh = i % num_head;
|
|
||||||
i /= num_head;
|
|
||||||
auto b = i;
|
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
|
||||||
auto V = vec::Vectorized<scalar_t>::size();
|
|
||||||
auto dh = 0;
|
|
||||||
auto d = nh * dim_per_head;
|
|
||||||
for (; dh + V <= dim_per_head; dh += V, d += V) {
|
|
||||||
// load
|
|
||||||
auto q_bias_data = Vec::loadu(&qkv_bias_data[d + 0 * D]);
|
|
||||||
auto k_bias_data = Vec::loadu(&qkv_bias_data[d + 1 * D]);
|
|
||||||
auto v_bias_data = Vec::loadu(&qkv_bias_data[d + 2 * D]);
|
|
||||||
|
|
||||||
auto q_data =
|
|
||||||
Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 0 * D]) +
|
|
||||||
q_bias_data;
|
|
||||||
auto k_data =
|
|
||||||
Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 1 * D]) +
|
|
||||||
k_bias_data;
|
|
||||||
auto v_data =
|
|
||||||
Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 2 * D]) +
|
|
||||||
v_bias_data;
|
|
||||||
|
|
||||||
q_data = q_data / Vec(sqrt_dim_per_head);
|
|
||||||
|
|
||||||
q_data.store(&q_k_v_data
|
|
||||||
[0 * B * num_head * T * dim_per_head +
|
|
||||||
b * num_head * T * dim_per_head +
|
|
||||||
nh * T * dim_per_head +
|
|
||||||
t * dim_per_head + dh]);
|
|
||||||
k_data.store(&q_k_v_data
|
|
||||||
[1 * B * num_head * T * dim_per_head +
|
|
||||||
b * num_head * T * dim_per_head +
|
|
||||||
nh * T * dim_per_head +
|
|
||||||
t * dim_per_head + dh]);
|
|
||||||
v_data.store(&q_k_v_data
|
|
||||||
[2 * B * num_head * T * dim_per_head +
|
|
||||||
b * num_head * T * dim_per_head +
|
|
||||||
nh * T * dim_per_head +
|
|
||||||
t * dim_per_head + dh]);
|
|
||||||
}
|
|
||||||
for (; dh < dim_per_head; dh++) {
|
|
||||||
auto d = nh * dim_per_head + dh;
|
|
||||||
auto q_bias = qkv_bias_data[d + 0 * D];
|
|
||||||
auto k_bias = qkv_bias_data[d + 1 * D];
|
|
||||||
auto v_bias = qkv_bias_data[d + 2 * D];
|
|
||||||
auto q_data = qkv_data[b * _3D * T + t * _3D + d + 0 * D] + q_bias;
|
|
||||||
auto k_data = qkv_data[b * _3D * T + t * _3D + d + 1 * D] + k_bias;
|
|
||||||
auto v_data = qkv_data[b * _3D * T + t * _3D + d + 2 * D] + v_bias;
|
|
||||||
q_data = q_data / sqrt_dim_per_head;
|
|
||||||
q_k_v_data[0 * B * num_head * T * dim_per_head +
|
|
||||||
b * num_head * T * dim_per_head +
|
|
||||||
nh * T * dim_per_head +
|
|
||||||
t * dim_per_head + dh] = q_data;
|
|
||||||
q_k_v_data[1 * B * num_head * T * dim_per_head +
|
|
||||||
b * num_head * T * dim_per_head +
|
|
||||||
nh * T * dim_per_head +
|
|
||||||
t * dim_per_head + dh] = k_data;
|
|
||||||
q_k_v_data[2 * B * num_head * T * dim_per_head +
|
|
||||||
b * num_head * T * dim_per_head +
|
|
||||||
nh * T * dim_per_head +
|
|
||||||
t * dim_per_head + dh] = v_data;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
|
||||||
const Tensor& qkv,
|
|
||||||
const Tensor& qkv_bias,
|
|
||||||
const int64_t num_head) {
|
|
||||||
auto B = qkv.size(0);
|
|
||||||
auto T = qkv.size(1);
|
|
||||||
auto _3D = qkv.size(2);
|
|
||||||
auto D = _3D / 3;
|
|
||||||
TORCH_CHECK(D % num_head == 0);
|
|
||||||
TORCH_CHECK(_3D % 3 == 0);
|
|
||||||
const auto dim_per_head = D / num_head;
|
|
||||||
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v.is_contiguous());
|
|
||||||
|
|
||||||
const auto qkv_contig = qkv.expect_contiguous();
|
|
||||||
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
||||||
ScalarType::Half,
|
|
||||||
ScalarType::BFloat16,
|
|
||||||
qkv.scalar_type(),
|
|
||||||
"transform_bias_rescale_qkv",
|
|
||||||
[&] {
|
|
||||||
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
|
|
||||||
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
|
|
||||||
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
|
|
||||||
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
|
|
||||||
|
|
||||||
int64_t grain_size =
|
|
||||||
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
|
|
||||||
parallel_for(
|
|
||||||
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
|
|
||||||
transform_bias_rescale_qkv_inner_loop(B, T, _3D, D, num_head, dim_per_head, qkv_data, qkv_bias_data, q_k_v_data, sqrt_dim_per_head, begin, end);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
auto q_k_v_s =
|
|
||||||
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
|
|
||||||
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
|
|
||||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
|
|
||||||
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
|
|
||||||
auto bt_ = b_.transpose(2, 1);
|
|
||||||
// TODO: are these a single call to cublas batched matmul?
|
|
||||||
auto c_ = at::matmul(a_, bt_);
|
|
||||||
return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
|
|
||||||
}
|
|
||||||
|
|
||||||
void masked_softmax_dropout(
|
|
||||||
Tensor& attn_scores,
|
|
||||||
const c10::optional<Tensor>& attn_mask) {
|
|
||||||
auto B = attn_scores.size(0);
|
|
||||||
auto num_heads = attn_scores.size(1);
|
|
||||||
auto T = attn_scores.size(2);
|
|
||||||
if (attn_mask) {
|
|
||||||
TORCH_CHECK(attn_mask->is_contiguous());
|
|
||||||
} else {
|
|
||||||
at::_softmax_out(attn_scores, attn_scores, 3, false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
||||||
ScalarType::Half,
|
|
||||||
ScalarType::BFloat16,
|
|
||||||
attn_scores.scalar_type(),
|
|
||||||
"masked_softmax_dropout",
|
|
||||||
[&] {
|
|
||||||
using accscalar_t = acc_type<scalar_t, false>;
|
|
||||||
// TODO: proper implementation with masking.
|
|
||||||
scalar_t* attn_scores_data = attn_scores.data_ptr<scalar_t>();
|
|
||||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / T, (int64_t)1);
|
|
||||||
parallel_for(
|
|
||||||
0, B * num_heads * T, grain_size, [&](int64_t begin, int64_t end) {
|
|
||||||
for (const auto i : c10::irange(begin, end)) {
|
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
|
||||||
auto V = vec::Vectorized<scalar_t>::size();
|
|
||||||
|
|
||||||
scalar_t* input_data = attn_scores_data + i;
|
|
||||||
auto max_input = Vec(std::numeric_limits<scalar_t>::lowest());
|
|
||||||
// TODO: handle epilogue
|
|
||||||
TORCH_CHECK(T % V == 0, "epilogue not implemented yet");
|
|
||||||
for (auto t = 0; t < T; t += V) {
|
|
||||||
auto v = Vec::loadu(&input_data[t]);
|
|
||||||
max_input = vec::maximum(max_input, v);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto hmax = std::numeric_limits<scalar_t>::lowest();
|
|
||||||
for (auto i = 0; i < V; ++i) {
|
|
||||||
hmax = std::max(max_input[i], hmax);
|
|
||||||
}
|
|
||||||
accscalar_t hsum = 0;
|
|
||||||
TORCH_CHECK(T % V == 0, "epilogue not implemented yet");
|
|
||||||
for (auto t = 0; t < T; t += V) {
|
|
||||||
auto v = Vec::loadu(&input_data[t]);
|
|
||||||
// TODO: vectorize in accscalar_t?
|
|
||||||
for (auto i = 0; i < V; ++i) {
|
|
||||||
hsum += std::exp(static_cast<accscalar_t>(v[i]) - hmax);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto inv_denominator = 1.0 / hsum;
|
|
||||||
TORCH_CHECK(T % V == 0, "epilogue not implemented yet");
|
|
||||||
for (auto t = 0; t < T; t += V) {
|
|
||||||
Vec v = Vec::loadu(&input_data[t]);
|
|
||||||
|
|
||||||
// TODO: vectorize in accscalar_t?
|
|
||||||
// TODO this faster solution does not work on Android build
|
|
||||||
/*
|
|
||||||
for (auto i = 0; i < V; ++i) {
|
|
||||||
v[i] = static_cast<scalar_t>(std::exp(static_cast<accscalar_t>(v[i]) - hmax) * inv_denominator);
|
|
||||||
}
|
|
||||||
v.store(&input_data[t]);
|
|
||||||
*/
|
|
||||||
for (auto i = 0; i < V; ++i) {
|
|
||||||
input_data[t + i] = static_cast<scalar_t>(std::exp(static_cast<accscalar_t>(v[i]) - hmax) * inv_denominator);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor bmm_nn(const Tensor& a, const Tensor& b) {
|
|
||||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
|
|
||||||
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
|
|
||||||
// TODO: are these a single call to cublas batched matmul?
|
|
||||||
auto c_ = at::matmul(a_, b_);
|
|
||||||
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor transform_0213(const Tensor& a) {
|
|
||||||
// TODO: check perf vs dedicated kernel.
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
|
|
||||||
return a.permute({0, 2, 1, 3})
|
|
||||||
.contiguous()
|
|
||||||
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor gemm_nt_bias(const Tensor& a, const Tensor& b, const Tensor& c) {
|
|
||||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2)});
|
|
||||||
auto r_ = at::native::linear(a_, b, c);
|
|
||||||
return r_.view({a.size(0), a.size(1), r_.size(1)});
|
|
||||||
}
|
|
||||||
|
|
||||||
void debug_assert_shape(const Tensor& t, c10::IntArrayRef shape) {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((size_t)t.dim() == shape.size(), "expected ", shape.size(), "-D tensor but got ", t.dim());
|
|
||||||
for (auto idx : c10::irange(shape.size())) {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t.sizes()[idx] == shape[idx], "expected dim ", idx, " to be ", shape[idx], " but got ", t.sizes()[idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_op_cpu(
|
|
||||||
const Tensor& qkv,
|
|
||||||
const Tensor& qkv_bias,
|
|
||||||
const int64_t num_head) {
|
|
||||||
auto result = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
|
||||||
return std::make_tuple(std::get<0>(result).clone(), std::get<1>(result).clone(), std::get<2>(result).clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor multi_head_self_attention_cpu(
|
|
||||||
const Tensor& query,
|
|
||||||
const Tensor& qkv_weight,
|
|
||||||
const Tensor& qkv_bias,
|
|
||||||
const Tensor& proj_weight,
|
|
||||||
const Tensor& proj_bias,
|
|
||||||
const int64_t num_head,
|
|
||||||
const c10::optional<Tensor>& mask) {
|
|
||||||
// query shape: [B, T, D]
|
|
||||||
// qkv_weight shape: [3 * D, D]
|
|
||||||
|
|
||||||
const auto D = query.sizes()[2];
|
|
||||||
|
|
||||||
TORCH_CHECK(query.dim() == 3, "expected 3-dimensional query, got ", query.dim(), "-D tensor");
|
|
||||||
TORCH_CHECK(qkv_weight.dim() == 2, "expected 2-dimensional qkv_weight, got ", qkv_weight.dim(), "-D tensor");
|
|
||||||
TORCH_CHECK(D * 3 == qkv_weight.sizes()[0], "expected qkv_weight first dim to be 3x last dim of query");
|
|
||||||
TORCH_CHECK(D == qkv_weight.sizes()[1], "expected qkv_weight second dim and last dim of query to be equal");
|
|
||||||
TORCH_CHECK(qkv_bias.dim() == 1, "expected 2-dimensional qkv_bias, got ", qkv_bias.dim(), "-D tensor");
|
|
||||||
TORCH_CHECK(qkv_bias.sizes()[0] == 3 * D, "expected qkv_bias first dim and first dim of query to be equal");
|
|
||||||
TORCH_CHECK(D % num_head == 0, "D must divide evenly by num_head");
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
const auto B = query.sizes()[0];
|
|
||||||
const auto T = query.sizes()[1];
|
|
||||||
const auto dim_per_head = D / num_head;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, T, 3 x D]
|
|
||||||
auto qkv = gemm_nt(query, qkv_weight);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(qkv, {B, T, 3 * D});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: 3 x [B, num_head, T, dim_per_head]
|
|
||||||
auto q_k_v = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
|
||||||
const auto& q = std::get<0>(q_k_v);
|
|
||||||
const auto& k = std::get<1>(q_k_v);
|
|
||||||
const auto& v = std::get<2>(q_k_v);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(q, {B, num_head, T, dim_per_head});
|
|
||||||
debug_assert_shape(k, {B, num_head, T, dim_per_head});
|
|
||||||
debug_assert_shape(v, {B, num_head, T, dim_per_head});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, num_head, T, T]
|
|
||||||
auto qkt = bmm_nt(q, k);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(qkt, {B, num_head, T, T});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, num_head, T, T]
|
|
||||||
masked_softmax_dropout(qkt, mask);
|
|
||||||
|
|
||||||
// shape: [B, num_head, T, dim_per_head]
|
|
||||||
auto attn_ctx = bmm_nn(qkt, v);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(attn_ctx, {B, num_head, T, dim_per_head});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, T, D]
|
|
||||||
auto attn = transform_0213(attn_ctx);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(attn, {B, T, D});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, T, D]
|
|
||||||
auto proj = gemm_nt_bias(attn, proj_weight, proj_bias);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(proj, {B, T, D});
|
|
||||||
#endif
|
|
||||||
return proj;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace native
|
|
||||||
} // namespace at
|
|
||||||
|
|
@ -1,342 +0,0 @@
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/AccumulateType.h>
|
|
||||||
#include <ATen/Dispatch.h>
|
|
||||||
#include <ATen/NativeFunctions.h>
|
|
||||||
#include <ATen/TensorAccessor.h>
|
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <ATen/cuda/detail/KernelUtils.h>
|
|
||||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
|
||||||
#include <ATen/native/cuda/Loops.cuh>
|
|
||||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
|
||||||
#include <ATen/native/cuda/block_reduce.cuh>
|
|
||||||
#include <ATen/native/cuda/PersistentSoftmax.cuh>
|
|
||||||
|
|
||||||
#include <c10/cuda/CUDAMathCompat.h>
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
|
|
||||||
namespace native {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
Tensor gemm_nt(const Tensor& a, const Tensor& b) {
|
|
||||||
return at::native::matmul(a, b.t());
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
|
|
||||||
|
|
||||||
template <typename scalar_t, typename accscalar_t, bool assume_aligned>
|
|
||||||
__global__ void transform_bias_rescale_qkv_kernel(
|
|
||||||
// [B, T, 3 * D]
|
|
||||||
const PackedTensorAccessor64<scalar_t, 3, RestrictPtrTraits> qkv,
|
|
||||||
// [3 * D]
|
|
||||||
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
|
|
||||||
// [3, B, NH, T, DH]
|
|
||||||
PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v) {
|
|
||||||
// warp per DH.
|
|
||||||
// so launch B * NH * T warps.
|
|
||||||
auto NH = q_k_v.size(2);
|
|
||||||
auto T = q_k_v.size(3);
|
|
||||||
auto DH = q_k_v.size(4);
|
|
||||||
|
|
||||||
auto t = blockIdx.x % T;
|
|
||||||
auto b = blockIdx.x / T;
|
|
||||||
|
|
||||||
auto D = NH * DH;
|
|
||||||
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(DH));
|
|
||||||
|
|
||||||
if (assume_aligned) {
|
|
||||||
constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
|
|
||||||
using LoadT = memory::aligned_vector<scalar_t, VEC>;
|
|
||||||
for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
|
|
||||||
auto d = d_v * VEC;
|
|
||||||
auto nh = d / DH;
|
|
||||||
auto dh = d % DH;
|
|
||||||
scalar_t qkv_bias_q[VEC];
|
|
||||||
scalar_t qkv_bias_k[VEC];
|
|
||||||
scalar_t qkv_bias_v[VEC];
|
|
||||||
scalar_t qkv_q[VEC];
|
|
||||||
scalar_t qkv_k[VEC];
|
|
||||||
scalar_t qkv_v[VEC];
|
|
||||||
|
|
||||||
// Here we require D % VEC == 0 for these vectorized loads.
|
|
||||||
*reinterpret_cast<LoadT*>(&qkv_bias_q) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
|
|
||||||
*reinterpret_cast<LoadT*>(&qkv_bias_k) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
|
|
||||||
*reinterpret_cast<LoadT*>(&qkv_bias_v) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
|
|
||||||
|
|
||||||
*reinterpret_cast<LoadT*>(&qkv_q) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 0 * D]);
|
|
||||||
*reinterpret_cast<LoadT*>(&qkv_k) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 1 * D]);
|
|
||||||
*reinterpret_cast<LoadT*>(&qkv_v) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 2 * D]);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
// TODO: specialize for float2half2/half2float2?
|
|
||||||
for (auto ii = 0; ii < VEC; ++ii) {
|
|
||||||
qkv_q[ii] = static_cast<scalar_t>(
|
|
||||||
(static_cast<accscalar_t>(qkv_q[ii]) +
|
|
||||||
static_cast<accscalar_t>(qkv_bias_q[ii])) /
|
|
||||||
static_cast<accscalar_t>(sqrt_dim_per_head));
|
|
||||||
qkv_k[ii] = static_cast<scalar_t>(
|
|
||||||
(static_cast<accscalar_t>(qkv_k[ii]) +
|
|
||||||
static_cast<accscalar_t>(qkv_bias_k[ii])));
|
|
||||||
qkv_v[ii] = static_cast<scalar_t>(
|
|
||||||
(static_cast<accscalar_t>(qkv_v[ii]) +
|
|
||||||
static_cast<accscalar_t>(qkv_bias_v[ii])));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Here we require DH % VEC == 0 for these vectorized stores.
|
|
||||||
*reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv_q);
|
|
||||||
*reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv_k);
|
|
||||||
*reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
|
|
||||||
*reinterpret_cast<const LoadT*>(&qkv_v);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Same as above, but we can't vectorize memory access.
|
|
||||||
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
|
|
||||||
auto nh = d / DH;
|
|
||||||
auto dh = d % DH;
|
|
||||||
scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
|
|
||||||
scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
|
|
||||||
scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
|
|
||||||
scalar_t qkv_q = qkv[b][t][d + 0 * D];
|
|
||||||
scalar_t qkv_k = qkv[b][t][d + 1 * D];
|
|
||||||
scalar_t qkv_v = qkv[b][t][d + 2 * D];
|
|
||||||
qkv_q = static_cast<scalar_t>(
|
|
||||||
(static_cast<accscalar_t>(qkv_q) +
|
|
||||||
static_cast<accscalar_t>(qkv_bias_q)) /
|
|
||||||
static_cast<accscalar_t>(sqrt_dim_per_head));
|
|
||||||
qkv_k = static_cast<scalar_t>(
|
|
||||||
(static_cast<accscalar_t>(qkv_k) +
|
|
||||||
static_cast<accscalar_t>(qkv_bias_k)));
|
|
||||||
qkv_v = static_cast<scalar_t>(
|
|
||||||
(static_cast<accscalar_t>(qkv_v) +
|
|
||||||
static_cast<accscalar_t>(qkv_bias_v)));
|
|
||||||
|
|
||||||
q_k_v[0][b][nh][t][dh] = qkv_q;
|
|
||||||
q_k_v[1][b][nh][t][dh] = qkv_k;
|
|
||||||
q_k_v[2][b][nh][t][dh] = qkv_v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
|
||||||
const Tensor& qkv,
|
|
||||||
const Tensor& qkv_bias,
|
|
||||||
const int64_t num_head) {
|
|
||||||
auto B = qkv.size(0);
|
|
||||||
auto T = qkv.size(1);
|
|
||||||
auto _3D = qkv.size(2);
|
|
||||||
auto D = _3D / 3;
|
|
||||||
TORCH_CHECK(D % num_head == 0);
|
|
||||||
const auto dim_per_head = D / num_head;
|
|
||||||
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
|
|
||||||
#define CALL_KERNEL(assume_aligned) \
|
|
||||||
transform_bias_rescale_qkv_kernel<scalar_t, accscalar_t, assume_aligned> \
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
|
||||||
qkv.packed_accessor64<scalar_t, 3, RestrictPtrTraits>(), \
|
|
||||||
qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
|
|
||||||
q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>())
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
||||||
ScalarType::Half,
|
|
||||||
ScalarType::BFloat16,
|
|
||||||
qkv.scalar_type(),
|
|
||||||
"transform_bias_rescale_qkv",
|
|
||||||
[&] {
|
|
||||||
using accscalar_t = acc_type<scalar_t, true>;
|
|
||||||
auto threads = std::max(std::min<int32_t>(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1);
|
|
||||||
auto blocks = B * T;
|
|
||||||
if (dim_per_head % TRANSFORM_BIAS_RESCALE_VEC == 0) {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
||||||
D % TRANSFORM_BIAS_RESCALE_VEC == 0,
|
|
||||||
"D = num_heads * dim_per_head, so we should have dim_per_head % "
|
|
||||||
"TRANSFORM_BIAS_RESCALE_VEC == 0 => "
|
|
||||||
"D % TRANSFORM_BIAS_RESCALE_VEC == 0");
|
|
||||||
CALL_KERNEL(true);
|
|
||||||
} else {
|
|
||||||
CALL_KERNEL(false);
|
|
||||||
}
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
});
|
|
||||||
#undef CALL_KERNEL
|
|
||||||
auto q_k_v_s =
|
|
||||||
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
|
||||||
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
|
|
||||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
|
|
||||||
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
|
|
||||||
auto bt_ = b_.transpose(2, 1);
|
|
||||||
// TODO: are these a single call to cublas batched matmul?
|
|
||||||
auto c_ = at::matmul(a_, bt_);
|
|
||||||
return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__inline__ __device__ T WarpReduceMax(T val) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
|
|
||||||
val = std::max(val, WARP_SHFL_DOWN(val, offset));
|
|
||||||
}
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__inline__ __device__ T WarpReduceSum(T val) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
|
|
||||||
val += WARP_SHFL_DOWN(val, offset);
|
|
||||||
}
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
void masked_softmax_dropout(
|
|
||||||
const Tensor& attn_scores,
|
|
||||||
const c10::optional<Tensor>& attn_mask) {
|
|
||||||
auto B = attn_scores.size(0);
|
|
||||||
auto num_heads = attn_scores.size(1);
|
|
||||||
auto T = attn_scores.size(2);
|
|
||||||
if (attn_mask) {
|
|
||||||
TORCH_CHECK(attn_mask->is_contiguous());
|
|
||||||
}
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
||||||
ScalarType::Half,
|
|
||||||
ScalarType::BFloat16,
|
|
||||||
attn_scores.scalar_type(),
|
|
||||||
"masked_softmax_dropout",
|
|
||||||
[&] {
|
|
||||||
using accscalar_t = acc_type<scalar_t, true>;
|
|
||||||
// TODO: proper implementation with masking.
|
|
||||||
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false, false>(
|
|
||||||
attn_scores.data_ptr<scalar_t>(),
|
|
||||||
attn_scores.data_ptr<scalar_t>(),
|
|
||||||
T,
|
|
||||||
T,
|
|
||||||
B * num_heads * T
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor bmm_nn(const Tensor& a, const Tensor& b) {
|
|
||||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
|
|
||||||
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
|
|
||||||
// TODO: are these a single call to cublas batched matmul?
|
|
||||||
auto c_ = at::matmul(a_, b_);
|
|
||||||
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor transform_0213(const Tensor& a) {
|
|
||||||
// TODO: check perf vs dedicated kernel.
|
|
||||||
return a.permute({0, 2, 1, 3})
|
|
||||||
.contiguous()
|
|
||||||
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor gemm_nt_bias(const Tensor& a, const Tensor& b, const Tensor& c) {
|
|
||||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2)});
|
|
||||||
auto r_ = at::native::linear(a_, b, c);
|
|
||||||
return r_.view({a.size(0), a.size(1), r_.size(1)});
|
|
||||||
}
|
|
||||||
|
|
||||||
void debug_assert_shape(const Tensor& t, c10::IntArrayRef shape) {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((size_t)t.dim() == shape.size(), "expected ", shape.size(), "-D tensor but got ", t.dim());
|
|
||||||
for (auto idx : c10::irange(shape.size())) {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t.sizes()[idx] == shape[idx], "expected dim ", idx, " to be ", shape[idx], " but got ", t.sizes()[idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_op_cuda(
|
|
||||||
const Tensor& qkv,
|
|
||||||
const Tensor& qkv_bias,
|
|
||||||
const int64_t num_head) {
|
|
||||||
auto result = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
|
||||||
return std::make_tuple(std::get<0>(result).clone(), std::get<1>(result).clone(), std::get<2>(result).clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor multi_head_self_attention_cuda(
|
|
||||||
const Tensor& query,
|
|
||||||
const Tensor& qkv_weight,
|
|
||||||
const Tensor& qkv_bias,
|
|
||||||
const Tensor& proj_weight,
|
|
||||||
const Tensor& proj_bias,
|
|
||||||
const int64_t num_head,
|
|
||||||
const c10::optional<Tensor>& mask) {
|
|
||||||
// query shape: [B, T, D]
|
|
||||||
// qkv_weight shape: [3 * D, D]
|
|
||||||
|
|
||||||
const auto D = query.sizes()[2];
|
|
||||||
|
|
||||||
TORCH_CHECK(query.dim() == 3, "expected 3-dimensional query, got ", query.dim(), "-D tensor");
|
|
||||||
TORCH_CHECK(qkv_weight.dim() == 2, "expected 2-dimensional qkv_weight, got ", qkv_weight.dim(), "-D tensor");
|
|
||||||
TORCH_CHECK(D * 3 == qkv_weight.sizes()[0], "expected qkv_weight first dim to be 3x last dim of query");
|
|
||||||
TORCH_CHECK(D == qkv_weight.sizes()[1], "expected qkv_weight second dim and last dim of query to be equal");
|
|
||||||
TORCH_CHECK(D % num_head == 0, "D must divide evenly by num_head");
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
const auto B = query.sizes()[0];
|
|
||||||
const auto T = query.sizes()[1];
|
|
||||||
const auto dim_per_head = D / num_head;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, T, 3 x D]
|
|
||||||
auto qkv = gemm_nt(query, qkv_weight);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(qkv, {B, T, 3 * D});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: 3 x [B, num_head, T, dim_per_head]
|
|
||||||
auto q_k_v = transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
|
||||||
const auto& q = std::get<0>(q_k_v);
|
|
||||||
const auto& k = std::get<1>(q_k_v);
|
|
||||||
const auto& v = std::get<2>(q_k_v);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(q, {B, num_head, T, dim_per_head});
|
|
||||||
debug_assert_shape(k, {B, num_head, T, dim_per_head});
|
|
||||||
debug_assert_shape(v, {B, num_head, T, dim_per_head});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, num_head, T, T]
|
|
||||||
auto qkt = bmm_nt(q, k);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(qkt, {B, num_head, T, T});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, num_head, T, T]
|
|
||||||
masked_softmax_dropout(qkt, mask);
|
|
||||||
|
|
||||||
// shape: [B, num_head, T, dim_per_head]
|
|
||||||
auto attn_ctx = bmm_nn(qkt, v);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(attn_ctx, {B, num_head, T, dim_per_head});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, T, D]
|
|
||||||
auto attn = transform_0213(attn_ctx);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(attn, {B, T, D});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// shape: [B, T, D]
|
|
||||||
auto proj = gemm_nt_bias(attn, proj_weight, proj_bias);
|
|
||||||
#ifndef NDEBUG
|
|
||||||
debug_assert_shape(proj, {B, T, D});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return proj;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace native
|
|
||||||
} // namespace at
|
|
||||||
|
|
@ -2549,16 +2549,6 @@
|
||||||
CUDA: layer_norm_cuda
|
CUDA: layer_norm_cuda
|
||||||
CompositeImplicitAutograd: math_native_layer_norm
|
CompositeImplicitAutograd: math_native_layer_norm
|
||||||
|
|
||||||
- func: _native_multi_head_self_attention(Tensor query, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, int num_head, Tensor? mask=None) -> Tensor
|
|
||||||
dispatch:
|
|
||||||
CPU: multi_head_self_attention_cpu
|
|
||||||
CUDA: multi_head_self_attention_cuda
|
|
||||||
|
|
||||||
- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_head) -> (Tensor, Tensor, Tensor)
|
|
||||||
dispatch:
|
|
||||||
CPU: transform_bias_rescale_qkv_op_cpu
|
|
||||||
CUDA: transform_bias_rescale_qkv_op_cuda
|
|
||||||
|
|
||||||
- func: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
- func: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: layer_norm_backward_cpu
|
CPU: layer_norm_backward_cpu
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,6 @@ ALLOW_LIST = [
|
||||||
("aten::nanquantile", datetime.date(2022, 9, 30)),
|
("aten::nanquantile", datetime.date(2022, 9, 30)),
|
||||||
("aten::_convolution_double_backward", datetime.date(2022, 3, 31)),
|
("aten::_convolution_double_backward", datetime.date(2022, 3, 31)),
|
||||||
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
|
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
|
||||||
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
|
||||||
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
|
||||||
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
|
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17540,44 +17540,6 @@ class TestNNDeviceType(NNTestCase):
|
||||||
self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
|
self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
|
||||||
self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
|
self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
|
||||||
|
|
||||||
@dtypesIfCUDA(torch.float)
|
|
||||||
@dtypes(torch.float)
|
|
||||||
def test_transform_bias_rescale_qkv(self, device, dtype):
|
|
||||||
# TODO: debug CPU test failure with settings (48, 4, 16, 8) and add that mode
|
|
||||||
tests = [
|
|
||||||
(64, 4, 16, 8),
|
|
||||||
# dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
|
|
||||||
(24, 2, 4, 2),
|
|
||||||
# Make sure CUDA can handle small input sizes
|
|
||||||
(2, 2, 2, 2),
|
|
||||||
# dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4, causes alignment issues
|
|
||||||
(24, 4, 4, 2)
|
|
||||||
]
|
|
||||||
for (embed_dim, num_heads, sl, bs) in tests:
|
|
||||||
x = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) * 10
|
|
||||||
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
(q, k, v) = torch._transform_bias_rescale_qkv(x @ qkv.weight.t(), qkv.bias, num_head=num_heads)
|
|
||||||
|
|
||||||
def simple_transform_bias_rescale_qkv(qkv, bias):
|
|
||||||
(q, k, v) = torch.split(qkv, embed_dim, dim=-1)
|
|
||||||
(q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
|
|
||||||
return tuple(
|
|
||||||
x.reshape((sl, bs, num_heads, embed_dim // num_heads)).transpose(2, 1)
|
|
||||||
for x in (
|
|
||||||
(q + q_bias) / math.sqrt(embed_dim // num_heads),
|
|
||||||
(k + k_bias),
|
|
||||||
(v + v_bias)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(x @ qkv.weight.t(), qkv.bias)
|
|
||||||
|
|
||||||
self.assertEqual(q.size(), correct_q.size())
|
|
||||||
self.assertTrue(torch.allclose(q, correct_q))
|
|
||||||
self.assertTrue(torch.allclose(k, correct_k))
|
|
||||||
self.assertTrue(torch.allclose(v, correct_v))
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@dtypes(torch.half, torch.float, torch.double)
|
@dtypes(torch.half, torch.float, torch.double)
|
||||||
def test_multihead_attention_dtype(self, device, dtype):
|
def test_multihead_attention_dtype(self, device, dtype):
|
||||||
|
|
|
||||||
|
|
@ -1176,7 +1176,6 @@ aten_native_source_non_codegen_list = [
|
||||||
"aten/src/ATen/native/quantized/library.cpp",
|
"aten/src/ATen/native/quantized/library.cpp",
|
||||||
"aten/src/ATen/quantized/QTensorImpl.cpp",
|
"aten/src/ATen/quantized/QTensorImpl.cpp",
|
||||||
"aten/src/ATen/quantized/Quantizer.cpp",
|
"aten/src/ATen/quantized/Quantizer.cpp",
|
||||||
"aten/src/ATen/native/attention.cpp",
|
|
||||||
"aten/src/ATen/native/Activation.cpp",
|
"aten/src/ATen/native/Activation.cpp",
|
||||||
"aten/src/ATen/native/AdaptiveAveragePooling.cpp",
|
"aten/src/ATen/native/AdaptiveAveragePooling.cpp",
|
||||||
"aten/src/ATen/native/AdaptiveAveragePooling3d.cpp",
|
"aten/src/ATen/native/AdaptiveAveragePooling3d.cpp",
|
||||||
|
|
|
||||||
|
|
@ -669,7 +669,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
|
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
|
||||||
torch.native_dropout: lambda input, p, train: -1,
|
torch.native_dropout: lambda input, p, train: -1,
|
||||||
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||||
torch._native_multi_head_self_attention: lambda query, qkv_weight, qkv_bias, proj_weight, proj_bias, mask=None: -1,
|
|
||||||
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
||||||
torch.native_norm: lambda input, p=2: -1,
|
torch.native_norm: lambda input, p=2: -1,
|
||||||
torch.native_norm: lambda input, p=2: -1,
|
torch.native_norm: lambda input, p=2: -1,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user