mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
# Summary This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements. The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking. The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family. Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf. ```Shell +---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+ | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim | dtype | head_dim | +---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+ | Average | 1.2388050062214226 | | | | | | | | | Max | 1.831672915579016 | 128 | 32 | 1024 | 2048 | 2048 | torch.bfloat16 | 64 | | Min | 0.9430534166730135 | 1 | 16 | 256 | 416 | 2048 | torch.bfloat16 | 128 | +---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823 Approved by: https://github.com/cpuhrsch
22 lines
353 B
ReStructuredText
22 lines
353 B
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
torch.nn.attention.bias
|
|
========================
|
|
|
|
.. automodule:: torch.nn.attention.bias
|
|
.. currentmodule:: torch.nn.attention.bias
|
|
|
|
CausalBias
|
|
==========
|
|
|
|
.. autoclass:: CausalBias
|
|
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
causal_lower_right
|
|
causal_upper_left
|
|
CausalVariant
|