mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
For a masked `tl.load` operation, the Triton language specifies that values masked out (i.e. where the mask evaluates to false) are undefined in the output of the load. Triton provides an optional `other` parameter which, when included, provides an explicit value to use for masked out values from the load. If the output from a masked load without the `other` parameter is used in a conditional, unexpected behavior can occur. Despite the language specification, all Triton backends currently in use by PyTorch Inductor (NVIDIA, AMD, and Intel) 0-initialize masked loads if `other` is not present (we recently changed the Intel backend behavior to match NVIDIA and AMD because that's what our users expect, even if we are not following the Triton spec to the tee). This PR attempts to "future-proof" Inductor for new backends (or perhaps changes in the current backends? - we did not see any performance change from 0-initializing in the Intel XPU backend but one could imagine compiler optimizations to remove paths that depend on undefined) to add an explicit `other` in instances where later conditionals depend on the `tl.load` output. I also removed an exception to `other` behavior for boolean loads, which was put in place for a Triton bug that should be fixed. I added `other` to the getting started documentation as a clue that masked load behavior requires explicit initialization if, even though I don't expect `undef` values to cause the example code to fail if the underlying output is not 0-initialized. Finally, I added other to the `make_load` function in `select_algorithm.py`, though I wasn't able to determine if that function was actually being called. Fixes #126535 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127311 Approved by: https://github.com/jansel
148 lines
6.2 KiB
ReStructuredText
148 lines
6.2 KiB
ReStructuredText
.. _torch.compiler_get_started:
|
|
|
|
Getting Started
|
|
===============
|
|
|
|
Before you read this section, make sure to read the :ref:`torch.compiler_overview`.
|
|
|
|
Let's start by looking at a simple ``torch.compile`` example that demonstrates
|
|
how to use ``torch.compile`` for inference. This example demonstrates the
|
|
``torch.cos()`` and ``torch.sin()`` features which are examples of pointwise
|
|
operators as they operate element by element on a vector. This example might
|
|
not show significant performance gains but should help you form an intuitive
|
|
understanding of how you can use ``torch.compile`` in your own programs.
|
|
|
|
.. note::
|
|
To run this script, you need to have at least one GPU on your machine.
|
|
If you do not have a GPU, you can remove the ``.to(device="cuda:0")`` code
|
|
in the snippet below and it will run on CPU.
|
|
|
|
.. code:: python
|
|
|
|
import torch
|
|
def fn(x):
|
|
a = torch.cos(x)
|
|
b = torch.sin(a)
|
|
return b
|
|
new_fn = torch.compile(fn, backend="inductor")
|
|
input_tensor = torch.randn(10000).to(device="cuda:0")
|
|
a = new_fn(input_tensor)
|
|
|
|
A more famous pointwise operator you might want to use would
|
|
be something like ``torch.relu()``. Pointwise ops in eager mode are
|
|
suboptimal because each one would need to read a tensor from the
|
|
memory, make some changes, and then write back those changes. The single
|
|
most important optimization that inductor performs is fusion. In the
|
|
example above we can turn 2 reads (``x``, ``a``) and
|
|
2 writes (``a``, ``b``) into 1 read (``x``) and 1 write (``b``), which
|
|
is crucial especially for newer GPUs where the bottleneck is memory
|
|
bandwidth (how quickly you can send data to a GPU) rather than compute
|
|
(how quickly your GPU can crunch floating point operations).
|
|
|
|
Another major optimization that inductor provides is automatic
|
|
support for CUDA graphs.
|
|
CUDA graphs help eliminate the overhead from launching individual
|
|
kernels from a Python program which is especially relevant for newer GPUs.
|
|
|
|
TorchDynamo supports many different backends, but TorchInductor specifically works
|
|
by generating `Triton <https://github.com/openai/triton>`__ kernels. Let's save
|
|
our example above into a file called ``example.py``. We can inspect the code
|
|
generated Triton kernels by running ``TORCH_COMPILE_DEBUG=1 python example.py``.
|
|
As the script executes, you should see ``DEBUG`` messages printed to the
|
|
terminal. Closer to the end of the log, you should see a path to a folder
|
|
that contains ``torchinductor_<your_username>``. In that folder, you can find
|
|
the ``output_code.py`` file that contains the generated kernel code similar to
|
|
the following:
|
|
|
|
.. code-block:: python
|
|
|
|
@pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
|
|
@triton.jit
|
|
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
|
xnumel = 10000
|
|
xoffset = tl.program_id(0) * XBLOCK
|
|
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
|
xmask = xindex < xnumel
|
|
x0 = xindex
|
|
tmp0 = tl.load(in_ptr0 + (x0), xmask, other=0.0)
|
|
tmp1 = tl.cos(tmp0)
|
|
tmp2 = tl.sin(tmp1)
|
|
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
|
|
|
|
.. note:: The above code snippet is an example. Depending on your hardware,
|
|
you might see different code generated.
|
|
|
|
And you can verify that fusing the ``cos`` and ``sin`` did actually occur
|
|
because the ``cos`` and ``sin`` operations occur within a single Triton kernel
|
|
and the temporary variables are held in registers with very fast access.
|
|
|
|
Read more on Triton's performance
|
|
`here <https://openai.com/blog/triton/>`__. Because the code is written
|
|
in Python, it's fairly easy to understand even if you have not written all that
|
|
many CUDA kernels.
|
|
|
|
Next, let's try a real model like resnet50 from the PyTorch
|
|
hub.
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
|
|
opt_model = torch.compile(model, backend="inductor")
|
|
opt_model(torch.randn(1,3,64,64))
|
|
|
|
And that is not the only available backend, you can run in a REPL
|
|
``torch.compiler.list_backends()`` to see all the available backends. Try out the
|
|
``cudagraphs`` next as inspiration.
|
|
|
|
Using a pretrained model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
PyTorch users frequently leverage pretrained models from
|
|
`transformers <https://github.com/huggingface/transformers>`__ or
|
|
`TIMM <https://github.com/rwightman/pytorch-image-models>`__ and one of
|
|
the design goals is TorchDynamo and TorchInductor is to work out of the box with
|
|
any model that people would like to author.
|
|
|
|
Let's download a pretrained model directly from the HuggingFace hub and optimize
|
|
it:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
from transformers import BertTokenizer, BertModel
|
|
# Copy pasted from here https://huggingface.co/bert-base-uncased
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
|
|
model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
|
|
text = "Replace me by any text you'd like."
|
|
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
|
|
output = model(**encoded_input)
|
|
|
|
If you remove the ``to(device="cuda:0")`` from the model and
|
|
``encoded_input``, then Triton will generate C++ kernels that will be
|
|
optimized for running on your CPU. You can inspect both Triton or C++
|
|
kernels for BERT. They are more complex than the trigonometry
|
|
example we tried above but you can similarly skim through it and see if you
|
|
understand how PyTorch works.
|
|
|
|
Similarly, let's try out a TIMM example:
|
|
|
|
.. code-block:: python
|
|
|
|
import timm
|
|
import torch
|
|
model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2)
|
|
opt_model = torch.compile(model, backend="inductor")
|
|
opt_model(torch.randn(64,3,7,7))
|
|
|
|
Next Steps
|
|
~~~~~~~~~~
|
|
|
|
In this section, we have reviewed a few inference examples and developed a
|
|
basic understanding of how torch.compile works. Here is what you check out next:
|
|
|
|
- `torch.compile tutorial on training <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
|
|
- :ref:`torch.compiler_api`
|
|
- :ref:`torchdynamo_fine_grain_tracing`
|