Mention TF32 on related docs (#44690)

Summary:
cc: ptrblck

![image](https://user-images.githubusercontent.com/1032377/93168022-cbbfcb80-f6d6-11ea-8f6e-f2c8a15c5bea.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44690

Reviewed By: ngimel

Differential Revision: D23727921

Pulled By: mruberry

fbshipit-source-id: db7cc8e74cde09c13d6a57683129fd839863b914
This commit is contained in:
Xiang Gao 2020-09-16 19:12:25 -07:00 committed by Facebook GitHub Bot
parent 79108fc16c
commit e48201c5cf
7 changed files with 140 additions and 6 deletions

87
docs/source/backends.rst Normal file
View File

@ -0,0 +1,87 @@
.. role:: hidden
:class: hidden-section
torch.backends
==============
`torch.backends` controls the behavior of various backends that PyTorch supports.
These backends include:
- ``torch.backends.cuda``
- ``torch.backends.cudnn``
- ``torch.backends.mkl``
- ``torch.backends.mkldnn``
- ``torch.backends.openmp``
torch.backends.cuda
^^^^^^^^^^^^^^^^^^^
.. autofunction:: torch.backends.cuda.is_built
.. attribute:: torch.backends.cuda.matmul.allow_tf32
A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix
multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`.
.. attribute:: torch.backends.cuda.cufft_plan_cache
``cufft_plan_cache`` caches the cuFFT plans
.. attribute:: size
A readonly :class:`int` that shows the number of plans currently in the cuFFT plan cache.
.. attribute:: max_size
A :class:`int` that controls cache capacity of cuFFT plan.
.. method:: clear()
Clears the cuFFT plan cache.
torch.backends.cudnn
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: torch.backends.cudnn.version
.. autofunction:: torch.backends.cudnn.is_available
.. attribute:: torch.backends.cudnn.enabled
A :class:`bool` that controls whether cuDNN is enabled.
.. attribute:: torch.backends.cudnn.allow_tf32
A :class:`bool` that controls where TensorFloat-32 tensor cores may be used in cuDNN
convolutions on Ampere or newer GPUs. See :ref:`tf32_on_ampere`.
.. attribute:: torch.backends.cudnn.deterministic
A :class:`bool` that, if True, causes cuDNN to only use deterministic convolution algorithms.
See also :func:`torch.is_deterministic` and :func:`torch.set_deterministic`.
.. attribute:: torch.backends.cudnn.benchmark
A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms
and select the fastest.
torch.backends.mkl
^^^^^^^^^^^^^^^^^^
.. autofunction:: torch.backends.mkl.is_available
torch.backends.mkldnn
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: torch.backends.mkldnn.is_available
torch.backends.openmp
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: torch.backends.openmp.is_available

View File

@ -37,6 +37,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
torch.autograd <autograd>
cuda
torch.cuda.amp <amp>
torch.backends <backends>
torch.distributed <distributed>
torch.distributions <distributions>
torch.fft <fft>

View File

@ -106,6 +106,10 @@ factory_data_common_args = parse_kwargs("""
the pinned memory. Works only for CPU tensors. Default: ``False``.
""")
tf32_notes = {
"tf32_note": """This operator supports :ref:`TensorFloat32<tf32_on_ampere>`."""
}
add_docstr(torch.abs, r"""
abs(input, *, out=None) -> Tensor
@ -290,6 +294,8 @@ it will not be propagated.
For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha`
must be real numbers, otherwise they should be integers.
{tf32_note}
Args:
batch1 (Tensor): the first batch of matrices to be multiplied
batch2 (Tensor): the second batch of matrices to be multiplied
@ -309,7 +315,7 @@ Example::
tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653],
[ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743],
[ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]])
""".format(**common_args))
""".format(**common_args, **tf32_notes))
add_docstr(torch.addcdiv, r"""
addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor
@ -417,6 +423,8 @@ it will not be propagated.
For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
:attr:`alpha` must be real numbers, otherwise they should be integers.
{tf32_note}
Args:
input (Tensor): matrix to be added
mat1 (Tensor): the first matrix to be multiplied
@ -435,7 +443,7 @@ Example::
>>> torch.addmm(M, mat1, mat2)
tensor([[-4.8716, 1.4671, -1.3746],
[ 0.7573, -3.9555, -2.8681]])
""".format(**common_args))
""".format(**common_args, **tf32_notes))
add_docstr(torch.addmv,
r"""
@ -833,6 +841,8 @@ it will not be propagated.
For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
:attr:`alpha` must be real numbers, otherwise they should be integers.
{tf32_note}
Args:
input (Tensor): the tensor to be added
batch1 (Tensor): the first batch of matrices to be multiplied
@ -850,7 +860,7 @@ Example::
>>> batch2 = torch.randn(10, 4, 5)
>>> torch.baddbmm(M, batch1, batch2).size()
torch.Size([10, 3, 5])
""".format(**common_args))
""".format(**common_args, **tf32_notes))
add_docstr(torch.bernoulli,
r"""
@ -991,6 +1001,8 @@ If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a
.. math::
\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i
""" + r"""
{tf32_note}
.. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
For broadcasting matrix products, see :func:`torch.matmul`.
@ -1012,7 +1024,7 @@ Example::
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])
""".format(**common_args))
""".format(**common_args, **tf32_notes))
add_docstr(torch.bitwise_and,
r"""
@ -4855,6 +4867,8 @@ If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a
.. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
For broadcasting matrix products, see :func:`torch.matmul`.
{tf32_note}
Args:
input (Tensor): the first matrix to be multiplied
mat2 (Tensor): the second matrix to be multiplied
@ -4869,7 +4883,7 @@ Example::
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
""".format(**common_args))
""".format(**common_args, **tf32_notes))
add_docstr(torch.matmul,
r"""
@ -4896,6 +4910,8 @@ The behavior depends on the dimensionality of the tensors as follows:
:math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)`
tensor, :attr:`out` will be an :math:`(j \times k \times n \times p)` tensor.
{tf32_note}
.. note::
The 1-dimensional dot product version of this function does not support an :attr:`out` parameter.
@ -4935,7 +4951,7 @@ Example::
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
""".format(**common_args))
""".format(**common_args, **tf32_notes))
add_docstr(torch.mode,
r"""

