## Description
Add a test case to cover the corner case of empty shards when creating ShardedTensor.
Original fix contributed by a user.
https://github.com/pytorch/pytorch/pull/108915
## Test
With the fix, the test added runs fine.
Without the fix in https://github.com/pytorch/pytorch/pull/108915, the test case added would throw the following assertion error.
```
(/home/irisz/local/a/pytorch-env) [irisz@devgpu051.cln3 ~/local/pytorch (add_test_for_corner_case_for_chunk_sharding_spec)]$ python3 test/distributed/_shard/sharded_tensor/test_sharded_tensor.py TestShardTensor.test_shard_tensor_with_empty_shard
Fail to import hypothesis in common_utils, tests are not derandomized
INFO:numba.cuda.cudadrv.driver:init
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
NCCL version 2.18.3+cuda12.0
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] Caught exception:
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] getattr(self, test_name)()
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] fn()
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] method(*args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] func(self, *args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] return func(*args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] st = _shard_tensor(tensor, spec)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 170, in shard
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] assert local_tensor is not None
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] AssertionError
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] exiting process 3 with exit code: 10
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] Caught exception:
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] getattr(self, test_name)()
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] fn()
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] method(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] func(self, *args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] return func(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] st = _shard_tensor(tensor, spec)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 179, in shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] dist.scatter(
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/c10d_logger.py", line 68, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] return func(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/distributed_c10d.py", line 3143, in scatter
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] _check_tensor_list(scatter_list, "scatter_list")
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] File "/data/users/irisz/pytorch/torch/distributed/distributed_c10d.py", line 808, in _check_tensor_list
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] raise TypeError(
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] TypeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] python test/distributed/_shard/sharded_tensor/test_sharded_tensor.py -k test_shard_tensor_with_empty_shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] exiting process 0 with exit code: 10
Process 3 terminated with exit code 10, terminating remaining processes.
E
======================================================================
ERROR: test_shard_tensor_with_empty_shard (__main__.TestShardTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
self._join_processes(fn)
File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 761, in _join_processes
self._check_return_codes(elapsed_time)
File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 811, in _check_return_codes
raise RuntimeError(error)
RuntimeError: Process 3 exited with error code 10 and exception:
Traceback (most recent call last):
File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
getattr(self, test_name)()
File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
fn()
File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
method(*args, **kwargs)
File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
func(self, *args, **kwargs)
File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
return func(*args, **kwargs)
File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
st = _shard_tensor(tensor, spec)
File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 170, in shard
assert local_tensor is not None
AssertionError
----------------------------------------------------------------------
Ran 1 test in 21.207s
FAILED (errors=1)
```
cc. @fduwjj @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109626
Approved by: https://github.com/fduwjj
Preparation for the next PR in this stack: #89559.
I replaced
- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).
There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
Change StorageReader and StorageWriter to follow the new SavePlanner / LoadPlanner design.
Add optional planner param to load_state_dict and save_state_dict and implement the new protocol.
This includes a small rework of FileSystem layer to support single file per rank and making fsync optional to match torch.save behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83781
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
The planners come with default implementations in default_planner.py.
The default planners expose their core functionality as separate functions
to make it easy for other checkpoint implementations to use this functionality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83419
Approved by: https://github.com/wanchaol
These two ops (Embedding and EmbeddingBag for ShardedTensor) especially for row-wise sharding is very inefficient and hard to fit in the concept of future design. So this PR is trying to:
1. Remove all unnecessary collective communications. Only one gather and one reduce(or reduce scatter) is needed.
2. Use auto-grad enabled collectives so that we can use these ops in real model training.
3. Some minor code cleaning
4. Treat input differently when it's replicated tensor. (Will add more for this for the next few PRs).
Differential Revision: [D37965687](https://our.internmc.facebook.com/intern/diff/D37965687/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81853
Approved by: https://github.com/wanchaol
This PR implements the following changes.
Move to new checkpoint metadata format with split between logical and storage data.
This is a step in the direction of supporting extensible checkpointing as it moves us away from the hardcoded storage model enforced by the FileSystem storage layer.
Change CheckpointException to include exception traceback. Exception tracebacks are not serializable so we need to take care of that otherwise we provide horribly bad errors to users.
Finally, remove `validate_state_dict` as it lost its usefulness. Loading is becoming more and more flexible to the point that the only reasonable way to verify if it's possible to load a given configuration is to actually try it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82078
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
MetadataIndex simplifies indexing into state dict and Metadata.
This includes a find_state_dict_object helper that searcher into a state dict.
This PR doesn't include search over Metadata at it requires changes that will land
in a subsequent PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81909
Approved by: https://github.com/wanchaol
Introduce error handling across all ranks when loading and saving checkpoints.
This makes it a lot simpler for users to handle failures and, as a positive side-effect, coordination of when it successfully finished.
This change requires 3 collectives when saving and 1 when loading.
All those collectives carry a small payload so they will be latency bound and write time should dominate it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77091
Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol