mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Add user-facing documentation for set_deterministic Also update grammar and readability in Reproducibility page Issue https://github.com/pytorch/pytorch/issues/15359 Pull Request resolved: https://github.com/pytorch/pytorch/pull/41692 Reviewed By: ailzhang Differential Revision: D23433061 Pulled By: mruberry fbshipit-source-id: 4c4552950803c2aaf80f7bb4792d2095706d07cf
124 lines
5.6 KiB
ReStructuredText
124 lines
5.6 KiB
ReStructuredText
Reproducibility
|
|
===============
|
|
|
|
Completely reproducible results are not guaranteed across PyTorch releases,
|
|
individual commits, or different platforms. Furthermore, results may not be
|
|
reproducible between CPU and GPU executions, even when using identical seeds.
|
|
|
|
However, there are some steps you can take to limit the number of sources of
|
|
nondeterministic behavior for a specific platform, device, and PyTorch release.
|
|
First, you can control sources of randomness that can cause multiple executions
|
|
of your application to behave differently. Second, you can configure PyTorch
|
|
to avoid using nondeterministic algorithms for some operations, so that multiple
|
|
calls to those operations, given the same inputs, will produce the same result.
|
|
|
|
.. warning::
|
|
|
|
Deterministic operations are often slower than nondeterministic operations, so
|
|
single-run performance may decrease for your model. However, determinism may
|
|
save time in development by facilitating experimentation, debugging, and
|
|
regression testing.
|
|
|
|
Controlling sources of randomness
|
|
.................................
|
|
|
|
PyTorch random number generator
|
|
-------------------------------
|
|
You can use :meth:`torch.manual_seed()` to seed the RNG for all devices (both
|
|
CPU and CUDA)::
|
|
|
|
import torch
|
|
torch.manual_seed(0)
|
|
|
|
Random number generators in other libraries
|
|
-------------------------------------------
|
|
If you or any of the libraries you are using rely on NumPy, you can seed the global
|
|
NumPy RNG with::
|
|
|
|
import numpy as np
|
|
np.random.seed(0)
|
|
|
|
However, some applications and libraries may use NumPy Random Generator objects,
|
|
not the global RNG
|
|
(`<https://numpy.org/doc/stable/reference/random/generator.html>`_), and those will
|
|
need to be seeded consistently as well.
|
|
|
|
If you are using any other libraries that use random number generators, refer to
|
|
the documentation for those libraries to see how to set consistent seeds for them.
|
|
|
|
CUDA convolution benchmarking
|
|
-----------------------------
|
|
The cuDNN library, used by CUDA convolution operations, can be a source of nondeterminism
|
|
across multiple executions of an application. When a cuDNN convolution is called with a
|
|
new set of size parameters, an optional feature can run multiple convolution algorithms,
|
|
benchmarking them to find the fastest one. Then, the fastest algorithm will be used
|
|
consistently during the rest of the process for the corresponding set of size parameters.
|
|
Due to benchmarking noise and different hardware, the benchmark may select different
|
|
algorithms on subsequent runs, even on the same machine.
|
|
|
|
Disabling the benchmarking feature with :code:`torch.backends.cudnn.benchmark = False`
|
|
causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced
|
|
performance.
|
|
|
|
However, if you do not need reproducibility across multiple executions of your application,
|
|
then performance might improve if the benchmarking feature is enabled with
|
|
:code:`torch.backends.cudnn.benchmark = True`.
|
|
|
|
Note that this setting is different from the :code:`torch.backends.cudnn.deterministic`
|
|
setting discussed below.
|
|
|
|
Avoiding nondeterministic algorithms
|
|
....................................
|
|
:meth:`torch.set_deterministic` lets you configure PyTorch to use deterministic
|
|
algorithms instead of nondeterministic ones where available, and to throw an error
|
|
if an operation is known to be nondeterministic (and without a deterministic
|
|
alternative).
|
|
|
|
Please check the documentation for :meth:`torch.set_deterministic()` for a full
|
|
list of affected operations. If an operation does not act correctly according to
|
|
the documentation, or if you need a deterministic implementation of an operation
|
|
that does not have one, please submit an issue:
|
|
`<https://github.com/pytorch/pytorch/issues?q=label:%22topic:%20determinism%22>`_
|
|
|
|
For example, running the nondeterministic CUDA implementation of :meth:`torch.Tensor.index_add_`
|
|
will throw an error::
|
|
|
|
>>> import torch
|
|
>>> torch.set_deterministic(True)
|
|
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
|
|
Traceback (most recent call last):
|
|
File "<stdin>", line 1, in <module>
|
|
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
|
|
'torch.set_deterministic(True)'. ...
|
|
|
|
When :meth:`torch.bmm` is called with sparse-dense CUDA tensors it typically uses a
|
|
nondeterministic algorithm, but when the deterministic flag is turned on, its alternate
|
|
deterministic implementation will be used::
|
|
|
|
>>> import torch
|
|
>>> torch.set_deterministic(True)
|
|
>>> torch.bmm(torch.randn(2, 2, 2).to_sparse().cuda(), torch.randn(2, 2, 2).cuda())
|
|
tensor([[[ 1.1900, -2.3409],
|
|
[ 0.4796, 0.8003]],
|
|
[[ 0.1509, 1.8027],
|
|
[ 0.0333, -1.1444]]], device='cuda:0')
|
|
|
|
Furthermore, if you are using CUDA tensors, and your CUDA version is 10.2 or greater, you
|
|
should set the environment variable `CUBLAS_WORKSPACE_CONFIG` according to CUDA documentation:
|
|
`<https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility>`_
|
|
|
|
CUDA convolution determinism
|
|
----------------------------
|
|
While disabling CUDA convolution benchmarking (discussed above) ensures that CUDA
|
|
selects the same algorithm each time an application is run, that algorithm itself
|
|
may be nondeterministic, unless either :code:`torch.set_deterministic(True)` or
|
|
:code:`torch.backends.cudnn.deterministic = True` is set. The latter setting controls
|
|
only this behavior, unlike :meth:`torch.set_deterministic` which will make other
|
|
PyTorch operations behave deterministically, too.
|
|
|
|
CUDA RNN and LSTM
|
|
-----------------
|
|
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
|
|
See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
|
|
|