[PyTorch] Fix MHA grain size computation (#72463)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72463

maxing with 1 makes a lot more sense to me than minning with 1, but I have no idea what I'm doing.
ghstack-source-id: 149067332

Test Plan: CI

Reviewed By: zrphercule

Differential Revision: D33990633

fbshipit-source-id: c706148c357473c929020f5dc65cc5050611af8f
This commit is contained in:
Scott Wolchok 2022-02-14 18:12:50 -08:00 committed by Facebook GitHub Bot
parent 5d8de9a122
commit 2adf3be11a

View File

@ -42,7 +42,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head)); const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
int64_t grain_size = int64_t grain_size =
std::min(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1); std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
parallel_for( parallel_for(
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) { 0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
for (auto i : c10::irange(begin, end)) { for (auto i : c10::irange(begin, end)) {