mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix mha torch._check in jit tracing (#142059)
Test Plan: `buck2 run @//mode/dev-nosan //mobile-vision/d2go/projects_oss/detr:tests -- -r test_detr_fbnet_export` Differential Revision: D66769339 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142059 Approved by: https://github.com/ezyang
This commit is contained in:
parent
540dc0c114
commit
a9d84875a9
|
|
@ -5978,6 +5978,21 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
|
|||
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
||||
|
||||
|
||||
def _check_key_padding_mask(
|
||||
key_padding_mask: torch.Tensor, src_len: int, bsz: int
|
||||
) -> None:
|
||||
torch._check_with(
|
||||
AssertionError,
|
||||
key_padding_mask.shape[0] == bsz,
|
||||
lambda: f"Expected key_padded_mask.shape[0] to be {bsz}, but got {key_padding_mask.shape[0]}",
|
||||
)
|
||||
torch._check_with(
|
||||
AssertionError,
|
||||
key_padding_mask.shape[1] == src_len,
|
||||
lambda: f"Expected key_padded_mask.shape[1] to be {src_len}, but got {key_padding_mask.shape[1]}",
|
||||
)
|
||||
|
||||
|
||||
def multi_head_attention_forward(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
|
|
@ -6316,17 +6331,8 @@ def multi_head_attention_forward(
|
|||
|
||||
# merge key padding and attention masks
|
||||
if key_padding_mask is not None:
|
||||
if not torch.jit.is_scripting():
|
||||
torch._check_with(
|
||||
AssertionError,
|
||||
key_padding_mask.shape[0] == bsz,
|
||||
None,
|
||||
)
|
||||
torch._check_with(
|
||||
AssertionError,
|
||||
key_padding_mask.shape[1] == src_len,
|
||||
None,
|
||||
)
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_check_key_padding_mask(key_padding_mask, src_len, bsz)
|
||||
|
||||
key_padding_mask = (
|
||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user