mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
# Summary This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements. The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking. The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family. Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf. ```Shell +---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+ | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim | dtype | head_dim | +---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+ | Average | 1.2388050062214226 | | | | | | | | | Max | 1.831672915579016 | 128 | 32 | 1024 | 2048 | 2048 | torch.bfloat16 | 64 | | Min | 0.9430534166730135 | 1 | 16 | 256 | 416 | 2048 | torch.bfloat16 | 128 | +---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823 Approved by: https://github.com/cpuhrsch
146 lines
3.8 KiB
ReStructuredText
146 lines
3.8 KiB
ReStructuredText
.. PyTorch documentation master file, created by
|
|
sphinx-quickstart on Fri Dec 23 13:31:47 2016.
|
|
You can adapt this file completely to your liking, but it should at least
|
|
contain the root `toctree` directive.
|
|
|
|
:github_url: https://github.com/pytorch/pytorch
|
|
|
|
PyTorch documentation
|
|
===================================
|
|
|
|
PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
|
|
|
Features described in this documentation are classified by release status:
|
|
|
|
*Stable:* These features will be maintained long-term and there should generally
|
|
be no major performance limitations or gaps in documentation.
|
|
We also expect to maintain backwards compatibility (although
|
|
breaking changes can happen and notice will be given one release ahead
|
|
of time).
|
|
|
|
*Beta:* These features are tagged as Beta because the API may change based on
|
|
user feedback, because the performance needs to improve, or because
|
|
coverage across operators is not yet complete. For Beta features, we are
|
|
committing to seeing the feature through to the Stable classification.
|
|
We are not, however, committing to backwards compatibility.
|
|
|
|
*Prototype:* These features are typically not available as part of
|
|
binary distributions like PyPI or Conda, except sometimes behind run-time
|
|
flags, and are at an early stage for feedback and testing.
|
|
|
|
.. toctree::
|
|
:glob:
|
|
:maxdepth: 1
|
|
:caption: Community
|
|
|
|
community/*
|
|
|
|
.. toctree::
|
|
:glob:
|
|
:maxdepth: 1
|
|
:caption: Developer Notes
|
|
|
|
notes/*
|
|
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
:caption: Language Bindings
|
|
|
|
cpp_index
|
|
Javadoc <https://pytorch.org/javadoc/>
|
|
torch::deploy <deploy>
|
|
|
|
.. toctree::
|
|
:glob:
|
|
:maxdepth: 2
|
|
:caption: Python API
|
|
|
|
torch
|
|
nn
|
|
nn.functional
|
|
tensors
|
|
tensor_attributes
|
|
tensor_view
|
|
torch.amp <amp>
|
|
torch.autograd <autograd>
|
|
torch.library <library>
|
|
cpu
|
|
cuda
|
|
torch.cuda.memory <torch_cuda_memory>
|
|
mps
|
|
torch.backends <backends>
|
|
torch.export <export>
|
|
torch.distributed <distributed>
|
|
torch.distributed.algorithms.join <distributed.algorithms.join>
|
|
torch.distributed.elastic <distributed.elastic>
|
|
torch.distributed.fsdp <fsdp>
|
|
torch.distributed.optim <distributed.optim>
|
|
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
|
torch.distributed.checkpoint <distributed.checkpoint>
|
|
torch.distributions <distributions>
|
|
torch.compiler <torch.compiler>
|
|
torch.fft <fft>
|
|
torch.func <func>
|
|
futures
|
|
fx
|
|
holistic_trace_analysis
|
|
torch.hub <hub>
|
|
torch.jit <jit>
|
|
torch.linalg <linalg>
|
|
torch.monitor <monitor>
|
|
torch.signal <signal>
|
|
torch.special <special>
|
|
torch.overrides
|
|
torch.package <package>
|
|
profiler
|
|
nn.init
|
|
nn.attention.bias
|
|
onnx
|
|
optim
|
|
complex_numbers
|
|
ddp_comm_hooks
|
|
pipeline
|
|
quantization
|
|
rpc
|
|
torch.random <random>
|
|
masked
|
|
torch.nested <nested>
|
|
sparse
|
|
storage
|
|
torch.testing <testing>
|
|
torch.utils <utils>
|
|
torch.utils.benchmark <benchmark_utils>
|
|
torch.utils.bottleneck <bottleneck>
|
|
torch.utils.checkpoint <checkpoint>
|
|
torch.utils.cpp_extension <cpp_extension>
|
|
torch.utils.data <data>
|
|
torch.utils.deterministic <deterministic>
|
|
torch.utils.jit <jit_utils>
|
|
torch.utils.dlpack <dlpack>
|
|
torch.utils.mobile_optimizer <mobile_optimizer>
|
|
torch.utils.model_zoo <model_zoo>
|
|
torch.utils.tensorboard <tensorboard>
|
|
type_info
|
|
named_tensor
|
|
name_inference
|
|
torch.__config__ <config_mod>
|
|
logging
|
|
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
:caption: Libraries
|
|
|
|
torchaudio <https://pytorch.org/audio/stable>
|
|
TorchData <https://pytorch.org/data>
|
|
TorchRec <https://pytorch.org/torchrec>
|
|
TorchServe <https://pytorch.org/serve>
|
|
torchtext <https://pytorch.org/text/stable>
|
|
torchvision <https://pytorch.org/vision/stable>
|
|
PyTorch on XLA Devices <https://pytorch.org/xla/>
|
|
|
|
Indices and tables
|
|
==================
|
|
|
|
* :ref:`genindex`
|
|
* :ref:`modindex`
|