pytorch/docs/source/nn.attention.bias.rst
drisspg d4c79a3078 Add an attention bias subclass for a lower right causal masking (#114823)
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.

The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.

The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.

Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
|  Type   |      Speedup       | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim |     dtype      | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 |            |           |           |           |           |                |          |
|   Max   | 1.831672915579016  |    128     |    32     |   1024    |   2048    |   2048    | torch.bfloat16 |    64    |
|   Min   | 0.9430534166730135 |     1      |    16     |    256    |    416    |   2048    | torch.bfloat16 |   128    |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
2023-12-06 08:29:26 +00:00

22 lines
353 B
ReStructuredText

.. role:: hidden
:class: hidden-section
torch.nn.attention.bias
========================
.. automodule:: torch.nn.attention.bias
.. currentmodule:: torch.nn.attention.bias
CausalBias
==========
.. autoclass:: CausalBias
.. autosummary::
:toctree: generated
:nosignatures:
causal_lower_right
causal_upper_left
CausalVariant