mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Update the heuristic for AArch64 bmm/baddbmm (#149122)"
This reverts commit d759a517af.
Reverted https://github.com/pytorch/pytorch/pull/149122 on behalf of https://github.com/jeanschmidt due to breaking internal models, @malfet may you help merge this? ([comment](https://github.com/pytorch/pytorch/pull/149122#issuecomment-2904703075))
This commit is contained in:
parent
5859582ee4
commit
866142ff16
|
|
@ -1360,6 +1360,41 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
static inline int64_t get_mkldnn_matmul_min_dim() {
|
||||||
|
static auto value = [&] {
|
||||||
|
const int64_t default_min_dim = [&] {
|
||||||
|
// Minimum dimension requirement for MKLDNN; derived based on experiments.
|
||||||
|
//it's enabled on all Neoverse cpus.
|
||||||
|
return is_arm_neoverse() ? 8 : 0;
|
||||||
|
}();
|
||||||
|
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_DIM");
|
||||||
|
return value.has_value() ? std::stoi(value.value()) : default_min_dim;
|
||||||
|
}();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static inline int64_t get_mkldnn_matmul_min_size() {
|
||||||
|
static auto value = [&] {
|
||||||
|
const int64_t default_min_size = [&] {
|
||||||
|
// Minimum size requirement for MKLDNN; derived based on experiments.
|
||||||
|
// it's enabled on all Neoverse cpus.
|
||||||
|
return is_arm_neoverse() ? 8 * 1024 : 0;
|
||||||
|
}();
|
||||||
|
const auto value = c10::utils::get_env("TORCH_MKLDNN_MATMUL_MIN_SIZE");
|
||||||
|
return value.has_value() ? std::stoi(value.value()) : default_min_size;
|
||||||
|
}();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) {
|
||||||
|
const int64_t min_dim = get_mkldnn_matmul_min_dim();
|
||||||
|
const int64_t min_size = get_mkldnn_matmul_min_size();
|
||||||
|
return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static void addmm_impl_cpu_(
|
static void addmm_impl_cpu_(
|
||||||
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
|
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
|
||||||
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
|
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
|
||||||
|
|
@ -1479,7 +1514,8 @@ static void addmm_impl_cpu_(
|
||||||
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
|
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
|
||||||
// additionally have support for running kernel with BF16 instructions
|
// additionally have support for running kernel with BF16 instructions
|
||||||
if (transpose_c) {
|
if (transpose_c) {
|
||||||
if (use_mkldnn_matmul(b, a, c) && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
|
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
|
||||||
|
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
|
||||||
try {
|
try {
|
||||||
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
|
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
|
||||||
// We have dispatched to ACL GEMM for single precision float
|
// We have dispatched to ACL GEMM for single precision float
|
||||||
|
|
@ -1735,7 +1771,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
|
||||||
(strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
|
(strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
|
||||||
};
|
};
|
||||||
|
|
||||||
if (use_mkldnn_matmul(batch1, batch2, self_or_result)) {
|
bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
|
||||||
|
if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
|
||||||
try {
|
try {
|
||||||
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
|
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -391,42 +391,6 @@ void mkldnn_matmul(
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if AT_MKLDNN_ACL_ENABLED()
|
|
||||||
// Experimentally derived heuristics for MKLDNN+ACL on NEOVERSE cores
|
|
||||||
static inline int64_t get_mkldnn_acl_addmm_min_dim() {
|
|
||||||
static auto value = [&] {
|
|
||||||
const int64_t default_min_dim = [&] {
|
|
||||||
return is_arm_neoverse() ? 8 : 0;
|
|
||||||
}();
|
|
||||||
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_DIM");
|
|
||||||
return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
|
|
||||||
}();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline int64_t get_mkldnn_acl_addmm_min_size() {
|
|
||||||
static auto value = [&] {
|
|
||||||
const int64_t default_min_size = [&] {
|
|
||||||
return is_arm_neoverse() ? 8 * 1024 : 0;
|
|
||||||
}();
|
|
||||||
const char* ptr = std::getenv("TORCH_MKLDNN_ADDMM_MIN_SIZE");
|
|
||||||
return ptr != nullptr ? std::atoi(ptr) : default_min_size;
|
|
||||||
}();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline int64_t get_mkldnn_acl_bmm_baddbmm_threshold() {
|
|
||||||
static auto value = [&] {
|
|
||||||
const int64_t default_threshold = [&] {
|
|
||||||
return is_arm_neoverse() ? 1L << 22 : 0;
|
|
||||||
}();
|
|
||||||
const char* ptr = std::getenv("TORCH_MKLDNN_BMM_BADDBMM_THRESHOLD");
|
|
||||||
return ptr != nullptr ? std::atoi(ptr) : default_threshold;
|
|
||||||
}();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||||
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
|
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
|
||||||
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
|
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
|
||||||
|
|
@ -441,26 +405,10 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||||
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
|
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
|
||||||
} else if (mat2.dim() == 2 && mat2.dim() == 2) {
|
} else if (mat2.dim() == 2 && mat2.dim() == 2) {
|
||||||
// aten::addmm
|
// aten::addmm
|
||||||
#if AT_MKLDNN_ACL_ENABLED()
|
|
||||||
const int64_t mkldnn_acl_addmm_min_dim = get_mkldnn_acl_addmm_min_dim();
|
|
||||||
const int64_t mkldnn_acl_addmm_min_size = get_mkldnn_acl_addmm_min_size();
|
|
||||||
// M > MIN_DIM and N > MIN_DIM and K > MIN_DIM and M*N*K > MIN_SIZE
|
|
||||||
return mat1.size(0) > mkldnn_acl_addmm_min_dim
|
|
||||||
&& mat1.size(1) > mkldnn_acl_addmm_min_dim
|
|
||||||
&& mat2.size(1) > mkldnn_acl_addmm_min_dim
|
|
||||||
&& mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_acl_addmm_min_size;
|
|
||||||
#else
|
|
||||||
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
|
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
// aten::bmm, aten::baddbmm
|
// aten::bmm, aten::baddbmm
|
||||||
#if AT_MKLDNN_ACL_ENABLED()
|
|
||||||
const int64_t mkldnn_acl_bmm_baddbmm_threshold = get_mkldnn_acl_bmm_baddbmm_threshold();
|
|
||||||
// BATCH_SIZE^2 * M * N * K >= THRESHOLD
|
|
||||||
return mat1.size(0) * mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) >= mkldnn_acl_bmm_baddbmm_threshold;
|
|
||||||
#else
|
|
||||||
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
|
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user