pytorch/test/distributed/_shard/sharded_tensor
wz337 e73efbffab [Test][ShardedTensor] Add test for corner case for chunk sharding spec (#109626)
## Description
Add a test case to cover the corner case of empty shards when creating ShardedTensor.
Original fix contributed by a user.
https://github.com/pytorch/pytorch/pull/108915

## Test
With the fix, the test added runs fine.
Without the fix in https://github.com/pytorch/pytorch/pull/108915, the test case added would throw the following assertion error.
```
(/home/irisz/local/a/pytorch-env) [irisz@devgpu051.cln3 ~/local/pytorch (add_test_for_corner_case_for_chunk_sharding_spec)]$ python3 test/distributed/_shard/sharded_tensor/test_sharded_tensor.py TestShardTensor.test_shard_tensor_with_empty_shard
Fail to import hypothesis in common_utils, tests are not derandomized
INFO:numba.cuda.cudadrv.driver:init
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
NCCL version 2.18.3+cuda12.0
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] Caught exception:
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     fn()
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     func(self, *args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     st = _shard_tensor(tensor, spec)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 170, in shard
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     assert local_tensor is not None
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] AssertionError
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]  exiting process 3 with exit code: 10
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] Caught exception:
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     fn()
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     func(self, *args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     st = _shard_tensor(tensor, spec)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 179, in shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     dist.scatter(
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/c10d_logger.py", line 68, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/distributed_c10d.py", line 3143, in scatter
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     _check_tensor_list(scatter_list, "scatter_list")
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/distributed_c10d.py", line 808, in _check_tensor_list
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     raise TypeError(
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] TypeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]      python test/distributed/_shard/sharded_tensor/test_sharded_tensor.py -k test_shard_tensor_with_empty_shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]  exiting process 0 with exit code: 10
Process 3 terminated with exit code 10, terminating remaining processes.
E
======================================================================
ERROR: test_shard_tensor_with_empty_shard (__main__.TestShardTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
    self._join_processes(fn)
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 761, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 811, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 3 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
    getattr(self, test_name)()
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
    fn()
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
    method(*args, **kwargs)
  File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
    func(self, *args, **kwargs)
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
    return func(*args, **kwargs)
  File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
    st = _shard_tensor(tensor, spec)
  File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
    st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
  File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 170, in shard
    assert local_tensor is not None
AssertionError
----------------------------------------------------------------------
Ran 1 test in 21.207s

FAILED (errors=1)
```

cc. @fduwjj @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109626
Approved by: https://github.com/fduwjj
2023-09-20 14:40:07 +00:00
..
ops
test_logger.py
test_sharded_tensor_reshard.py
test_sharded_tensor.py [Test][ShardedTensor] Add test for corner case for chunk sharding spec (#109626) 2023-09-20 14:40:07 +00:00