Summary:
This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`.
This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with
cusparselt >= 0.6.2, via to `scaled_mm` API.
```
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)
dense_result = torch._scaled_mm(
A_fp8, B_fp8,
scale_a=A_scale, scale_b=B_scale,
out_dtype=out_dtype
)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
sparse_result = torch._scaled_mm(
A_fp8_sparse, B_fp8,
scale_a=A_scale, scale_b=B_scale,
out_dtype=out_dtype
)
```
Note that to keep this consistent with normal torch behavior, calling
`torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError.
I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the
backend to make the tests a bit cleaner
Test Plan:
```
python test/test_sparse_semi_structured -k scaled_mm
python test/test_sparse_semi_structured -k fp8
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397
Approved by: https://github.com/drisspg
As in the title.
Tackles https://github.com/pytorch/ao/pull/821/files#r1759821413
The PR assumes that the existing tuning parameters are good also when using scaling arguments. This needs to be verified as a follow-up task.
Also, this PR redefines triton-contiguous tensors: the tensor must have strides not larger than 1. This will now allow zero strides that previously triggered `contiguous` call although the underlying memory buffer was contiguous.
Re: "a considerable slow-down occurs because tensor data is copied element-wise rather than chunk-wise" - this note should refer to a code (torch or triton?) that implements the element/chunk-wise copy so that we could verify that allowing zero strides indeed would not trigger element-wise copies. Atm, the performance increase in ViT-H benchmarks (that involve using 0 strides) is an evidence that allowing zero strides does not lead to slow-downs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136104
Approved by: https://github.com/cpuhrsch
Summary:
This PR updated cuSPARSELt to v0.6.2. I think we should land
https://github.com/pytorch/pytorch/pull/128534 first though.
Most of this PR is just enabling tests to run when cuSPARSELt v0.6.2 is
available.
Unfortunately was running into a bug with fp32 support on Hopper, so I
removed fp32 support from the cuSPARSELt backend. I think this should be
fine since almost everybody uses the bfloat/float16/int8 kernels.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134022
Approved by: https://github.com/jerryzh168, https://github.com/malfet
ghstack dependencies: #128534
This PR switches to cuDSS library and has the same purpose of #127692, which is to add Sparse CSR tensor support to linalg.solve.
Fixes#69538
Minimum example of usage:
```
import torch
if __name__ == '__main__':
spd = torch.rand(4, 3)
A = spd.T @ spd
b = torch.rand(3).to(torch.float64).cuda()
A = A.to_sparse_csr().to(torch.float64).cuda()
x = torch.linalg.solve(A, b)
print((A @ x - b).norm())
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129856
Approved by: https://github.com/amjames, https://github.com/lezcano, https://github.com/huydhn
Co-authored-by: Zihang Fang <zhfang1108@gmail.com>
Co-authored-by: Huy Do <huydhn@gmail.com>
As in the title. In addition, the PR introduces `_int_bsr_dense_addmm` that is equivalent to `bsr_dense_addmm` except for int8 inputs the operation result is int32 tensor (similar to existing `_int_mm`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133855
Approved by: https://github.com/cpuhrsch
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
This PR adds in fast semi-structured sparsification kernels to PyTorch.
These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.
The kernels have been added as aten native functions
In particular, three new functions have been added:
* `torch._sparse_semi_structured_tile`
This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.
* `torch._sparse_semi_structured_apply`
This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.
* `torch._sparse_semi_structured_apply_dense`
This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation
The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.
To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`
Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
This PR adds in fast semi-structured sparsification kernels to PyTorch.
These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.
The kernels have been added as aten native functions
In particular, three new functions have been added:
* `torch._sparse_semi_structured_tile`
This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.
* `torch._sparse_semi_structured_apply`
This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.
* `torch._sparse_semi_structured_apply_dense`
This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation
The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.
To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`
Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
This PR adds in fast semi-structured sparsification kernels to PyTorch.
These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.
The kernels have been added as aten native functions
In particular, three new functions have been added:
* `torch._sparse_semi_structured_tile`
This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.
* `torch._sparse_semi_structured_apply`
This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.
* `torch._sparse_semi_structured_apply_dense`
This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation
The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.
To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
This PR adds in fast semi-structured sparsification kernels to PyTorch.
These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.
The kernels have been added as aten native functions
In particular, three new functions have been added:
* `torch._sparse_semi_structured_tile`
This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.
* `torch._sparse_semi_structured_apply`
This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.
* `torch._sparse_semi_structured_apply_dense`
This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation
The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.
To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
Summary:
This PR is a refactor of semi-structured sparsity support.
**deprecation**:
Before `torch.sparse.to_sparse_semi_structured` had a kwarg param
`transposed=False`, which has been removed. This kwarg was unused and
now thros a deprecation warning.
Namely, I've taken the subclassing implementation that xFormers has
created and brought it over to PyTorch, as part of our plan to upstream
runtime 2:4 sparsity.
I've also copied over all the op support that Daniel implemenented that
did not depend on the fast sparsification routines, into
`_sparse_semi_structured_ops.py`
With this subclass, all of our internal tests pass, as well as those in
xFormers.
The main change is that we now define a base subclass,
`SparseSemiStructuredTensor` that is inherited from for each of the
specific backends.
We also now can arbitrarily override the sparse dispatch table with
`_load_dispatch_table()`, idea being this is still general enough
where users don't need to modify pytorch source code to get their model
working.
This also adds in padding support and stores alg_id and fuse_transpose
as flags on the tensor, instead of hardcoding them.
There still remains two components in xFormers that will need to be
ported over eventually:
- the autograd functions (`Sparsify24`, `Sparsify24_like`)
- fast sparsification routines that they rely on
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117302
Approved by: https://github.com/alexsamardzic, https://github.com/HDCharles
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Summary:
Both cuSPASRELt and CUTLASS support 1:2 semi-structured sparsity for
fp32, which this PR enables.(thanks @alexsamardzic).
Furthermore, this PR also updates the sparse_config to take into account
the different shape constraints for sparse and dense matrices.
Technically, cuSPARSELt supports smaller sparse matrix constraints as it
seens to pad to the CUTLASS constraints under the hood. However, in
practice small sparse matrices are not commonly used and we care more
about the dense constraints for LLM inference.
For now, we keep the CUTLASS constraints in place for both cuSPARSELt
and CUTLASS tensors
This PR also reconnects the _FUSE_TRANSPOSE flag for cuSPARSELt tensors.
Test Plan:
```
python test/test_sparse_semi_structured.py
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115550
Approved by: https://github.com/cpuhrsch
As in the title.
In addition:
- improve the algorithm for finding a minima of operation timings: break the inner loop early when a next minima candidate is found
- add tests and fix bugs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115499
Approved by: https://github.com/cpuhrsch
The `bsr_dense_addmm` triton kernel introduced in https://github.com/pytorch/pytorch/pull/114595 is a generalization of `bsr_dense_mm` triton kernel and a more efficient version of it because it uses an extra kernel parameter `SPLIT_N` that has notable effect to performance for r.h.s operand with a larger number of columns.
This PR eliminates the `bsr_dense_mm` triton kernel in favor of using `bsr_dense_addmm` triton kernel.
The performance increase of `bsr_dense_mm` is as follows (float16, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is 50/71 %
- with 32x32 blocks, the average/maximal speed up is 30/63 %
- with 64x64 blocks, the average/maximal speed up is 12/26 %
- with 128x128 blocks, the average/maximal speed up is 7/17 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115030
Approved by: https://github.com/cpuhrsch
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
Summary:
This PR adds in support for passing in a alpha Tensor, which represents
a tensor of alpha values to fuse into the matmul.
```
cusparselt_sparse_mm = alpha A @ B + bias
```
This operation is necessary for quantization, where we would like to
fuse one of the dequant matmuls into the sparse op.
Test Plan:
```
python test/test_sparse_semi_structured -k alpha
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112056
Approved by: https://github.com/cpuhrsch
As in the title.
In addition, this PR fixes a bug in `bsr_dense_mm` and `bsr_dense_addmm` return value handling where computations are performed on `make_triton_contiguous` return value while `bsr_dense_mm`/`bsr_dense_addmm` return a tensor that is an input to `make_triton_contiguous`. If `make_triton_contiguous` makes a copy of the input, the return values of `bsr_dense_mm`/`bsr_dense_addmm` will contain garbage.
The PR increases the performance of nn.linear as follows (float32, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is 67/78 %
- with 32x32 blocks, the average/maximal speed up is 72/79 %
- with 64x64 blocks, the average/maximal speed up is 71/79 %
- with 128x128 blocks, the average/maximal speed up is 62/76 %
The performance increase is illustrated also by the following sparsity-speedup graphs (before and after this PR):
<img src="https://github.com/pytorch/pytorch/assets/402156/55ce0bf7-8ef2-47ab-99e8-8878f159037d" width="48%"> <img src="https://github.com/pytorch/pytorch/assets/402156/df256175-a594-4bd7-b244-90867fb9a45e" width="48%">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114757
Approved by: https://github.com/cpuhrsch
As in the title.
The `bsr_dense_addmm` kernel implemented in this PR is a generalization of `bsr_dense_mm` in the following respects (in addition of having input, beta, and alpha parameters):
- it implements `SPLIT_N` kernel parameter that enables efficient kernel launches in the case of wide inputs. For instance, the timing of nn.linear with 256x256 BSR weights having 16x16 blocks and 256x131072 strided input reduced about 16x (this corresponds to the 94 % speed up value listed below).
- it supports rectangular blocks in sparse BSR tensor weights
The performance increase of nn.linear is as follows (float16, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is 55/94 %
- with 32x32 blocks, the average/maximal speed up is 33/63 %
- with 64x64 blocks, the average/maximal speed up is 23/42 %
- with 128x128 blocks, the average/maximal speed up is 15/39 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114595
Approved by: https://github.com/cpuhrsch
This PR:
- updates `torch/sparse/_triton_ops_meta.py` for the API change in `triton.testing.do_bench`
- force `num_stages` to be 1 when blocksize is 128x128 to avoid out of resources exception when `bsr_dense_mm` is called from `nn.linear`.
- as in the title. The performance of `nn.linear` on BSR tensor weights (dtypes `float16` and `bfloat16`) is increased as follows (`NVIDIA A100-SXM4-80GB`):
- for blocksize 16x16, the average/maximum speed up is about 11/20 %
- for blocksize 32x32, the average/maximum speed up is about 15/24 %
- for blocksize 64x64, the average/maximum speed up is about 18/26 %
- for blocksize 128x128, the average/maximum speed up is about 15/28 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114026
Approved by: https://github.com/cpuhrsch