mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
104 lines
4.6 KiB
ReStructuredText
104 lines
4.6 KiB
ReStructuredText
Frequently Asked Questions
|
|
==========================
|
|
|
|
My model reports "cuda runtime error(2): out of memory"
|
|
-------------------------------------------------------
|
|
|
|
As the error message suggests, you have run out of memory on your
|
|
GPU. Since we often deal with large amounts of data in PyTorch,
|
|
small mistakes can rapidly cause your program to use up all of your
|
|
GPU; fortunately, the fixes in these cases are often simple.
|
|
Here are a few common things to check:
|
|
|
|
**Don't accumulate history across your training loop.**
|
|
By default, computations involving variables that require gradients
|
|
will keep history. This means that you should avoid using such
|
|
variables in computations which will live beyond your training loops,
|
|
e.g., when tracking statistics. Instead, you should detach the variable
|
|
or access its underlying data.
|
|
|
|
Sometimes, it can be non-obvious when differentiable variables can
|
|
occur. Consider the following training loop (abridged from `source
|
|
<https://discuss.pytorch.org/t/high-memory-usage-while-training/162>`_):
|
|
|
|
.. code-block:: python
|
|
|
|
total_loss = 0
|
|
for i in range(10000):
|
|
optimizer.zero_grad()
|
|
output = model(input)
|
|
loss = criterion(output)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss[0]
|
|
|
|
Here, ``total_loss`` is accumulating history across your training loop.
|
|
This code looks innocuous because ``loss[0]`` implies that you are
|
|
converting the tensor to a scalar, but ``loss[0]`` is still a
|
|
differentiable Variable! You can fix this code by writing
|
|
``loss.data[0]`` instead.
|
|
|
|
Other instances of this problem:
|
|
`1 <https://discuss.pytorch.org/t/resolved-gpu-out-of-memory-error-with-batch-size-1/3719>`_.
|
|
|
|
**Don't hold onto tensors and variables you don't need.**
|
|
If you assign a Tensor or Variable to a local, Python will not
|
|
deallocate until the local goes out of scope. You can free
|
|
this reference by using ``del x``. Similarly, if you assign
|
|
a Tensor or Variable to a member variable of an object, it will
|
|
not deallocate until the object goes out of scope. You will
|
|
get the best memory usage if you don't hold onto temporaries
|
|
you don't need.
|
|
|
|
The scopes of locals can be larger than you expect. For example:
|
|
|
|
.. code-block:: python
|
|
|
|
for i in range(5):
|
|
intermediate = f(input[i])
|
|
result += g(intermediate)
|
|
output = h(result)
|
|
return output
|
|
|
|
Here, ``intermediate`` remains live even while ``h`` is executing,
|
|
because its scope extrudes past the end of the loop. To free it
|
|
earlier, you should ``del intermediate`` when you are done with it.
|
|
|
|
**Don't run RNNs on sequences that are too large.**
|
|
The amount of memory required to backpropagate through an RNN scales
|
|
linearly with the length of the RNN; thus, you will run out of memory
|
|
if you try to feed an RNN a sequence that is too long.
|
|
|
|
The technical term for this phenomenon is `backpropagation through time
|
|
<https://en.wikipedia.org/wiki/Backpropagation_through_time>`_,
|
|
and there are plenty of references for how to implement truncated
|
|
BPTT, including in the `word language model <https://github.com/pytorch/examples/tree/master/word_language_model>`_ example; truncation is handled by the
|
|
``repackage`` function as described in
|
|
`this forum post <https://discuss.pytorch.org/t/help-clarifying-repackage-hidden-in-word-language-model/226>`_.
|
|
|
|
**Don't use linear layers that are too large.**
|
|
A linear layer ``nn.Linear(m, n)`` uses :math:`O(nm)` memory: that is to say,
|
|
the memory requirements of the weights
|
|
scales quadratically with the number of features. It is very easy
|
|
to `blow through your memory <https://github.com/pytorch/pytorch/issues/958>`_
|
|
this way (and remember that you will need at least twice the size of the
|
|
weights, since you also need to store the gradients.)
|
|
|
|
My GPU memory isn't freed properly
|
|
-------------------------------------------------------
|
|
PyTorch use a caching memory allocator to speed up memory allocations. As a
|
|
result, the values shown in ``nvidia-smi`` usually don't reflect the true
|
|
memory usage. See :ref:`cuda-memory-management` for more details about GPU
|
|
memory management.
|
|
|
|
If your GPU memory isn't freed even after Python quits, it is very likely that
|
|
some Python subprocesses are still alive. You may find them via
|
|
``ps -elf | grep python`` and manually kill them with ``kill -9 [pid]``.
|
|
|
|
My data loader workers return identical random numbers
|
|
-------------------------------------------------------
|
|
You are likely using other libraries to generate random numbers in the dataset.
|
|
For example, NumPy's RNG is duplicated when worker subprocesses are started via
|
|
``fork``. See :class:`torch.utils.data.DataLoader`'s document for how to
|
|
properly set up random seeds in workers.
|