Fix mha torch._check in jit tracing (#142059)

Test Plan: `buck2 run @//mode/dev-nosan //mobile-vision/d2go/projects_oss/detr:tests -- -r test_detr_fbnet_export`

Differential Revision: D66769339

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142059
Approved by: https://github.com/ezyang
This commit is contained in:
Angela Yi 2024-12-05 18:38:14 +00:00 committed by PyTorch MergeBot
parent 540dc0c114
commit a9d84875a9

View File

@ -5978,6 +5978,21 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
def _check_key_padding_mask(
key_padding_mask: torch.Tensor, src_len: int, bsz: int
) -> None:
torch._check_with(
AssertionError,
key_padding_mask.shape[0] == bsz,
lambda: f"Expected key_padded_mask.shape[0] to be {bsz}, but got {key_padding_mask.shape[0]}",
)
torch._check_with(
AssertionError,
key_padding_mask.shape[1] == src_len,
lambda: f"Expected key_padded_mask.shape[1] to be {src_len}, but got {key_padding_mask.shape[1]}",
)
def multi_head_attention_forward(
query: Tensor,
key: Tensor,
@ -6316,17 +6331,8 @@ def multi_head_attention_forward(
# merge key padding and attention masks
if key_padding_mask is not None:
if not torch.jit.is_scripting():
torch._check_with(
AssertionError,
key_padding_mask.shape[0] == bsz,
None,
)
torch._check_with(
AssertionError,
key_padding_mask.shape[1] == src_len,
None,
)
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_check_key_padding_mask(key_padding_mask, src_len, bsz)
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)