Commit Graph

41 Commits

Author SHA1 Message Date
Pearu Peterson
49f0d127fb Fix a bug in retrieving approximate bsr_dense_addmm kernel meta data (#124371)
Fixes #124333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124371
Approved by: https://github.com/eqy, https://github.com/lezcano
2024-04-24 13:59:18 +00:00
Pearu Peterson
a39e638707 Update bsr_dense_addmm kernel parameters for sizes 3 x 2 ^ N (#122506)
As in the title. The speed-ups for a particular set of input sizes range from about 7 to 85 % depending on the used BSR tensor block sizes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122506
Approved by: https://github.com/cpuhrsch
2024-03-23 11:54:33 +00:00
Peter Bell
3a8bf25fdd [SparseCsr] Remove triton sdpa skip after triton pin update (#109601)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109601
Approved by: https://github.com/desertfire, https://github.com/amjames
2024-02-08 16:40:25 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
Pearu Peterson
32286512cc Add tune_bsr_dense_addmm as an API to find optimal triton kernel parameters for bsr_dense_addmm (#115499)
As in the title.

In addition:
- improve the algorithm for finding a minima of operation timings: break the inner loop early when a next minima candidate is found
- add tests and fix bugs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115499
Approved by: https://github.com/cpuhrsch
2023-12-12 16:44:51 +00:00
Pearu Peterson
12085914b8 Replace bsr_dense_mm triton kernel with bsr_dense_addm triton kernel (#115030)
The `bsr_dense_addmm` triton kernel introduced in https://github.com/pytorch/pytorch/pull/114595 is a generalization of `bsr_dense_mm` triton kernel and a more efficient version of it because it uses an extra kernel parameter `SPLIT_N` that has notable effect to performance for r.h.s operand with a larger number of columns.

This PR eliminates the `bsr_dense_mm` triton kernel in favor of using `bsr_dense_addmm` triton kernel.

The performance increase of `bsr_dense_mm` is as follows (float16, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is 50/71 %
- with 32x32 blocks, the average/maximal speed up is 30/63 %
- with 64x64 blocks, the average/maximal speed up is 12/26 %
- with 128x128 blocks, the average/maximal speed up is 7/17 %

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115030
Approved by: https://github.com/cpuhrsch
2023-12-05 22:29:24 +00:00
Pearu Peterson
4ba37e1804 Add tests for bsr_dense_addmm and bsr_dense_mm triton kernels (#114800)
As in the title.

In addition,
- resolve https://github.com/pytorch/pytorch/pull/114757#discussion_r1409547917 re triton-contiguous inputs
- support non-contiguous inputs and outputs in triton kernels
- fix a couple of minor bugs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114800
Approved by: https://github.com/cpuhrsch
2023-12-04 22:07:47 +00:00
Pearu Peterson
69f112d586 Call triton bsr_dense_mm/bsr_dense_addmm kernels on mm/addmm float32 inputs when appropiate (#114757)
As in the title.

In addition, this PR fixes a bug in `bsr_dense_mm` and `bsr_dense_addmm` return value handling where computations are performed on `make_triton_contiguous` return value while `bsr_dense_mm`/`bsr_dense_addmm` return a tensor that is an input to `make_triton_contiguous`. If `make_triton_contiguous` makes a copy of the input, the return values of `bsr_dense_mm`/`bsr_dense_addmm` will contain garbage.

The PR increases the performance of nn.linear as follows (float32, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is 67/78 %
- with 32x32 blocks, the average/maximal speed up is 72/79 %
- with 64x64 blocks, the average/maximal speed up is 71/79 %
- with 128x128 blocks, the average/maximal speed up is 62/76 %

The performance increase is illustrated also by the following sparsity-speedup graphs (before and after this PR):
<img src="https://github.com/pytorch/pytorch/assets/402156/55ce0bf7-8ef2-47ab-99e8-8878f159037d" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/df256175-a594-4bd7-b244-90867fb9a45e" width="48%">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114757
Approved by: https://github.com/cpuhrsch
2023-11-30 13:38:07 +00:00
Pearu Peterson
69c4819f53 Add bsr_dense_addmm triton kernel (#114595)
As in the title.

The `bsr_dense_addmm` kernel implemented in this PR is a generalization of `bsr_dense_mm` in the following respects (in addition of having input, beta, and alpha parameters):
- it implements `SPLIT_N` kernel parameter that enables efficient kernel launches in the case of wide inputs. For instance, the timing of nn.linear with 256x256 BSR weights having 16x16 blocks and 256x131072 strided input reduced about 16x (this corresponds to the 94 % speed up value listed below).
- it supports rectangular blocks in sparse BSR tensor weights

The performance increase of nn.linear is as follows (float16, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is  55/94 %
- with 32x32 blocks, the average/maximal speed up is  33/63 %
- with 64x64 blocks, the average/maximal speed up is  23/42 %
- with 128x128 blocks, the average/maximal speed up is  15/39 %

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114595
Approved by: https://github.com/cpuhrsch
2023-11-29 05:29:25 +00:00
Pearu Peterson
cffea773e3 Fix bsr_dense_mm with a non-contiguous out argument. (#113801)
Fixes https://github.com/pytorch/pytorch/issues/113754

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113801
Approved by: https://github.com/cpuhrsch
2023-11-16 05:56:17 +00:00
Pearu Peterson
e1c872e009 Add optimal triton kernel parameters to bsr_dense_mm and scatter_mm for bfloat16 and float32 dtypes (#113553)
As in the title.

This PR is a follow-up to PR https://github.com/pytorch/pytorch/pull/112737 to address bfloat16 and float32 dtype cases. The performance increase is as follows (`NVIDIA A100-SXM4-80GB`):

- bsr_scatter_mm and bfloat16
  - for blocksize 16x16, the average/maximum speed up is about 29/75 %.
  - for blocksize 32x32, the average/maximum speed up is about 23/58 %.
  - for blocksize 64x64, the average/maximum speed up is about 27/66 %.
  - for blocksize 128x128, the average/maximum speed up is about 33/72 %.
- bsr_dense_mm and bfloat16
  - for blocksize 16x16, the average/maximum speed up is about 47/61 %.
  - for blocksize 32x32, the average/maximum speed up is about 29/43 %.
  - for blocksize 64x64, the average/maximum speed up is about 21/41 %.
  - for blocksize 128x128, the average/maximum speed up is about 12/29 %.
- bsr_dense_mm and  float32
  - for blocksize 16x16, the average/maximum speed up is about 35/49 %.
  - for blocksize 32x32, the average/maximum speed up is about 2/5 %.
  - for blocksize 64x64, the average/maximum speed up is about 2/21 %.
  - for blocksize 128x128, the average/maximum speed up is about 79/84 %.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113553
Approved by: https://github.com/cpuhrsch
2023-11-14 00:47:59 +00:00
Pearu Peterson
e64d250210 Add a tool for a semi-automatic optimization of bsr_dense_mm meta parameters. (#112737)
Finding optimal meta parameters for bsr_dense_mm and bsr_scatter_mm triton kernels is a tedious job. This PR introduces a tool (a Python script `torch/sparse/_triton_ops_meta.py`) that finds the optimal set of meta parameters for a given set of matrix multiplication inputs and their block sizes. Currently, such a set is found for square bsr tensor inputs with sizes 256...16384 and square blocksizes 16...128, and dense tensor inputs with sizes 256...131072.
As a result, bsr_dense_mm performance has increased as follows (`NVIDIA A100-SXM4-80GB`):
- for blocksize 16x16, the average/maximum speed up is about 40/60 %.
- for blocksize 32x32, the average/maximum speed up is about 28/45 %.
- for blocksize 64x64, the average/maximum speed up is about 26/43 %.
- for blocksize 128x128, the average/maximum speed up is about 12/28 %.

To enable the performance improvements through meta parameter optimization for other CUDA devices, one must execute the `_triton_ops_meta.py` which will calculate the optimal meta parameters and store the results in a dictionary object defined in `_triton_ops_meta.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112737
Approved by: https://github.com/cpuhrsch
2023-11-05 12:52:09 +00:00
Pearu Peterson
33c41daf60 Fix scatter_mm kernel failure on non-contiguous tensor arguments (#112337)
This PR fixes
```
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
```
that appears when using large non-contiguous tensor arguments in `scatter_mm` kernel launch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112337
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #112154, #112076
2023-10-30 19:16:05 +00:00
Pearu Peterson
cf6041e942 Use weakref in storing tensors as keys (follow-up to #111470) (#112076)
This PR addresses the discussion items in https://github.com/pytorch/pytorch/pull/111470#discussion_r1369008167, that is,
- use weakref when storing tensors as keys,
- add `storage_offset` to the key data,
- and revise the description of the `TensorAsKey` utility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112076
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #112154
2023-10-30 19:16:05 +00:00
Pearu Peterson
b969c675f5 Add batched dimensions support to the second operand of bsr_scatter_mm (#111796)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111796
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110396, #111470, #111489, #111760
2023-10-23 23:52:49 +00:00
Pearu Peterson
6382011843 Add NVIDIA A100 optimized meta parameters to bsr_dense_mm (#111760)
As in the title.

The figures below illustrate the performance differences of bsr_dense_mm with optimized parameters and bsr_dense_mm with default parameters (GPU: NVIDIA A100-SXM4-80GB). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value bsr_dense_mm have the same performance characteristics as torch.matmul. The second figure represents speedups from using optimized meta parameters in bsr_dense_mm at its performance equilibrium points with respect to bsr_dense_mm with default meta parameters.

In sum, this PR speeds up `bsr_dense_mm` about 50 % depending on the bsr tensor shape and blocksize and lowers the performance equilibrium points of BSR tensor sparsity and strided tensor for matmul operations.

<img src="https://github.com/pytorch/pytorch/assets/402156/6fe9d35f-dd21-4aa0-bb01-6ee257254453" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/506921c6-3770-4209-ad3d-498d2ae4989d" width="48%">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111760
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110396, #111470, #111489
2023-10-23 23:52:49 +00:00
Pearu Peterson
f3d08ab271 Use more performant bsr_scatter_mm within bsr_dense_mm when blocksize is 16. (#111489)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111489
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110396, #111470
2023-10-23 23:52:49 +00:00
Pearu Peterson
6078ed95cc Use lru_cache to cache indices data for bsr_scatter_mm. (#111470)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111470
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110396
2023-10-23 23:52:49 +00:00
Pearu Peterson
d4708a6da7 Add scatter_mm and bsr_scatter_mm operations. (#110396)
This PR introduces `scatter_mm` operation (compute `mm` of arbitrary pairs of tensors given in batches of tensors) that is used to implement `bsr_scatter_mm` that is equivalent to `bsr_dense_mm` (the `mm` operation on bsr and strided tensors). The implementation is provided both in Triton (when tensor dimensions are multiples of 16) and in PyTorch (otherwise).

The figures below illustrate the performance differences of `bsr_scatter_mm` and `bsr_dense_mm` (GPU: `NVIDIA GeForce RTX 2060 SUPER`). The first figure represents the performance equilibrium point in BSR tensor sparsity at which value `bsr_scatter_mm` or `bsr_dense_mm` have the same performance characteristics as `torch.matmul`. The second figure represents speedups from using `bsr_scatter_mm` at its performance equilibrium points with respect to `bsr_dense_mm`.

<img src="https://github.com/pytorch/pytorch/assets/402156/526d182e-937f-4812-a6c4-904f52d6d5ab" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/ccb606ab-1f3f-4133-887c-b56285f4f168" width="48%">

The same figures for GPU card `NVIDIA A100-SXM4-80GB`:

<img src="https://github.com/pytorch/pytorch/assets/402156/25466f1d-df34-4d1c-a975-afb478e4d9f0" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/6ada91f0-a20f-4f0d-8a48-1f4ccc60d08e" width="48%">

In sum:
- `bsr_scatter_mm` is about 2x faster than `bsr_dense_mm` for small block sizes of 16 and 32 and large tensors [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- `bsr_scatter_mm` is up to 2x faster than `bsr_dense_mm` for small block sizes of 16 and large tensors [GPU: `NVIDIA A100-SXM4-80GB`].
- `bsr_dense_mm` is up to 20 % faster than `bsr_scatter_mm` for block sizes of 64 or larger [GPU: `NVIDIA GeForce RTX 2060 SUPER`].
- However, `bsr_dense_mm` fails with `OutOfResources` exception for block sizes of 256 or larger whereas `bsr_scatter_mm` succeeds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110396
Approved by: https://github.com/cpuhrsch
2023-10-23 19:45:30 +00:00
Oguz Ulgen
1df14f1bf8 Move has_triton to top level triton utils so that dynamo can also access (#109832)
it without creating cyclic dependencies

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109832
Approved by: https://github.com/zou3519
2023-09-22 19:33:41 +00:00
Pearu Peterson
4e042cfed5 Improve triton bsr_dense_mm performance on column-major ordered inputs with float32 dtype (#108512)
As in the title.

The bsr_dense_mm performance on inputs using column-major storage order is relevant for `linear(x, W)` operation that for BSR weights is defined as `bsr_dense_mm(W, x.transpose(-2, -1)).transpose(-2, 1)` so that the second argument to `bse_dense_mm` is a strided tensor using column-major storage order when `x` is C-contiguous.

For large inputs (size > 1000) and moderate sparsity in the BSR input, the speed up can be more than 3 times, as illustrated in the following figure (raw data: [bench_bsr_dense_mm_1_results.txt](https://github.com/pytorch/pytorch/files/12512245/bench_bsr_dense_mm_1_results.txt)):

![bench_bsr_dense_mm_1](https://github.com/pytorch/pytorch/assets/402156/c6372008-dfae-4d26-b119-2c3c944a74ae)

For small inputs (size=512), there exists a slight degradation of performance.

For row-major ordered inputs, there is no change in performance (see raw data above).

For inputs with float16 dtype, there is no considerable change in performance (see blue marks in the figure).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108512
Approved by: https://github.com/cpuhrsch
2023-09-06 17:30:06 +00:00
nikitaved
44c8515d0d SDPA: frontend for BSR masks (#104042)
This PR implements a (yet private) frontend for scaled_dot_product_attention that works with BSR `attn_mask`.

This function is directly comparable (with suitable masks) with `torch.nn.functional.scaled_dot_product_attention` once `attn_mask.dtype == torch.bool`, but it's behavior is different when `attn_mask.dtype != torch.bool`. This is because `torch.nn.functional.scaled_dot_product_attention` assumes that irrelevant values are supposed to be filled with `-inf`, while the selected ones should be `0`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104042
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
2023-07-13 18:01:21 +00:00
Nikita Vedeneev
39a22e2791 softmax: Triton kernel for BSR inputs (#102095)
Implements `softmax` Triton kernel for BSR inputs. So far, only over `dim=-1`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102095
Approved by: https://github.com/cpuhrsch
2023-06-21 01:23:27 +00:00
Nikita Vedeneev
6c7410ddc3 sampled_addmm: BSR support (#101163)
This PR implements a `sampled_addmm` kernel that works with a BSR mask.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101163
Approved by: https://github.com/cpuhrsch
2023-05-25 12:33:50 +00:00
Nikita Vedeneev
dd2c22f4bb bsr_dense_bmm(): enable more precise float32 support with float64 accumulators (#100882)
Float64 is there in Triton! This PR increases precision for float32 inputs with float64 accumulation dtype.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100882
Approved by: https://github.com/cpuhrsch
2023-05-11 11:22:55 +00:00
Nikita Vedeneev
0141a242fd bsr_dense_bmm(): remove sparse_rowspace kernel and some dead code (#100876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100876
Approved by: https://github.com/cpuhrsch, https://github.com/Skylion007
2023-05-09 16:12:11 +00:00
Nikita Vedeneev
c4bc259f00 bsr_dense_mm(): better test coverage (#100543)
This PR improves test coverage for `bsr_dense_mm` by:
- ~~enabling correctness tests for `float32`~~.
- extending and testing input correctness checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100543
Approved by: https://github.com/cpuhrsch, https://github.com/malfet
2023-05-09 09:26:02 +00:00
Nikita Vedeneev
cd8b82e5c6 bsr_dense_mm(): code refactoring (#100634)
Code unification/refactoring for better re-use. Intended for easier `sampled_addmm` implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100634
Approved by: https://github.com/cpuhrsch
2023-05-08 13:27:39 +00:00
Nikita Vedeneev
05dda7ff65 bsr_dense_mm Triton kernel: fix out kwarg (#96648)
As per title. The kernel did not handle `out=` correctly and returned a different tensor which only shared storage with `out`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96648
Approved by: https://github.com/cpuhrsch
2023-03-14 18:01:22 +00:00
Natalia Gimelshein
76cac70939 new triton main pin (#95896)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95896
Approved by: https://github.com/jansel, https://github.com/malfet
2023-03-10 06:30:41 +00:00
PyTorch MergeBot
d0731271cd Revert "new triton main pin (#95896)"
This reverts commit 6e0359dd42.

Reverted https://github.com/pytorch/pytorch/pull/95896 on behalf of https://github.com/huydhn due to I am not quite sure what this is about yet, but testing 3.8 wheel starts to fail 6e0359dd42
2023-03-10 05:41:45 +00:00
Natalia Gimelshein
6e0359dd42 new triton main pin (#95896)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95896
Approved by: https://github.com/jansel
2023-03-10 03:40:37 +00:00
Nikita Vedeneev
d809020fc8 Triton kernel for bsr @ dense (#94823)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94823
Approved by: https://github.com/cpuhrsch, https://github.com/malfet
2023-03-03 15:11:28 +00:00
PyTorch MergeBot
7012d985fa Revert "Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)"
This reverts commit 46f16b9363.

Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/ZainRizvi due to Causing a test to fail consistently: test_decomp.py::HasDecompTest::test_has_decomposition
2023-01-26 16:22:29 +00:00
Nikita Vedeneev
46f16b9363 Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)
As per title.

Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-26 07:58:27 +00:00
PyTorch MergeBot
60bf851931 Revert "Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)"
This reverts commit 8383b5c488.

Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/malfet due to This seems to have broke sm_86 testing, see https://hud.pytorch.org/hud/pytorch/pytorch/master/1?per_page=50&name_filter=sm86%20%2F%20test%20(default%2C%203
2023-01-19 23:37:59 +00:00
Nikita Vedeneev
8383b5c488 Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)
As per title.

Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-19 03:14:54 +00:00
PyTorch MergeBot
89f1ad08b4 Revert "Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)"
This reverts commit 7f256fff77.

Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/huydhn due to This breaks lint 7f256fff77
2023-01-17 22:14:37 +00:00
Nikita Vedeneev
7f256fff77 Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)
As per title.

Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-17 21:43:20 +00:00