Revert "Update the heuristic for AArch64 bmm/baddbmm (#149122)"

This reverts commit d759a517af.

Reverted https://github.com/pytorch/pytorch/pull/149122 on behalf of https://github.com/jeanschmidt due to breaking internal models, @malfet may you help merge this? ([comment](https://github.com/pytorch/pytorch/pull/149122#issuecomment-2904703075))
This commit is contained in:
PyTorch MergeBot 2025-05-23 14:54:54 +00:00
parent 5859582ee4
commit 866142ff16
2 changed files with 39 additions and 54 deletions

View File

@ -1360,6 +1360,41 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
#endif #endif
static inline int64_t get_mkldnn_matmul_min_dim() {
static auto value = [&] {
const int64_t default_min_dim = [&] {
// Minimum dimension requirement for MKLDNN; derived based on experiments.
//it's enabled on all Neoverse cpus.
return is_arm_neoverse() ? 8 : 0;
}();
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_DIM");
return value.has_value() ? std::stoi(value.value()) : default_min_dim;
}();
return value;
}
static inline int64_t get_mkldnn_matmul_min_size() {
static auto value = [&] {
const int64_t default_min_size = [&] {
// Minimum size requirement for MKLDNN; derived based on experiments.
// it's enabled on all Neoverse cpus.
return is_arm_neoverse() ? 8 * 1024 : 0;
}();
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_SIZE");
return value.has_value() ? std::stoi(value.value()) : default_min_size;
}();
return value;
}
static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) {
const int64_t min_dim = get_mkldnn_matmul_min_dim();
const int64_t min_size = get_mkldnn_matmul_min_size();
return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
}
static void addmm_impl_cpu_( static void addmm_impl_cpu_(
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) { Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2); TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
@ -1479,7 +1514,8 @@ static void addmm_impl_cpu_(
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also // that will call then into Arm® Compute Library (ACL) GEMM kernel and also
// additionally have support for running kernel with BF16 instructions // additionally have support for running kernel with BF16 instructions
if (transpose_c) { if (transpose_c) {
if (use_mkldnn_matmul(b, a, c) && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) { bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
try { try {
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>()); mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
// We have dispatched to ACL GEMM for single precision float // We have dispatched to ACL GEMM for single precision float
@ -1735,7 +1771,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
(strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1])); (strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
}; };
if (use_mkldnn_matmul(batch1, batch2, self_or_result)) { bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
try { try {
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>()); mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
return; return;

View File

@ -391,42 +391,6 @@ void mkldnn_matmul(
} }
#if AT_MKLDNN_ACL_ENABLED()
// Experimentally derived heuristics for MKLDNN+ACL on NEOVERSE cores
static inline int64_t get_mkldnn_acl_addmm_min_dim() {
static auto value = [&] {
const int64_t default_min_dim = [&] {
return is_arm_neoverse() ? 8 : 0;
}();
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_DIM");
return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
}();
return value;
}
static inline int64_t get_mkldnn_acl_addmm_min_size() {
static auto value = [&] {
const int64_t default_min_size = [&] {
return is_arm_neoverse() ? 8 * 1024 : 0;
}();
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_SIZE");
return ptr != nullptr ? std::atoi(ptr) : default_min_size;
}();
return value;
}
static inline int64_t get_mkldnn_acl_bmm_baddbmm_threshold() {
static auto value = [&] {
const int64_t default_threshold = [&] {
return is_arm_neoverse() ? 1L << 22 : 0;
}();
const char* ptr = std::getenv("TORCH_MKLDNN_BMM_BADDBMM_THRESHOLD");
return ptr != nullptr ? std::atoi(ptr) : default_threshold;
}();
return value;
}
#endif
static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k) // if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
@ -441,26 +405,10 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size; return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
} else if (mat2.dim() == 2 && mat2.dim() == 2) { } else if (mat2.dim() == 2 && mat2.dim() == 2) {
// aten::addmm // aten::addmm
#if AT_MKLDNN_ACL_ENABLED()
const int64_t mkldnn_acl_addmm_min_dim = get_mkldnn_acl_addmm_min_dim();
const int64_t mkldnn_acl_addmm_min_size = get_mkldnn_acl_addmm_min_size();
// M > MIN_DIM and N > MIN_DIM and K > MIN_DIM and M*N*K > MIN_SIZE
return mat1.size(0) > mkldnn_acl_addmm_min_dim
&& mat1.size(1) > mkldnn_acl_addmm_min_dim
&& mat2.size(1) > mkldnn_acl_addmm_min_dim
&& mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_acl_addmm_min_size;
#else
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size; return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
#endif
} else { } else {
// aten::bmm, aten::baddbmm // aten::bmm, aten::baddbmm
#if AT_MKLDNN_ACL_ENABLED()
const int64_t mkldnn_acl_bmm_baddbmm_threshold = get_mkldnn_acl_bmm_baddbmm_threshold();
// BATCH_SIZE^2 * M * N * K >= THRESHOLD
return mat1.size(0) * mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) >= mkldnn_acl_bmm_baddbmm_threshold;
#else
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size; return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
#endif
} }
} }