mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential [disconnected] Revision: D70471332 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305 Approved by: https://github.com/shunting314, https://github.com/eellison |
||
|---|---|---|
| .. | ||
| _strobelight | ||
| _sympy | ||
| backcompat | ||
| benchmark | ||
| bottleneck | ||
| data | ||
| hipify | ||
| jit | ||
| model_dump | ||
| serialization | ||
| tensorboard | ||
| viz | ||
| __init__.py | ||
| _appending_byte_serializer.py | ||
| _backport_slots.py | ||
| _config_module.py | ||
| _config_typing.pyi | ||
| _content_store.py | ||
| _contextlib.py | ||
| _cpp_embed_headers.py | ||
| _cpp_extension_versioner.py | ||
| _cxx_pytree.py | ||
| _device.py | ||
| _exposed_in.py | ||
| _filelock.py | ||
| _foreach_utils.py | ||
| _freeze.py | ||
| _functools.py | ||
| _get_clean_triton.py | ||
| _import_utils.py | ||
| _mode_utils.py | ||
| _ordered_set.py | ||
| _python_dispatch.py | ||
| _pytree.py | ||
| _stats.py | ||
| _thunk.py | ||
| _traceback.py | ||
| _triton.py | ||
| _typing_utils.py | ||
| _zip.py | ||
| backend_registration.py | ||
| bundled_inputs.py | ||
| checkpoint.py | ||
| collect_env.py | ||
| cpp_backtrace.py | ||
| cpp_extension.py | ||
| deterministic.py | ||
| dlpack.py | ||
| file_baton.py | ||
| flop_counter.py | ||
| hooks.py | ||
| mkldnn.py | ||
| mobile_optimizer.py | ||
| model_zoo.py | ||
| module_tracker.py | ||
| show_pickle.py | ||
| throughput_benchmark.py | ||
| weak.py | ||