Replace _baddbmm_mkl_ with cpublas::gemm_batched (#66165)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66165

Test Plan: Imported from OSS

Reviewed By: dagitses

Differential Revision: D31493952

Pulled By: ngimel

fbshipit-source-id: 87cf79036c2d0f4955edbeeeb78f578b0fd223ab
This commit is contained in:
Peter Bell 2021-10-08 17:10:57 -07:00 committed by Facebook GitHub Bot
parent 51835bec07
commit c957d9fdf6
5 changed files with 45 additions and 189 deletions

View File

@ -27,7 +27,6 @@ _(aten, _amp_update_scale_) \
_(aten, _arange) \ _(aten, _arange) \
_(aten, _argmax) \ _(aten, _argmax) \
_(aten, _argmin) \ _(aten, _argmin) \
_(aten, _baddbmm_mkl) \
_(aten, _cast_Byte) \ _(aten, _cast_Byte) \
_(aten, _cast_Char) \ _(aten, _cast_Char) \
_(aten, _cast_Double) \ _(aten, _cast_Double) \

View File

@ -3,6 +3,7 @@
#include <ATen/Dispatch.h> #include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h> #include <ATen/ExpandUtils.h>
#include <ATen/NamedTensorUtils.h> #include <ATen/NamedTensorUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/native/mkldnn/Matmul.h> #include <ATen/native/mkldnn/Matmul.h>
#include <ATen/native/CPUBlas.h> #include <ATen/native/CPUBlas.h>
#include <ATen/native/IndexingUtils.h> #include <ATen/native/IndexingUtils.h>
@ -1240,6 +1241,48 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
}); });
} }
void baddbmm_with_gemm_(const Tensor &result, const Tensor &mat1, const Tensor &mat2, const Scalar &beta_, const Scalar &alpha_) {
TORCH_INTERNAL_ASSERT(result.is_contiguous());
const auto result_sizes = result.sizes();
const auto result_strides = result.strides();
const auto mat1_strides = mat1.strides();
const auto mat2_strides = mat2.strides();
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
auto is_transposed = [](const c10::IntArrayRef& strides, const c10::IntArrayRef& sizes) {
return strides[1] == 1 && strides[2] >= sizes[1];
};
// gemm expects fortran order matrices, so we swap argument order to transpose everything
const auto transpose_a = is_transposed(mat2_strides, mat2_sizes);
const auto transpose_b = is_transposed(mat1_strides, mat1_sizes);
const int64_t batch_size = mat1_sizes[0];
const int64_t m = result_sizes[2];
const int64_t n = result_sizes[1];
const int64_t k = mat2_sizes[1];
const int64_t lda = mat2_strides[transpose_a ? 2 : 1];
const int64_t ldb = mat1_strides[transpose_b ? 2 : 1];
const int64_t ldc = result_strides[1];
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "baddbmm_with_gemm", [&] {
using opmath_t = at::opmath_type<scalar_t>;
const auto alpha = alpha_.to<opmath_t>();
const auto beta = beta_.to<opmath_t>();
at::native::cpublas::gemm_batched_with_stride(
transpose_a ? TransposeType::Transpose : TransposeType::NoTranspose,
transpose_b ? TransposeType::Transpose : TransposeType::NoTranspose,
batch_size, m, n, k, alpha,
mat2.data_ptr<scalar_t>(), lda, mat2_strides[0],
mat1.data_ptr<scalar_t>(), ldb, mat1_strides[0],
beta,
result.data_ptr<scalar_t>(), ldc, result_strides[0]);
});
}
// This tries to apply some optimizations to bmm/baddbmm: // This tries to apply some optimizations to bmm/baddbmm:
// - When the operand size is small, computation are parallelized over the batch // - When the operand size is small, computation are parallelized over the batch
// dimension using OMP and naive matrix multiplication is applied. // dimension using OMP and naive matrix multiplication is applied.
@ -1319,7 +1362,7 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
&& batch_items_contiguous_or_transposed(batch1) && batch_items_contiguous_or_transposed(batch1)
&& batch_items_contiguous_or_transposed(batch2) && batch_items_contiguous_or_transposed(batch2)
&& self_or_result.is_contiguous()) { && self_or_result.is_contiguous()) {
at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha); baddbmm_with_gemm_(self_or_result, batch1, batch2, beta, alpha);
} else { // split along batch dimension } else { // split along batch dimension
#ifdef C10_MOBILE #ifdef C10_MOBILE
/* /*

View File

@ -1,6 +1,4 @@
#include <ATen/native/mkl/LinearAlgebra.h> #include <ATen/native/mkl/LinearAlgebra.h>
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h> #include <ATen/Config.h>
#if !AT_MKL_ENABLED() #if !AT_MKL_ENABLED()
@ -39,100 +37,14 @@ void mkl_gemm_batched(
TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support"); TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
} }
Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
AT_ERROR("bmm: ATen not compiled with MKL support");
}
}} }}
#else // AT_MKL_ENABLED #else // AT_MKL_ENABLED
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/Utils.h>
#include <ATen/NativeFunctions.h>
#include <algorithm>
#include <vector>
#include <numeric>
#include <cmath>
#include <mkl.h> #include <mkl.h>
#include <ATen/mkl/Exceptions.h>
#include <ATen/mkl/Descriptors.h>
#include <ATen/mkl/Limits.h>
namespace at { namespace native { namespace at { namespace native {
static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int M, const int N, const int K, const float alpha, const float* A,
const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) {
cblas_sgemm(CblasRowMajor, trans_A, trans_B, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int M, const int N, const int K, const double alpha, const double* A,
const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) {
cblas_dgemm(CblasRowMajor, trans_A, trans_B, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int M, const int N, const int K, const c10::complex<float> alpha,
const c10::complex<float>* A, const int lda, const c10::complex<float>* B, const int ldb,
const c10::complex<float> beta, c10::complex<float>* C, const int ldc) {
cblas_cgemm(CblasRowMajor, trans_A, trans_B, M, N, K, reinterpret_cast<const void *>(&alpha),
reinterpret_cast<const void*>(A), lda, reinterpret_cast<const void*>(B), ldb,
reinterpret_cast<const void*>(&beta), reinterpret_cast<void*>(C), ldc);
}
static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int M, const int N, const int K, const c10::complex<double> alpha,
const c10::complex<double>* A, const int lda, const c10::complex<double>* B, const int ldb,
const c10::complex<double> beta, c10::complex<double>* C, const int ldc) {
cblas_zgemm(CblasRowMajor, trans_A, trans_B, M, N, K, reinterpret_cast<const void *>(&alpha),
reinterpret_cast<const void*>(A), lda, reinterpret_cast<const void*>(B), ldb,
reinterpret_cast<const void*>(&beta), reinterpret_cast<void*>(C), ldc);
}
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const float alpha,
const float** A, const int lda, const float** B, const int ldb, const float beta,
float** C, const int ldc) {
cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
}
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const double alpha,
const double** A, const int lda, const double** B, const int ldb, const double beta,
double** C, const int ldc) {
cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
}
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const c10::complex<float> alpha,
const c10::complex<float>** A, const int lda, const c10::complex<float>** B, const int ldb,
const c10::complex<float> beta, c10::complex<float>** C, const int ldc) {
cblas_cgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, reinterpret_cast<const void*>(&alpha),
reinterpret_cast<const void**>(A), &lda, reinterpret_cast<const void**>(B), &ldb,
reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
}
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const c10::complex<double> alpha,
const c10::complex<double>** A, const int lda, const c10::complex<double>** B, const int ldb,
const c10::complex<double> beta, c10::complex<double>** C, const int ldc) {
cblas_zgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, reinterpret_cast<const void*>(&alpha),
reinterpret_cast<const void**>(A), &lda, reinterpret_cast<const void**>(B), &ldb,
reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
}
static CBLAS_TRANSPOSE to_cblas(TransposeType x) { static CBLAS_TRANSPOSE to_cblas(TransposeType x) {
switch (x) { switch (x) {
case TransposeType::NoTranspose: return CblasNoTrans; case TransposeType::NoTranspose: return CblasNoTrans;
@ -190,102 +102,6 @@ void mkl_gemm_batched(
reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size); reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
} }
template <typename scalar_t>
static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
const auto mat1_strides = mat1.strides();
const auto mat2_strides = mat2.strides();
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
auto is_transposed = [](const c10::IntArrayRef& strides, const c10::IntArrayRef& sizes) {
return strides[1] == 1 && strides[2] >= sizes[1];
};
const CBLAS_TRANSPOSE trans_A =
is_transposed(mat1_strides, mat1_sizes) ? CblasTrans : CblasNoTrans;
const CBLAS_TRANSPOSE trans_B =
is_transposed(mat2_strides, mat2_sizes) ? CblasTrans : CblasNoTrans;
// mat1: batch_size * M * K
const int batch_size = mat1_sizes[0];
const int M = mat1_sizes[1];
// mat2: batch_size * K * N
const int N = mat2_sizes[2];
const int K = mat1_sizes[2];
scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();
const int lda = trans_A == CblasTrans ? mat1_strides[2] : mat1_strides[1];
const int ldb = trans_B == CblasTrans ? mat2_strides[2] : mat2_strides[1];
const int ldc = res.strides()[1];
// avoid using tensor accessor in the case of mat1/mat2 not being transposed
// or only transposed in the last two axes
const bool canAvoidTensorAccessor = mat1_strides[0] == mat1_sizes[1] * mat1_sizes[2] &&
mat2_strides[0] == mat2_sizes[1] * mat2_sizes[2];
scalar_t* const res_data = res.data_ptr<scalar_t>();
if (batch_size == 1) {
const scalar_t* A;
const scalar_t* B;
if (canAvoidTensorAccessor) {
scalar_t* mat1_data = mat1.data_ptr<scalar_t>();
scalar_t* mat2_data = mat2.data_ptr<scalar_t>();
A = mat1_data;
B = mat2_data;
} else {
auto mat1_acc = mat1.accessor<scalar_t, 3>();
auto mat2_acc = mat2.accessor<scalar_t, 3>();
A = mat1_acc[0].data();
B = mat2_acc[0].data();
}
gemm(trans_A, trans_B, M, N, K, alpha, A, lda, B, ldb, beta, res_data, ldc);
return;
}
std::vector<const scalar_t*> A;
A.reserve(batch_size);
std::vector<const scalar_t*> B;
B.reserve(batch_size);
std::vector<scalar_t*> C;
C.reserve(batch_size);
// avoid using tensor accessor in the case of mat1/mat2 not being transposed
// or only transposed in the last two axis
const auto res_sizes = res.sizes();
if (canAvoidTensorAccessor) {
scalar_t* mat1_data = mat1.data_ptr<scalar_t>();
scalar_t* mat2_data = mat2.data_ptr<scalar_t>();
for (int64_t batch = 0; batch < batch_size; batch++) {
A.emplace_back(mat1_data + batch * mat1_sizes[1] * mat1_sizes[2]);
B.emplace_back(mat2_data + batch * mat2_sizes[1] * mat2_sizes[2]);
C.emplace_back(res_data + batch * res_sizes[1] * res_sizes[2]);
}
} else {
auto mat1_acc = mat1.accessor<scalar_t, 3>();
auto mat2_acc = mat2.accessor<scalar_t, 3>();
for (int64_t batch = 0; batch < batch_size; batch++) {
A.emplace_back(mat1_acc[batch].data());
B.emplace_back(mat2_acc[batch].data());
C.emplace_back(res_data + batch * res_sizes[1] * res_sizes[2]);
}
}
gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc);
}
Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
// checks are done in native/LinearAlgebra.cpp
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "baddbmm__mkl", [&] {
baddbmm_mkl_template<scalar_t>(self, batch1, batch2, beta, alpha);
});
return self;
}
}} // namespace at::native }} // namespace at::native
#endif #endif

View File

@ -729,9 +729,6 @@
variants: method variants: method
structured_delegate: baddbmm.out structured_delegate: baddbmm.out
- func: _baddbmm_mkl_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
variants: function
- func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
structured: True structured: True
variants: function variants: function

View File

@ -56,6 +56,7 @@ ALLOW_LIST = [
("aten::_softmax_backward_data", datetime.date(2021, 10, 21)), ("aten::_softmax_backward_data", datetime.date(2021, 10, 21)),
("aten::fused_moving_avg_obs_fake_quant", datetime.date(2021, 10, 21)), ("aten::fused_moving_avg_obs_fake_quant", datetime.date(2021, 10, 21)),
("aten::_fused_moving_avg_obs_fq_helper", datetime.date(2021, 10, 21)) ("aten::_fused_moving_avg_obs_fq_helper", datetime.date(2021, 10, 21))
("aten::_baddbmm_mkl_", datetime.date(2021, 10, 31)),
] ]
ALLOW_LIST_COMPILED = [ ALLOW_LIST_COMPILED = [