View File

@ -46,6 +46,7 @@ else:
def version():
"""Returns the version of cuDNN"""
if not _init():
return None
return __cudnn_version

View File

@ -21,6 +21,8 @@ conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> T
Applies a 1D convolution over an input signal composed of several input
planes.
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
See :class:`~torch.nn.Conv1d` for details and output shape.
Note:
@ -57,6 +59,8 @@ conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> T
Applies a 2D convolution over an input image composed of several input
planes.
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
See :class:`~torch.nn.Conv2d` for details and output shape.
Note:
@ -95,6 +99,8 @@ conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> T
Applies a 3D convolution over an input image composed of several input
planes.
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
See :class:`~torch.nn.Conv3d` for details and output shape.
Note:
@ -131,6 +137,8 @@ conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0
Applies a 1D transposed convolution operator over an input signal
composed of several input planes, sometimes also called "deconvolution".
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
See :class:`~torch.nn.ConvTranspose1d` for details and output shape.
Note:
@ -170,6 +178,8 @@ conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0
Applies a 2D transposed convolution operator over an input image
composed of several input planes, sometimes also called "deconvolution".
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
See :class:`~torch.nn.ConvTranspose2d` for details and output shape.
Note:
@ -211,6 +221,8 @@ conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0
Applies a 3D transposed convolution operator over an input image
composed of several input planes, sometimes also called "deconvolution"
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
See :class:`~torch.nn.ConvTranspose3d` for details and output shape.
Note:
@ -1660,6 +1672,8 @@ def linear(input, weight, bias=None):
r"""
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
Shape:
- Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of

View File

@ -129,6 +129,8 @@ class Conv1d(_ConvNd):
:math:`N` is a batch size, :math:`C` denotes a number of channels,
:math:`L` is a length of signal sequence.
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation, a single
number or a one-element tuple.
@ -275,6 +277,8 @@ class Conv2d(_ConvNd):
:math:`H` is a height of input planes in pixels, and :math:`W` is
width in pixels.
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation, a single
number or a tuple.
@ -431,6 +435,8 @@ class Conv3d(_ConvNd):
where :math:`\star` is the valid 3D `cross-correlation`_ operator
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation.
* :attr:`padding` controls the amount of implicit zero-paddings on both
@ -629,6 +635,8 @@ class ConvTranspose1d(_ConvTransposeNd):
It is also known as a fractionally-strided convolution or
a deconvolution (although it is not an actual deconvolution operation).
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation.
* :attr:`padding` controls the amount of implicit zero-paddings on both
@ -681,6 +689,7 @@ class ConvTranspose1d(_ConvTransposeNd):
True``.
Please see the notes on :doc:`/notes/randomness` for background.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
@ -762,6 +771,8 @@ class ConvTranspose2d(_ConvTransposeNd):
It is also known as a fractionally-strided convolution or
a deconvolution (although it is not an actual deconvolution operation).
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation.
* :attr:`padding` controls the amount of implicit zero-paddings on both
@ -928,6 +939,8 @@ class ConvTranspose3d(_ConvTransposeNd):
It is also known as a fractionally-strided convolution or
a deconvolution (although it is not an actual deconvolution operation).
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
* :attr:`stride` controls the stride for the cross-correlation.
* :attr:`padding` controls the amount of implicit zero-paddings on both

View File

@ -34,6 +34,8 @@ class Identity(Module):
class Linear(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
Args:
in_features: size of each input sample
out_features: size of each output sample