Update MultiheadAttention documentations (#20071)

Summary:
Add documentations to add_bias_kv, add_zero_attn, and attn_mask.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20071

Differential Revision: D15213034

Pulled By: zhangguanheng66

fbshipit-source-id: c3db4b9e8527863420ba3ce6abf6098d3b0fb7a7
This commit is contained in:
Guanheng Zhang 2019-05-04 13:52:27 -07:00 committed by Facebook Github Bot
parent ecdeef37df
commit fc00bfd12e

View File

@ -689,8 +689,11 @@ class MultiheadAttention(Module):
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
Examples::
@ -741,19 +744,37 @@ class MultiheadAttention(Module):
@weak_script_method
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention.
incremental_state: if provided, previous time steps are cached.
need_weights: output attn_output_weights.
static_kv: if true, key and value are static. The key and value in previous
states will be used.
attn_mask: mask that prevents attention to certain positions.
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- incremental_state: a dictionary used for storing states.
- attn_mask: :math:`(L, L)` where L is the target sequence length.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()