Compared to #104848, this PR makes a step further: when the enable_sparse_support decorator is applied to `torch.autograd.gradcheck`, the resulting callable is equivalent to `torch.autograd.gradcheck` with an extra feature of supporting functions that can have input sparse tensors or/and can return sparse tensors.
At the same time, the underlying call to `torch.autograd.gradcheck` will operate on strided tensors only. This basically means that torch/autograd/gradcheck.py can be cleaned up by removing the code that deals with sparse tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107150
Approved by: https://github.com/albanD, https://github.com/amjames, https://github.com/cpuhrsch
ghstack dependencies: #107638, #107777
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
semi-structured sparse bool mask and creates an instance of the
subclass.
**SparseSemiStructuredTensor**
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
**to_sparse_semi_structured**
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
**User Details**
We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
```
from torch.sparse import to_sparse_semi_structured
mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()
linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
mask=linear.weight.bool())
```
This also updates tests and the `torch.sparse` module docstring to
reflect these changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
semi-structured sparse bool mask and creates an instance of the
subclass.
**SparseSemiStructuredTensor**
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
**to_sparse_semi_structured**
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
**User Details**
We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
```
from torch.sparse import to_sparse_semi_structured
mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()
linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
mask=linear.weight.bool())
```
This also updates tests and the `torch.sparse` module docstring to
reflect these changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
semi-structured sparse bool mask and creates an instance of the
subclass.
**SparseSemiStructuredTensor**
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
**to_sparse_semi_structured**
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
**User Details**
We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
```
from torch.sparse import to_sparse_semi_structured
mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()
linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
mask=linear.weight.bool())
```
This also updates tests and the `torch.sparse` module docstring to
reflect these changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
The previous sentence seemed to imply that sparse may not always be helpful, ie, your execution time may increase when using sparse. But the docs mentioned otherwise.
A simple re-ordering of two words in the documentation to better align with the contextual sentiment.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93258
Approved by: https://github.com/cpuhrsch
This PR is a copy of https://github.com/pytorch/pytorch/pull/90849 that merge was reverted.
The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:
`torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking.
`torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.
The PR fixes https://github.com/pytorch/pytorch/issues/90833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92094
Approved by: https://github.com/cpuhrsch
This PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:
- `torch.enable_check_sparse_tensor_invariants` and `torch.is_check_sparse_tensor_invariants_enabled` functions to globally enable/disable the invariant checks and to retrieve the state of the feature, respectively
- `torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.
The PR also fixes https://github.com/pytorch/pytorch/issues/90833
# Main issue
*The following content is outdated after merging the PRs in this ghstack but kept for the record.*
The importance of this feature is that when enabling the invariants checks by default, say, via
<details>
```
$ git diff
diff --git a/torch/__init__.py b/torch/__init__.py
index c8543057c7..19a91d0482 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1239,3 +1239,8 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
# Populate magic methods on SymInt and SymFloat
import torch.fx.experimental.symbolic_shapes
+
+# temporarily enable sparse tensor arguments validation in unsafe
+# constructors:
+
+torch._C._set_check_sparse_tensor_invariants(True)
```
</details>
a massive number of test failures/errors occur in test_sparse_csr.py tests:
```
$ pytest -sv test/test_sparse_csr.py
<snip>
==== 4293 failed, 1557 passed, 237 skipped, 2744 errors in 69.71s (0:01:09) ====
```
that means that we are silently constructing sparse compressed tensors that do not satisfy the sparse tensor invariants. In particular, the following errors are raised:
```
AssertionError: "resize_as_sparse_compressed_tensor_: self and src must have the same layout" does not match "expected values to be a strided and contiguous tensor"
RuntimeError: CUDA error: device-side assert triggered
RuntimeError: `col_indices[..., crow_indices[..., i - 1]:crow_indices[..., i]] for all i = 1, ..., nrows are sorted and distinct along the last dimension values` is not satisfied.
RuntimeError: expected col_indices to be a strided and contiguous tensor
RuntimeError: expected row_indices to be a strided and contiguous tensor
RuntimeError: expected values to be a strided and contiguous tensor
RuntimeError: for_each: failed to synchronize: cudaErrorAssert: device-side assert triggered
RuntimeError: tensor dimensionality must be sum of batch, base, and dense dimensionalities (=0 + 2 + 0) but got 3
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90849
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
Similar to [scipy.sparse.spdiags](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spdiags.html#scipy-sparse-spdiags)
Part of #70926
In other functions (ie (torch.diagonal)[https://pytorch.org/docs/stable/generated/torch.diagonal.html#torch.diagonal]) diagonals of a tensor are referenced using the offset and the two dimensions that the diagonal is taken with respect to.
Here the reference implementation from scipy is only considering matrix output, so even if we only support 2-d output at first. It may be useful to consider how the dimensions corresponding to each diagonal would be specified for higher dimensional output.
The proposed torch signature implies that all offsets refer to the diagonals with respect to the only two dimensions of the output:
```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, int[] shape, Layout? layout=None) -> SparseTensor
```
Above it is required that: `diagonals.ndimension() == 2`, `offsets.ndimensions() == 1`, `offsets.shape[0] == diagonals.shape[0]` and `len(shape) == 2`.
This would need to be altered for the case where `len(shape)` > 2. One options is:
```
torch.sparse.spdiags(Tensor[] diagonals, IntTensor[] offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```
Here `offsets` and `diagonals` becomes lists of tensors, and the `IntTensor dims` argument is introduced. This would require that `len(diagonals) == len(offsets) == dims.shape[0]`, `dims.ndimension() == 2` and `dims.shape[1] == 2` also the same restrictions as the 2d case above apply to the elements of `diagonals` and `offsets` pairwise (that is `diagonals[i].ndimension() == 2`, `offsets[i].ndimension() == 1` and `offsets[i].shape[0] == diagonals[i].shape[0]` for all i). This form of the signature would construct the sparse result by placing the values from `diagonals[i][j]` into the diagonal with offset `offset[i][j]` taken with respect to dimensions `dims[i]`. The specialization back to the original signature for the 2d case could be seen as allowing the single row of dims to default to `[0, 1]` when there is only one `diagonals`, `offsets` provided, and shape is `2-d`. This option allows the rows of an input element `diagonals[i]` to have a different length which may be appropriate as the max length of a diagonal along different dimension pairs will be different.
Another option is to specify the dimensions the diagonal is taken with respect to for each offset. This signature would look like:
```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```
Here, `diagonals` is still 2-D with dimension 0 matching the length of 1-D `offsets` and the tensor input `dims` is also 2-D with dimension 0 matching the length of 1-D `offsets` and the second dimension being fixed at `2` in this case the sparse result is constructed by placing the elements from `diagonals[i]` into the output diagonal `output.diagonal(offset[i], dim0=dims[i][0], dim1=dims[i][1])` (with some additional consideration that makes it more complicated than simply asigning to that view). The specialization from this back to the 2-D form could be seen as assuming `dims = [[0, 1], [0, 1]... len(offsets) times ]` when `len shape==2`.
In both proposed signatures for the N-D case the specialization back to the 2-D signature is a bit of a stretch for your typical default arguments logic, however I think the first is better choice as it offers more flexibility.
I think some discussion is required about:
- [x] Should the N-D output case be implemented from the outset
- [x] If not, should the future addition of the N-D output case be considered when designing the interface.
- [x] Other thoughts on the signature which includes the `dims` information for the N-D output case.
**Resolution**: Since no one has offered a request for N-D output support, I think is fine to restrict this to sparse matrix generation. Should a request for N-D support come later, an overload accepting the additional `dims` could be added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78439
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/pearu
Similar to [scipy.sparse.spdiags](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spdiags.html#scipy-sparse-spdiags)
Part of #70926
In other functions (ie (torch.diagonal)[https://pytorch.org/docs/stable/generated/torch.diagonal.html#torch.diagonal]) diagonals of a tensor are referenced using the offset and the two dimensions that the diagonal is taken with respect to.
Here the reference implementation from scipy is only considering matrix output, so even if we only support 2-d output at first. It may be useful to consider how the dimensions corresponding to each diagonal would be specified for higher dimensional output.
The proposed torch signature implies that all offsets refer to the diagonals with respect to the only two dimensions of the output:
```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, int[] shape, Layout? layout=None) -> SparseTensor
```
Above it is required that: `diagonals.ndimension() == 2`, `offsets.ndimensions() == 1`, `offsets.shape[0] == diagonals.shape[0]` and `len(shape) == 2`.
This would need to be altered for the case where `len(shape)` > 2. One options is:
```
torch.sparse.spdiags(Tensor[] diagonals, IntTensor[] offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```
Here `offsets` and `diagonals` becomes lists of tensors, and the `IntTensor dims` argument is introduced. This would require that `len(diagonals) == len(offsets) == dims.shape[0]`, `dims.ndimension() == 2` and `dims.shape[1] == 2` also the same restrictions as the 2d case above apply to the elements of `diagonals` and `offsets` pairwise (that is `diagonals[i].ndimension() == 2`, `offsets[i].ndimension() == 1` and `offsets[i].shape[0] == diagonals[i].shape[0]` for all i). This form of the signature would construct the sparse result by placing the values from `diagonals[i][j]` into the diagonal with offset `offset[i][j]` taken with respect to dimensions `dims[i]`. The specialization back to the original signature for the 2d case could be seen as allowing the single row of dims to default to `[0, 1]` when there is only one `diagonals`, `offsets` provided, and shape is `2-d`. This option allows the rows of an input element `diagonals[i]` to have a different length which may be appropriate as the max length of a diagonal along different dimension pairs will be different.
Another option is to specify the dimensions the diagonal is taken with respect to for each offset. This signature would look like:
```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```
Here, `diagonals` is still 2-D with dimension 0 matching the length of 1-D `offsets` and the tensor input `dims` is also 2-D with dimension 0 matching the length of 1-D `offsets` and the second dimension being fixed at `2` in this case the sparse result is constructed by placing the elements from `diagonals[i]` into the output diagonal `output.diagonal(offset[i], dim0=dims[i][0], dim1=dims[i][1])` (with some additional consideration that makes it more complicated than simply asigning to that view). The specialization from this back to the 2-D form could be seen as assuming `dims = [[0, 1], [0, 1]... len(offsets) times ]` when `len shape==2`.
In both proposed signatures for the N-D case the specialization back to the 2-D signature is a bit of a stretch for your typical default arguments logic, however I think the first is better choice as it offers more flexibility.
I think some discussion is required about:
- [x] Should the N-D output case be implemented from the outset
- [x] If not, should the future addition of the N-D output case be considered when designing the interface.
- [x] Other thoughts on the signature which includes the `dims` information for the N-D output case.
**Resolution**: Since no one has offered a request for N-D output support, I think is fine to restrict this to sparse matrix generation. Should a request for N-D support come later, an overload accepting the additional `dims` could be added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78439
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/pearu
Summary:
Working towards https://docs.google.com/document/d/10yx2-4gs0gTMOimVS403MnoAWkqitS8TUHX73PN8EjE/edit?pli=1#
This PR:
- Ensure that all the submodules are listed in a rst file (that ensure they are considered by the coverage tool)
- Remove some long deprecated code that just error out on import
- Remove the allow list altogether to ensure nothing gets added back there
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73983
Reviewed By: anjali411
Differential Revision: D34787908
Pulled By: albanD
fbshipit-source-id: 163ce61e133b12b2f2e1cbe374f979e3d6858db7
(cherry picked from commit c9edfead7a01dc45bfc24eaf7220d2a84ab1f62e)
Summary:
This PR introduces the `cuSolverSP` backend for `linalg.solve` with sparse CSR input matrices. The motivation comes from the issue: https://github.com/pytorch/pytorch/issues/69538.
`cuSolver` provides [`cusolverSp<t>csrlsvluHost`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu) API, a few things to note:
1. As mentioned in the documentation: `only CPU (Host) path is provided.` From the profiling, there doesn't seem to be any GPU kernel launch for optimization, please see the profiling below.
2. Since only `host` path is provided, the CPU path uses `csrlsvluHost` (but requires PyTorch to be installed/built with CUDA support).
3. The documentation mentions reordering helps optimize stuff, but it isn't clear how it affects the performance. There are options for reordering, so we stick to `reorder = 0` as the default choice.
`cuSolver` has [`csrlsvqr`](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr) function which provides a `device` path to solve the linear system. This function is used for the CUDA path in this PR.
**Gist:**
For CPU Path: we call [`csrlsvluHost` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvlu).
For CUDA Path: we call [`csrlsvqr` function of cuSolver](https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csrlsvqr).
**Profiling:** (On sparse input tensor of size 1000 x 1000, with a vector of shape length 1000), for `csrlsvlu` function (to show no GPU optimization)
```cpp
==3999651== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 100.00% 2.1440us 1 2.1440us 2.1440us 2.1440us [CUDA memcpy HtoD]
API calls: 99.72% 1.07199s 9 119.11ms 500ns 1.07164s cudaFree
0.11% 1.2182ms 398 3.0600us 140ns 137.94us cuDeviceGetAttribute
0.06% 674.45us 4 168.61us 165.50us 173.64us cuDeviceTotalMem
0.03% 357.07us 4 89.268us 2.7800us 201.89us cudaMalloc
0.03% 309.29us 1 309.29us 309.29us 309.29us cudaGetDeviceProperties
0.01% 160.47us 332 483ns 350ns 3.3300us cudaFuncSetAttribute
0.01% 115.12us 4 28.780us 26.290us 33.410us cuDeviceGetName
0.00% 28.591us 5 5.7180us 440ns 16.921us cudaGetDevice
0.00% 22.061us 4 5.5150us 871ns 18.690us cudaDeviceSynchronize
0.00% 20.370us 18 1.1310us 410ns 6.9900us cudaEventDestroy
0.00% 16.390us 1 16.390us 16.390us 16.390us cudaMemcpy
0.00% 11.540us 2 5.7700us 1.4900us 10.050us cuDeviceGetPCIBusId
0.00% 10.510us 18 583ns 430ns 1.6200us cudaEventCreateWithFlags
0.00% 7.9100us 21 376ns 290ns 700ns cudaDeviceGetAttribute
0.00% 1.4300us 6 238ns 150ns 590ns cuDeviceGet
0.00% 1.2200us 4 305ns 190ns 500ns cuDeviceGetCount
0.00% 900ns 1 900ns 900ns 900ns cuInit
0.00% 860ns 4 215ns 180ns 260ns cuDeviceGetUuid
0.00% 240ns 1 240ns 240ns 240ns cuDriverGetVersion
0.00% 230ns 1 230ns 230ns 230ns cudaGetDeviceCount
```
Script:
```python
import torch
def solve(x, other, out):
torch.linalg.solve(x, other, out=out)
if __name__ == "__main__":
dense_inp = torch.randn((1000, 1000), dtype=torch.float64)
# Set 50% of the values to 0 randomly
dense_inp = torch.nn.functional.dropout(dense_inp, p=0.5)
sparse_inp = dense_inp.to_sparse_csr()
other = torch.randint(100, (1000,), dtype=torch.float64)
out = torch.randint(1, (1000,), dtype=torch.float64)
solve(sparse_inp, other, out)
```
The following error is raised when the function is used on a CPU device with PyTorch built/installed without CUDA support:
* When built without CUDA support:
```python
/home/krshrimali/pytorch/torch/autograd/profiler.py:151: UserWarning: CUDA is not available, disabling CUDA profiling
warn("CUDA is not available, disabling CUDA profiling")
Traceback (most recent call last):
File "/home/krshrimali/pytorch/test_sp.py", line 17, in <module>
solve(x, other, out)
File "/home/krshrimali/pytorch/test_sp.py", line 5, in solve
torch.linalg.solve(x, other, out=out)
RuntimeError: PyTorch was not built with CUDA support. Please use PyTorch built CUDA support
```
**Performance Comparison** (vs SciPy's [`scipy.sparse.linalg.spsolve`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.spsolve.html):
Time taken by `scipy.sparse.linalg.spsolve` : 0.595 seconds
On CPU: Time taken by `torch.linalg.solve` : 4.565 seconds
On CUDA: Time taken by `torch.linalg.solve`: 1.838 seconds
The inputs are of dimensions: (17281, 17281) and (17281, 1), and were taken from https://math.nist.gov/MatrixMarket/extreme.html.
Thanks to IvanYashchuk for helping me with the PR, and guiding me through it.
cc: IvanYashchuk pearu nikitaved cpuhrsch
cc nikitaved pearu cpuhrsch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71399
Reviewed By: VitalyFedyunin
Differential Revision: D33767740
Pulled By: cpuhrsch
fbshipit-source-id: a945f065210cd719096eb8d7cdbf8e8937c2fce9
(cherry picked from commit f4f35c17da414e1ca6c6d91402933521857aa1ea)
Summary:
Let's make the documentation for `torch.sparse.sampled_addmm` searchable in the PyTorch documentation.
This PR shall be cherry-picked for the next 1.11 release.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72312
Reviewed By: davidberard98
Differential Revision: D34045230
Pulled By: cpuhrsch
fbshipit-source-id: c1b1dc907443284857f48c8ce1efab22c6701bbe
(cherry picked from commit 225929ecf2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27850
Many of these are real problems in the documentation (i.e., link or
bullet point doesn't display correctly).
Test Plan: - built and viewed the documentation for each change locally.
Differential Revision: D17908123
Pulled By: zou3519
fbshipit-source-id: 65c92a352c89b90fb6b508c388b0874233a3817a
Summary:
- to fix#12241
- add `_sparse_sum()` to ATen, and expose as `torch.sparse.sum()`, not support `SparseTensor.sum()` currently
- this PR depends on #11253, and will need to be updated upon it lands
- [x] implement forward
- [x] implement backward
- performance [benchmark script](https://gist.github.com/weiyangfb/f4c55c88b6092ef8f7e348f6b9ad8946#file-sparse_sum_benchmark-py):
- sum all dims is fastest for sparse tensor
- when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA
- CUDA backward is comparable (<2x) between `sum several dims` vs `sum all dims` in sparse
- CPU backward uses binary search is still slow in sparse, takes `5x` time in `sum [0, 2, 3] dims` vs `sum all dims`
- optimize CUDA backward for now
- using thrust for sort and binary search, but runtime not improved
- both of CPU and CUDA forward are slow in sparse (`sum several dims` vs `sum all dims`), at most `20x` slower in CPU, and `10x` in CUDA
- improve CPU and CUDA forward kernels
(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 8.77 µs vs 72.9 µs | 42.5 µs vs 108 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 112 µs vs 4.47 ms | 484 µs vs 407 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 141 µs vs 148 µs | 647 µs vs 231 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 235 µs vs 1.23 ms | 781 µs vs 213 µs
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 48.5 µs vs 360 µs | 160 µs vs 2.03 ms
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 258 µs vs 1.22 ms | 798 µs vs 224 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 204 µs vs 882 µs | 443 µs vs 133 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 709 µs vs 1.15 ms | 893 µs vs 202 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 39.8 µs vs 81 µs | 42.4 µs vs 113 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 747 µs vs 4.7 ms | 2.4 ms vs 414 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 1.04 ms vs 126 µs | 5.03 ms vs 231 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.12 ms vs 1.24 ms | 5.99 ms vs 213 µs
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 133 µs vs 366 µs | 463 µs vs 2.03 ms
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.56 ms vs 1.22 ms | 6.11 ms vs 229 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.53 ms vs 799 µs | 824 µs vs 134 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 5.15 ms vs 1.09 ms | 7.02 ms vs 205 µs
- after improving CPU and CUDA forward kernels
- in `(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD)` forward, CPU takes ~~`171 µs`~~, in which `130 µs` is spent on `coalesce()`, for CUDA, total time is ~~`331 µs`~~, in which `141 µs` is spent on `coalesce()`, we need to reduce time at other places outside `coalesce()`.
- after a few simple tweaks, now in the forward, it is at most `10x` slower in CPU, and `7x` in CUDA. And time takes in `sum dense dims only [2, 3]` is `~2x` of `sum all dims`. Speed of `sum all sparse dims [0, 1]` is on bar with `sum all dims`
(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 7 µs vs 69.5 µs | 31.5 µs vs 61.6 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 11.3 µs vs 4.72 ms | 35.2 µs vs 285 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 197 µs vs 124 µs | 857 µs vs 134 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 124 µs vs 833 µs | 796 µs vs 106 µs
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 20.5 µs vs 213 µs | 39.4 µs vs 1.24 ms
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 131 µs vs 830 µs | 881 µs vs 132 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 95.8 µs vs 409 µs | 246 µs vs 87.2 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 624 µs vs 820 µs | 953 µs vs 124 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 45.3 µs vs 72.9 µs | 33.9 µs vs 57.2 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 81.4 µs vs 4.49 ms | 39.7 µs vs 280 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 984 µs vs 111 µs | 6.41 ms vs 121 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.45 ms vs 828 µs | 6.77 ms vs 113 µs
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 74.9 µs vs 209 µs | 37.7 µs vs 1.23 ms
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.48 ms vs 845 µs | 6.96 ms vs 132 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.14 ms vs 411 µs | 252 µs vs 87.8 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 4.53 ms vs 851 µs | 7.12 ms vs 128 µs
- time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of `torch.copy_()`:
```
>>> d = [1000, 1000, 2, 2]
>>> nnz = 10000
>>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)),
torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, d[2], d[3])
>>> size = torch.Size(d)
>>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda()
>>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_()
>>> data = S2.clone()
>>> S.copy_(S2)
>>> y = S * 2
>>> torch.cuda.synchronize()
>>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize()
7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12430
Differential Revision: D12878313
Pulled By: weiyangfb
fbshipit-source-id: e16dc7681ba41fdabf4838cf05e491ca9108c6fe
Summary:
Couple questions:
1) I used the log1p implementation in #8969 as a guide especially for testing. I'm not sure what the ```skipIfROCM``` annotation is for, so unsure if i need it for my test.
2) I implemented the branching logic in the narrow function itself; is this the right place to do so? I noticed that there a number of places where sparse-specific logic is handled with just an if statement in this file. Or should I implement a separate dispatch in native_functions.yml as in the log1p?
And of course, happy to make any any other updates/changes that I may have missed as well. This is my first PR to the project.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11342
Differential Revision: D9978430
Pulled By: weiyangfb
fbshipit-source-id: e73dc20302ab58925afb19e609e31f4a38c634ad
Summary:
This PR cleans up the `at::Tensor` class by removing all methods that start with an underscore in favor of functions in the `at::` namespace. This greatly cleans up the `Tensor` class and makes it clearer what is the public and non-public API.
For this I changed `native_functions.yaml` and `Declarations.cwrap` to make all underscore methods `variant: function` (or add such a statement to begin with), and then fixed all code locations using the underscore methods.
ezyang colesbury gchanan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11152
Differential Revision: D9683607
Pulled By: goldsborough
fbshipit-source-id: 97f869f788fa56639c05a439e2a33be49f10f543
Basically, it's easy to confuse the dimensions of the index tensor.
This adds some more text which should hopefully clarify the situation.
Fixes#2416.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
As discussed in #1441.
I also added some docs giving clear guidance about how to coalescing
in sparse tensors.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>