This is a follow up for PR #95506 to run all the triton kernels in a compiled module individually as suggested by Horace.
Here are the steps:
1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g.
```
TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training
```
2. From the output we will see 3 lines like
```
Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py
```
That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module.
3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file.
```
python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k
```
Example output:
<img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png">
Note: I use the first 10 characters of the hash to identify each kernel since
1. hash is easier to get in the code :)
2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash)
If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95845
Approved by: https://github.com/Chillee
Summary: The AOT mode currently works for the CPP backend. When turned on, Inductor compiles the model code into a .so file with aot_inductor_entry as the entry function. If the AOT compilation fails, Inductor will explicitly fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94822
Approved by: https://github.com/jansel
Inductor implementations of collectives/wait must match
eager impls in _functional_collectives in terms of interacting
with _register_tensor_work API. If they do, then splitting
a collective-wait pair so one half is in a compiled graph should
work fine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95893
Approved by: https://github.com/kumpera
A PR to generate benchmark code for individual triton kernels. We can explore improving autotuning with the saved compiled kernel directly. This potentially can speedup our iteration and separate the concern with the upstream components that generate the compiled module.
Since I'm still ramping up on inductor, I'll reflect what I learned here so people can correct me if I'm wrong. In inductor, WrapperCodeGen class is used to generate the compiled module for CUDA (or triton). Here is an example compiled module for a toy model like: `def f(x): return sin(x) + cos(x)` https://gist.github.com/shunting314/c6ed9f571919e3b414166f1696dcc61b . A compiled module contains the following part:
- various triton kernels
- a wrapper (or a method named call . The name is hardcoded) that calls the triton kernels and potentially ATen kernels to efficiently do the same work as the original Fx graph being compiled by inductor
- some utility code that generate random inputs and run the wrapper
The triton kernels in the compiled module are annotated with decorator like pointwise which is used for autotuning.
This PR add a config so enabling it will just trigger the path of the compiled module being printed. It can be controlled from environment variable as well.
The path to each compiled triton kernel is added as comment in the compiled module. E.g.
```
# kernel path: /tmp/torchinductor_shunting/gn/cgn6x3mqoltu7q77gjnu2elwfupinsvcovqwibc6fhsoiy34tvga.py
triton__0 = async_compile.triton('''
import triton
import triton.language as tl
...
""")
````
Example command:
```
TORCHINDUCTOR_OUTPUT_COMPILED_MODULE_PATH=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --training --dashboard --only AlbertForMaskedLM --disable-cudagraphs
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95506
Approved by: https://github.com/Chillee
Summary:
update the hashing method for `ChoiceCaller` class.
`TritonTemplateCaller` objects will now be hashed to:
`{name}-({BLOCK_M}, {BLOCK_N}, {BLOCK_K})-{num_stages}-{num_warps}-{code_hash}`
for example:
`triton_mm-(64, 32, 32)-4-8-cptlntwzcl2gaaofd2oabdwhaqv4ox3lluvbuxitjfhhpz6cyl4o`
`ExternKernelCaller` objects will now be hashed to:
`{name}-{kwargs.keys()[0]}={kwargs.vals()[0]}-...-{code_hash}`
for example:
`addmm-alpha=1-beta=1-c4xxd3iocu4yt6z4udrlqnumays7q6mfnfd3qprh4fxgsvyhqdkf`
Test Plan: sandcastle
Differential Revision: D43285470
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94853
Approved by: https://github.com/jansel, https://github.com/bertmaher
This adds `torch.cuda._DeviceGuard` which is a stripped down version of
`torch.cuda.device` with lower overhead. To do this, it only accepts `int` as
the device so we don't need to call `_get_device_index` and is implemented
with a new C++ helper `torch._C._cuda_exchangeDevice` that allows
`_DeviceGuard.__enter__` to be just a single function call. On my machine,
I see a drop from 3.8us of overhead to 0.94 us with this simple benchmark:
```python
def set_device():
with torch.cuda.device(0):
pass
%timeit set_device()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91045
Approved by: https://github.com/ngimel, https://github.com/anijain2305
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab
## Pitch
Change input args type from `std::tuple` to `std::vector` to reduce the compilation time.
## Description
`std::tie()` takes quite a long time during the compilation when the input args number grows.
For example, for a graph from the `PegasusForConditionalGeneration` model with 318 input args, the compilation of `std::tie` for the args is about 10s. By changing to std::vector, the compilation time of arg assignment is reduced to less than 1s.
### Code before:
```cpp
at::Tensor call_0(std::tuple<at::Tensor&, at::Tensor&> args) {
at::Tensor arg0_1, arg1_1;
std::tie(arg0_1, arg1_1) = args;
...
return buf0;
}
```
### Code after:
```cpp
at::Tensor call_0(std::vector<at::Tensor> args) {
at::Tensor arg0_1, arg1_1;
arg0_1 = args[0];
arg1_1 = args[1];
...
return buf0;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90754
Approved by: https://github.com/jgong5, https://github.com/jansel
For reductions, the code string in the codegen stage and the execution stage are different due to `\`.
- The code string gotten from `code.getvalue()` (`code` is an `IndentedBuffer`) in codegen stage:
```
#pragma omp declare reduction(argmax : struct IndexValue_1 :\
omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\
omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\
initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
```
- The code string loaded during the execution (`\` will be escaped):
```
#pragma omp declare reduction(argmax : struct IndexValue_1 : omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value, omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index) initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
```
Thus we can't get the same hash value for these two pieces of code.
This PR adds a function to make the transformation escape the backslash in the codegen stage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88561
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
This PR adds an option `config.profiler_mark_wrapper_call` (disabled by default) to mark the duration of wrapper call in the PyTorch profiler. This makes it easy to identify the duration and start/end of each wrapper call in the profiler output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89674
Approved by: https://github.com/jansel
- Propagates origin fx nodes through inlining during lowering
- Concatenates op names into kernel name
- Adds config to cap the number of ops in the kernel name so they don't get too long
Caveats:
- The ordering in the name may not match the order that the ops are executed in the kernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88624
Approved by: https://github.com/anijain2305, https://github.com/jansel
Inductor internally models any `size=1` dimension as having `stride=0` to simplify indexing formulas (sympy will remove these terms from the expression).
This caused a bug in our generate stride assert in detectron2_maskrcnn_r_50_fpn, where we asserted the wrong stride of a size==1 dimension.
This fixes that bug, and moves size/stride assert logic to C++ which should be a small perf gain.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87028
Approved by: https://github.com/anijain2305