mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PyTorch] MHA: fix contiguity assumption in transform_bias_rescale_qkv (#72465)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72465 This code path incorrectly assumed input tensors were contiguous. Now we check that. ghstack-source-id: 149201476 Test Plan: CI Reviewed By: ngimel Differential Revision: D34007665 fbshipit-source-id: c43438f2495e32304ea3f7846e01eceb4a9448f7
This commit is contained in:
parent
03e45ceb89
commit
0767b225f2
|
|
@ -113,15 +113,18 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
||||||
TORCH_CHECK(_3D % 3 == 0);
|
TORCH_CHECK(_3D % 3 == 0);
|
||||||
const auto dim_per_head = D / num_head;
|
const auto dim_per_head = D / num_head;
|
||||||
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
|
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v.is_contiguous());
|
||||||
|
|
||||||
|
const auto qkv_contig = qkv.expect_contiguous();
|
||||||
|
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
ScalarType::Half,
|
ScalarType::Half,
|
||||||
ScalarType::BFloat16,
|
ScalarType::BFloat16,
|
||||||
qkv.scalar_type(),
|
qkv.scalar_type(),
|
||||||
"transform_bias_rescale_qkv",
|
"transform_bias_rescale_qkv",
|
||||||
[&] {
|
[&] {
|
||||||
scalar_t* qkv_data = qkv.data_ptr<scalar_t>();
|
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
|
||||||
scalar_t* qkv_bias_data = qkv_bias.data_ptr<scalar_t>();
|
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
|
||||||
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
|
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
|
||||||
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
|
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
|
||||||
|
|
||||||
|
|
@ -134,6 +137,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
||||||
});
|
});
|
||||||
auto q_k_v_s =
|
auto q_k_v_s =
|
||||||
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
|
||||||
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user