Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Differential [disconnected] Revision: D70471332
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Differential Revision: D70471332
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Note the attached diff contains some minor fbcode-only changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147583
Approved by: https://github.com/eellison, https://github.com/shunting314
Currently, there are compilation warnings as below, which are resolved after the fix
```
/tmp/torchinductor_root/c7t6qm4gf35cxkk5jywa5booovl5n6ivzwdbbs5og7rdemqtgrzh/caoefkofe5jrkuaoch4lfpjwtodlcy4savxgzsxqldkcdof7ifh7.cpp: In function ‘ihipModuleSymbol_t* loadKernel(std::string, const string&, uint32_t, const std::optional<std::__cxx11::basic_string<char> >&)’:
/tmp/torchinductor_root/c7t6qm4gf35cxkk5jywa5booovl5n6ivzwdbbs5og7rdemqtgrzh/caoefkofe5jrkuaoch4lfpjwtodlcy4savxgzsxqldkcdof7ifh7.cpp:482:25: warning: ignoring returned value of type ‘hipError_t’, declared with attribute nodiscard [-Wunused-result]
482 | hipDrvGetErrorString(code, &msg); \
| ~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~
/tmp/torchinductor_root/c7t6qm4gf35cxkk5jywa5booovl5n6ivzwdbbs5og7rdemqtgrzh/caoefkofe5jrkuaoch4lfpjwtodlcy4savxgzsxqldkcdof7ifh7.cpp:519:5: note: in expansion of macro ‘CUDA_DRIVER_CHECK’
519 | CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str()));
| ^~~~~~~~~~~~~~~~~
In file included from /opt/rocm/include/hip/hip_runtime.h:70,
from /pytorch/torch/include/torch/csrc/inductor/aoti_runtime/device_utils.h:14,
from /pytorch/torch/include/torch/csrc/inductor/aoti_runtime/model.h:17,
from /pytorch/torch/include/torch/csrc/inductor/aoti_runtime/model_container.h:13,
from /tmp/torchinductor_root/c7t6qm4gf35cxkk5jywa5booovl5n6ivzwdbbs5og7rdemqtgrzh/caoefkofe5jrkuaoch4lfpjwtodlcy4savxgzsxqldkcdof7ifh7.cpp:4:
/opt/rocm/include/hip/hip_runtime_api.h:2369:12: note: in call to ‘hipError_t hipDrvGetErrorString(hipError_t, const char**)’, declared here
2369 | hipError_t hipDrvGetErrorString(hipError_t hipError, const char** errorString);
| ^~~~~~~~~~~~~~~~~~~~
In file included from /opt/rocm/include/hip/hip_runtime.h:70,
from /pytorch/torch/include/torch/csrc/inductor/aoti_runtime/device_utils.h:14,
from /pytorch/torch/include/torch/csrc/inductor/aoti_runtime/model.h:17,
from /pytorch/torch/include/torch/csrc/inductor/aoti_runtime/model_container.h:13,
from /tmp/torchinductor_root/c7t6qm4gf35cxkk5jywa5booovl5n6ivzwdbbs5og7rdemqtgrzh/caoefkofe5jrkuaoch4lfpjwtodlcy4savxgzsxqldkcdof7ifh7.cpp:4:
/opt/rocm/include/hip/hip_runtime_api.h:399:3: note: ‘hipError_t’ declared here
399 | } hipError_t;
| ^~~~~~~~~~
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137626
Approved by: https://github.com/ColinPeppler, https://github.com/chenyang78
Summary: We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, it will apply positive lookbehind (?<=\W) and lookahead (?=\W) to the pattern to avoid matching keyword at the beginning and end of code line. However, this can happen in codegen, which will cause the pattern to not match.
Test Plan:
```
buck2 run //caffe2/test/inductor:test_cpp_wrapper_hipify
```
```
File changed: fbcode//caffe2/test/inductor/test_cpp_wrapper_hipify.py
Buck UI: https://www.internalfb.com/buck2/395155fa-b2dc-4892-8c71-74e52c65fa2f
Note: Using experimental modern dice
Network: Up: 0B Down: 0B (reSessionID-8fcfc520-755c-48f9-bacc-507c62f59231)
Jobs completed: 10947. Time elapsed: 0.5s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
BUILD SUCCEEDED
/data/users/zhuoran/fbsource/buck-out/v2/gen/fbcode/15b7034708b669be/caffe2/test/inductor/__test_cpp_wrapper_hipify__/test_cpp_wrapper_hipify#link-tree/torch/_utils_internal.py:282: NCCL_DEBUG env var is set to None
/data/users/zhuoran/fbsource/buck-out/v2/gen/fbcode/15b7034708b669be/caffe2/test/inductor/__test_cpp_wrapper_hipify__/test_cpp_wrapper_hipify#link-tree/torch/_utils_internal.py:300: NCCL_DEBUG is forced to WARN from None
test_hipify_aoti_driver_header (caffe2.test.inductor.test_cpp_wrapper_hipify.TestCppWrapperHipify) ... ok
test_hipify_basic_declaration (caffe2.test.inductor.test_cpp_wrapper_hipify.TestCppWrapperHipify) ... ok
test_hipify_cross_platform (caffe2.test.inductor.test_cpp_wrapper_hipify.TestCppWrapperHipify) ... ok
----------------------------------------------------------------------
Ran 3 tests in 0.262s
OK
```
e2e test:
```
TORCH_LOGS="output_code,graph_code" buck2 run mode/{opt,amd-gpu,inplace} -c fbcode.triton_backend=amd -c fbcode.enable_gpu_sections=true //aiplatform/modelstore/model_generation/gpu_lowering_service:gpu_lowering_cli -- --model_input_path="ads_storage_fblearner/tree/user/facebook/fblearner/predictor/936383960/0/gpu_lowering/input.merge" --model_output_path="ads_storage_fblearner/tree/user/facebook/fblearner/predictor/936383960/0/gpu_lowering/mi300_inductor_output.merge" --lowering_backend AOT_INDUCTOR --is_ads_model False --aot_inductor_lowering_settings_json='{"use_scripting":true,"preset_lowerer":"standalone_hstu_cint;disable_new_lowering_weights;disable_dper_passes:passes=fuse_parallel_linear_no_weight_change","precision":4,"output_precision":4, "remove_unexpected_type_cast":false, "sample_input_tile_factor":32}' 2>&1 | tee local_benchmark_log.txt
```
Differential Revision: D58705216
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128912
Approved by: https://github.com/desertfire