Argument "device" was missed.
So, "test_sparse.py::TestSparseAnyCUDA::test_as_sparse_gradcheck_*_cuda" was always run on the default device ("cpu") if another default torch device was not configured before.
This fix will probably detect a number of issues on various devices which were previously missed.
Should fix failed rocm CI jobs with "##[error]The action has timed out." and speedup test execution
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111584
Approved by: https://github.com/soulitzer
Compared to #104848, this PR makes a step further: when the enable_sparse_support decorator is applied to `torch.autograd.gradcheck`, the resulting callable is equivalent to `torch.autograd.gradcheck` with an extra feature of supporting functions that can have input sparse tensors or/and can return sparse tensors.
At the same time, the underlying call to `torch.autograd.gradcheck` will operate on strided tensors only. This basically means that torch/autograd/gradcheck.py can be cleaned up by removing the code that deals with sparse tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107150
Approved by: https://github.com/albanD, https://github.com/amjames, https://github.com/cpuhrsch
ghstack dependencies: #107638, #107777
Resolves https://github.com/pytorch/pytorch/issues/107097
After this PR, instead of
```python
torch.sparse_coo_tensor(indices, values, size)._coalesced_(is_coalesced)
```
(that does not work in the autograd context, see #107097), use
```python
torch.sparse_coo_tensor(indices, values, size, is_coalesced=is_coalesced)
```
All sparse coo factory functions that take indices as input support the `is_coalesced` argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107638
Approved by: https://github.com/cpuhrsch
Update distutils.Version to packaging.version due to the deprecation warning.
```python
/root/Git.d/pytorch/pytorch/torch/testing/_internal/common_methods_invocations.py:17136: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
/root/Git.d/pytorch/pytorch/torch/testing/_internal/common_methods_invocations.py:17138: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
/root/Git.d/pytorch/pytorch/torch/testing/_internal/common_methods_invocations.py:17140: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107207
Approved by: https://github.com/soulitzer
Issue: #93684
# Problem
Reduce graph breaks when dynamo compiles python functions containing numpy functions and ndarray operations.
# Design (as I know it)
* Use torch_np.ndarray(a wrapper of tensor) to back a `VariableTracker`: `NumpyTensorVariable`.
* Translate all attributes and methods calls, on ndarray, to torch_np.ndarray equivalent.
This PR adds `NumpyTensorVariable` and supports:
1. tensor to ndarray, ndarray to tensor
2. numpy functions such as numpy.meshgrid()
3. ndarray attributes such as `itemsize`, `stride`
Next PR will handle returning `np.ndarray` and add support for ndarray methods
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95849
Approved by: https://github.com/ezyang
As in the title.
The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`.
As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example,
```python
torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense())
torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense())
```
(recall, gradcheck has `masked=False` as default) must be updated to
```python
torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False))
torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True)
```
Fixes https://github.com/pytorch/pytorch/issues/95550
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96095
Approved by: https://github.com/cpuhrsch
As in the title. The bug was reported in https://github.com/pytorch/pytorch/pull/94728#discussion_r1108892366 and has the following reproducer:
```python
>>> import torch
>>> check_ctx = torch.sparse.check_sparse_tensor_invariants(True)
>>> no_check_ctx = torch.sparse.check_sparse_tensor_invariants(False)
>>> with check_ctx:
... assert torch.sparse.check_sparse_tensor_invariants.is_enabled()
... with no_check_ctx:
... assert not torch.sparse.check_sparse_tensor_invariants.is_enabled()
... assert torch.sparse.check_sparse_tensor_invariants.is_enabled()
...
Traceback (most recent call last):
File "<stdin>", line 5, in <module>
AssertionError
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95372
Approved by: https://github.com/cpuhrsch
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR.
This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601
Approved by: https://github.com/ezyang