mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] Fix MHA grain size computation (#72463)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72463 maxing with 1 makes a lot more sense to me than minning with 1, but I have no idea what I'm doing. ghstack-source-id: 149067332 Test Plan: CI Reviewed By: zrphercule Differential Revision: D33990633 fbshipit-source-id: c706148c357473c929020f5dc65cc5050611af8f
This commit is contained in:
parent
5d8de9a122
commit
2adf3be11a
|
|
@ -42,7 +42,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
||||||
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
|
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
|
||||||
|
|
||||||
int64_t grain_size =
|
int64_t grain_size =
|
||||||
std::min(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
|
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
|
||||||
parallel_for(
|
parallel_for(
|
||||||
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
|
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
|
||||||
for (auto i : c10::irange(begin, end)) {
|
for (auto i : c10::irange(begin, end)) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user