pytorch/torch/sparse
Nikita Vedeneev 46f16b9363 Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)
As per title.

Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-26 07:58:27 +00:00
..
__init__.py Revert "Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)" 2023-01-19 23:37:59 +00:00
_triton_ops.py Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078) 2023-01-26 07:58:27 +00:00