Summary:
The recent tries on bandwidth profiler is not as expected. I have observed a few issues and tried to fix them in this diff:
1. The return of the DebugAutotuner class
2. Profiling results shows really large overhead.
DebugAutotuner.run() returns the benchmark time around 45ms while CachingAutotuner.run() returns the benchmark time around 0.45ms.
The `_find_names` and `re.match` takes 45ms: P1669186358
After we commenting out the above _find_names and re.match, the benchmark time become consistent with non-profiling mode: P1669185589
3. introduce a variable `bandwidth_info` to control the path in DebugAutotuner.run(). During benchmarking of configuration selection, we should turn off the `bandwidth_info`
After applying this diff, the profiling issues mentioned above are fixed: P1669273172
Test Plan:
```
TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_OUTPUT=~/tmp/profile.txt TORCH_LOGS='+inductor,+schedule,output_code' TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 CUDA_VISIBLE_DEVICES=5 buck run mode/{opt,inplace} scripts/wwei6/triton_examples:test_mat 2>&1 | tee profiling-5.log
```
If we want to disable the Aten backend, just add TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS="TRITON"
Differential Revision: D64883079
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139607
Approved by: https://github.com/chenyang78
Here's a markdown summary for the PR:
# Add workspace buffer support for Triton templates
## Summary
Adds support for templates to allocate and use temporary workspace buffers
## Key Changes
- Add `WorkspaceArg` support in Triton template system
- Automatic workspace allocation/deallocation around kernel execution
- Zero-initialization support for workspace buffers
- Seamless integration with existing tensor management
## Example Usage
```python
def generate(self, ...):
workspace_arg = WorkspaceArg(
count=1024*1024, # 1MB workspace
zero_fill=True # Zero-initialized
)
return TritonTemplateCaller(..., workspace_arg=workspace_arg)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138050
Approved by: https://github.com/Chillee, https://github.com/eellison
Summary:
We skip the save_gpu_kernel if kernel is being saved already.
This would give us a more accurate Triton profiling result. The
following trace shows before/after the change for a benchmarking of a
trivial addmm:
Before:
<img width="1255" alt="Screenshot 2024-09-23 at 10 26 53 AM" src="https://github.com/user-attachments/assets/5aea05ef-6ef0-464c-8da9-17b31c97b43a">
After:
<img width="910" alt="Screenshot 2024-09-23 at 10 27 03 AM" src="https://github.com/user-attachments/assets/488b7d4f-268f-41cf-8553-cb16ceeae118">
We can see that before the change, the benchmarking includes two parts,
(1) The overhead of our triton_heuristic call, which includes the
save/get, and the (expensive) hash computation.
(2) The exact computation of Triton kernel.
We see that (1) accounts >50% of time, which makes kernel selection
for profiling choosing aten kernels over Triton kernels.
Test Plan:
Existing OSS CI
python test/inductor/test_cuda_cpp_wrapper.py
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137073
Approved by: https://github.com/desertfire
move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827
Approved by: https://github.com/eellison
Automated fixes to put imports that are only used in type hints into TYPE_CHECKING imports. This also enables the RUFF TCH rules which will automatically apply autofixes to move imports in and out of TYPE_CHECKING blocks as needed in the future, this will make the initial PyTorch import faster and will reduce cyclic dependencies.
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127688
Approved by: https://github.com/XuehaiPan, https://github.com/ezyang, https://github.com/malfet
This is to prevent the import from being removed due to unused import. What's annoying about this is that it's not consistently running: lintrunner doesn't warn me on this PR even without the comment, but it does on other PRs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127545
Approved by: https://github.com/masnesral
By moving AsyncCompile to its own file, we can import codecache without running the side effects of AsyncCompile. This will be important for AOTAutogradCaching, where we want to share some implementation details with codecache.py without spawning new processes.
To conservatively maintain the same behavior elsewhere, every time we import codecache, I've added an import to torch._inductor.async_compile (except in autograd_cache.py, where the explicit goal is to not do this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127235
Approved by: https://github.com/aorenste, https://github.com/oulgen, https://github.com/masnesral
This PR adds the autotune Infrastructure for CPU. It generalizes and extends `BenchmarkRequest` with CPU support and C++ module loader. A `do_bench_cpu` util function is added for benchmarking functions on CPU with warmups and returns the median number from multiple trials.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125159
Approved by: https://github.com/jansel
Enable nonzero workspace and Cutlass StreamK for Inductor Cutlass GEMM ops.
This is a simpler rewrite of my original version of #119005 using @peterbell10 's workspace allocation mechanism from #117992
Test Plan:
- Additional unit test in test_cutlass_backend.py which specifically tests StreamK GEMM with workspace requirement
- CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125406
Approved by: https://github.com/jansel
If a Kernel does not return in a reasonable amount of time during autotuning, it can delay inductor compilation a lot. This change introduces soft / hard kill timeouts and a mechanism to kill Kernels being profiled in subprocesses if they take too long.
Correspondingly, a few new config options are introduced within _inductor/config.py - all of them with inline docs.
Test Plan:
Existing tests within test_max_autotune.py and test_cutlass_backend.py ) cover the new codepaths.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123932
Approved by: https://github.com/jansel
ghstack dependencies: #121497, #123930
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.
I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.
Triton benchmarking templates were emitted as :
```
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```
In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.
```
@triton_heuristics.template(
num_stages=3,
num_warps=8,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121998
Approved by: https://github.com/jansel
Internal users reported that they get failure for max-autotune if tensors are not on device 0. It turns out that we may use tensors on device say 6 and run kernel on them at device 0.
This PR enforces that we do benchmarking for max-autotune on the correct device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122479
Approved by: https://github.com/xintwfb, https://github.com/Chillee
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.
Triton benchmarking templates were emitted as :
```
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```
In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.
```
@triton_heuristics.template(
num_stages=3,
num_warps=8,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121998
Approved by: https://github.com/jansel
ghstack dependencies: #121996, #120275, #121997