mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Using `CUDA_KERNEL_ASSERT_PRINTF` inside kernels allows us to log invalid values to the console (that can be in turn used to surface _hopefully_ more clearer error messages).
This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using `__assert_fail`)
To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently.
## Alternatives considered
I considered making the behavior of `CUDA_KERNEL_ASSERT_PRINTF` controllable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful.
Test Plan:
## Simple Python Driver:
```
# scatter_errors.py
import torch
def main() -> None:
a = torch.rand(128, device="cuda:0")
idx = torch.randint(0, 128, (100,), device="cuda:0")
idx[0] = 9999
b = torch.scatter(a, 0, idx, 555.0)
print(b)
```
When running normally via:
```
$ buck2 run @//mode/opt :scatter_errors
```
we see the followng DSA message:
```
fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
```
Running via:
```
$ buck2 run @//mode/opt -c fbcode.c10_enable_verbose_assert=1 :scatter_errors
```
however produces:
```
[CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0]: Assertion failed: `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"`: Expected 0 <= idx_dim < index_size (128), but got idx_dim = 9999
```
Differential Revision: D85185987
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166171
Approved by: https://github.com/ngimel
|
||
|---|---|---|
| .. | ||
| core | ||
| cpu/vec | ||
| macros | ||
| util | ||
| BUCK.oss | ||
| BUILD.bazel | ||
| build.bzl | ||
| CMakeLists.txt | ||
| ovrsource_defs.bzl | ||
| version.h.in | ||