mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Update MultiheadAttention documentations (#20071)
Summary: Add documentations to add_bias_kv, add_zero_attn, and attn_mask. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20071 Differential Revision: D15213034 Pulled By: zhangguanheng66 fbshipit-source-id: c3db4b9e8527863420ba3ce6abf6098d3b0fb7a7
This commit is contained in:
parent
ecdeef37df
commit
fc00bfd12e
|
|
@ -689,8 +689,11 @@ class MultiheadAttention(Module):
|
|||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model
|
||||
num_heads: parallel attention layers, or heads
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
add_bias_kv: add bias to the key and value sequences at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
|
||||
Examples::
|
||||
|
||||
|
|
@ -741,19 +744,37 @@ class MultiheadAttention(Module):
|
|||
@weak_script_method
|
||||
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
||||
need_weights=True, static_kv=False, attn_mask=None):
|
||||
"""
|
||||
Inputs of forward function
|
||||
query: [target length, batch size, embed dim]
|
||||
key: [sequence length, batch size, embed dim]
|
||||
value: [sequence length, batch size, embed dim]
|
||||
key_padding_mask: if True, mask padding based on batch size
|
||||
incremental_state: if provided, previous time steps are cashed
|
||||
need_weights: output attn_output_weights
|
||||
static_kv: key and value are static
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention.
|
||||
incremental_state: if provided, previous time steps are cached.
|
||||
need_weights: output attn_output_weights.
|
||||
static_kv: if true, key and value are static. The key and value in previous
|
||||
states will be used.
|
||||
attn_mask: mask that prevents attention to certain positions.
|
||||
|
||||
Outputs of forward function
|
||||
attn_output: [target length, batch size, embed dim]
|
||||
attn_output_weights: [batch size, target length, sequence length]
|
||||
Shape:
|
||||
- Inputs:
|
||||
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- incremental_state: a dictionary used for storing states.
|
||||
- attn_mask: :math:`(L, L)` where L is the target sequence length.
|
||||
|
||||
- Outputs:
|
||||
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
||||
kv_same = key.data_ptr() == value.data_ptr()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user