mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix torch.set_float32_matmul_precision doc (#119620)
Fixes #119606, clearify the explictly stored number of bits in doc Pull Request resolved: https://github.com/pytorch/pytorch/pull/119620 Approved by: https://github.com/eqy, https://github.com/malfet
This commit is contained in:
parent
88183923d2
commit
6cd82253ae
|
|
@ -1015,24 +1015,24 @@ def set_float32_matmul_precision(precision: str) -> None:
|
|||
Supports three settings:
|
||||
|
||||
* "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
|
||||
bits) for internal computations.
|
||||
bits with 23 bits explicitly stored) for internal computations.
|
||||
* "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
|
||||
mantissa bits) or treat each float32 number as the sum of two bfloat16 numbers
|
||||
(approximately 16 mantissa bits), if the appropriate fast matrix multiplication
|
||||
mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
|
||||
(approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
|
||||
algorithms are available. Otherwise float32 matrix multiplications are computed
|
||||
as if the precision is "highest". See below for more information on the bfloat16
|
||||
approach.
|
||||
* "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
|
||||
bits) for internal computations, if a fast matrix multiplication algorithm
|
||||
bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
|
||||
using that datatype internally is available. Otherwise float32
|
||||
matrix multiplications are computed as if the precision is "high".
|
||||
|
||||
When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
|
||||
that is more complicated than simply truncating to some smaller number mantissa bits
|
||||
(e.g. 10 for TensorFloat32, 8 for bfloat16). Refer to [Henry2019]_ for a complete
|
||||
(e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
|
||||
description of this algorithm. To briefly explain here, the first step is to realize
|
||||
that we can perfectly encode a single float32 number as the sum of three bfloat16
|
||||
numbers (because float32 has 24 mantissa bits while bfloat16 has 8, and both have the
|
||||
numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
|
||||
same number of exponent bits). This means that the product of two float32 numbers can
|
||||
be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
|
||||
accuracy for speed by dropping some of these products. The "high" precision algorithm
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user