Commit Graph

77 Commits

Author SHA1 Message Date
wz337
e73efbffab [Test][ShardedTensor] Add test for corner case for chunk sharding spec (#109626)
## Description
Add a test case to cover the corner case of empty shards when creating ShardedTensor.
Original fix contributed by a user.
https://github.com/pytorch/pytorch/pull/108915

## Test
With the fix, the test added runs fine.
Without the fix in https://github.com/pytorch/pytorch/pull/108915, the test case added would throw the following assertion error.
```
(/home/irisz/local/a/pytorch-env) [irisz@devgpu051.cln3 ~/local/pytorch (add_test_for_corner_case_for_chunk_sharding_spec)]$ python3 test/distributed/_shard/sharded_tensor/test_sharded_tensor.py TestShardTensor.test_shard_tensor_with_empty_shard
Fail to import hypothesis in common_utils, tests are not derandomized
INFO:numba.cuda.cudadrv.driver:init
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
NCCL version 2.18.3+cuda12.0
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] Caught exception:
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     fn()
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     func(self, *args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     st = _shard_tensor(tensor, spec)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 170, in shard
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]     assert local_tensor is not None
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR] AssertionError
[rank3]:[2023-09-19 11:19:27,071] torch.testing._internal.common_distributed: [ERROR]  exiting process 3 with exit code: 10
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] Caught exception:
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     fn()
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     func(self, *args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     st = _shard_tensor(tensor, spec)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 179, in shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     dist.scatter(
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/c10d_logger.py", line 68, in wrapper
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/distributed_c10d.py", line 3143, in scatter
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     _check_tensor_list(scatter_list, "scatter_list")
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]   File "/data/users/irisz/pytorch/torch/distributed/distributed_c10d.py", line 808, in _check_tensor_list
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]     raise TypeError(
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] TypeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]      python test/distributed/_shard/sharded_tensor/test_sharded_tensor.py -k test_shard_tensor_with_empty_shard
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
[rank0]:[2023-09-19 11:19:27,123] torch.testing._internal.common_distributed: [ERROR]  exiting process 0 with exit code: 10
Process 3 terminated with exit code 10, terminating remaining processes.
E
======================================================================
ERROR: test_shard_tensor_with_empty_shard (__main__.TestShardTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
    self._join_processes(fn)
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 761, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 811, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 3 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 658, in run_test
    getattr(self, test_name)()
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 544, in wrapper
    fn()
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_utils.py", line 2406, in wrapper
    method(*args, **kwargs)
  File "/data/users/irisz/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py", line 94, in wrapper
    func(self, *args, **kwargs)
  File "/data/users/irisz/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
    return func(*args, **kwargs)
  File "/data/users/irisz/pytorch/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py", line 258, in test_shard_tensor_with_empty_shard
    st = _shard_tensor(tensor, spec)
  File "/data/users/irisz/pytorch/torch/distributed/_shard/api.py", line 68, in _shard_tensor
    st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
  File "/data/users/irisz/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 170, in shard
    assert local_tensor is not None
AssertionError
----------------------------------------------------------------------
Ran 1 test in 21.207s

FAILED (errors=1)
```

cc. @fduwjj @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109626
Approved by: https://github.com/fduwjj
2023-09-20 14:40:07 +00:00
Ivan Kobzarev
4582ceb2c4 [distributed][sharded_tensor] Move local_shards check from ShardedTensorBase to ShardedTensor (#100197)
Differential Revision: [D45369211](https://our.internmc.facebook.com/intern/diff/D45369211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100197
Approved by: https://github.com/fduwjj
2023-05-02 12:42:24 +00:00
Iris
ca8625f456 [BE][1/N]Add sharding spec logger for ShardedTensor (#99748)
Set up a nullHandler() on the OSS side.
Next step is to set up the counterpart in internal.

This is part of the effort for ShardedTensor deprecation. We want to log internal use cases for different sharding spec.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99748
Approved by: https://github.com/H-Huang, https://github.com/fegin
2023-04-22 04:05:21 +00:00
fduwjj
5cc2e4d7c9 [10/N] Remove ST init ops (#96985)
Differential Revision: [D44158326](https://our.internmc.facebook.com/intern/diff/D44158326)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96985
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-03-22 20:26:18 +00:00
fduwjj
546835c45a [9/N] Remove ST multiple ops (#96989)
Differential Revision: [D44158327](https://our.internmc.facebook.com/intern/diff/D44158327)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96989
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-03-22 20:02:58 +00:00
fduwjj
7863efbd76 [BE][8/N] Remove ShardedTensor from TP FSDP integration test and other tests depending on Sharded Linear (#96254)
We removed ShardedLinear in https://github.com/pytorch/pytorch/pull/95948 but it broke TP_FSDP integration test because it is using ShardedTensor in the test. Migrating using DTensor fixes the test. DTensor shards the bias too so that we need to change the test a little bit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96254
Approved by: https://github.com/huydhn
2023-03-08 21:56:41 +00:00
fduwjj
28aa2efd14 [7/N][BE] Remove Partial Tensor and its dependency (#95949)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95949
Approved by: https://github.com/wanchaol
2023-03-06 19:57:46 +00:00
fduwjj
6dddc0d689 [6/N][BE] Remove Sharded Linear Op for ShardedTensor (#95948)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95948
Approved by: https://github.com/wanchaol
2023-03-06 19:57:19 +00:00
Sergii Dymchenko
35bf5bac26 Fix "sandcastle_skip_if decorator name is confusing" (#95649)
Fixes https://github.com/pytorch/pytorch/issues/89473
See the issue https://github.com/pytorch/pytorch/issues/89473

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95649
Approved by: https://github.com/atalman, https://github.com/malfet
2023-03-03 09:29:40 +00:00
fduwjj
fa7f17799a [3/N][BE][ST Deprecate] Remove Replicated Tensor (#95453)
Please use distributed tensor instead. We are deprecating ShardedTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95453
Approved by: https://github.com/wanchaol
2023-02-26 06:18:31 +00:00
fduwjj
a88bfc60c7 [2/N][ST deprecate][BE] Remove Replicate Tensor convert from DDP and PTD (#95450)
No use is found for this ST/Replicated Tensor based DDP. As part of ShardedTensor migration, let's remove this logic. Trying to undo everything in https://github.com/pytorch/pytorch/pull/75753.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95450
Approved by: https://github.com/wanchaol
2023-02-26 03:03:37 +00:00
Xuehai Pan
046e88a291 [BE] [3/3] Rewrite super() calls in test (#94592)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-12 22:20:53 +00:00
Howard Huang
bc764f453d Fix sharded_tensor test_sharded_tensor_to_cpu (#91453)
Fixes https://github.com/pytorch/pytorch/issues/91381

Assert needs to be updated in the test. Run `ciflow/periodic` to run the multigpu tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91453
Approved by: https://github.com/clee2000
2022-12-29 13:21:30 +00:00
PyTorch MergeBot
cba96366a2 Revert "remove torch.equal usages (#89527)"
This reverts commit 4095ef8b80.

Reverted https://github.com/pytorch/pytorch/pull/89527 on behalf of https://github.com/clee2000 due to broke periodic multigpu tests 4095ef8b80 https://github.com/pytorch/pytorch/actions/runs/3592806602/jobs/6049368502
2022-12-02 21:36:13 +00:00
Philip Meier
4095ef8b80 remove torch.equal usages (#89527)
Preparation for the next PR in this stack: #89559.

I replaced

- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).

There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
2022-12-01 11:22:52 +00:00
Iris
aee96bbf5a [PT-D][Checkpointing] Move distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (#88698)
Context in RFC: https://github.com/pytorch/pytorch/issues/86620

.rst file will be finalized in subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88698
Approved by: https://github.com/wanchaol
2022-11-16 21:06:38 +00:00
Chien-Chin Huang
bd1e95ce30 Improve the performance of validate_non_overlapping_shards_metadata (#85639)
`validate_non_overlapping_shards_metadata()` uses a quadratic algorithm to verify the overlapping. However, in some cases (only one dimension is sharded), we a O(nlogn) algorithm can easily be implemented. This PR changes the implementation of `validate_non_overlapping_shards_metadata()`.

Differential Revision: [D39681725](https://our.internmc.facebook.com/intern/diff/D39681725/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85639
Approved by: https://github.com/wanchaol
2022-10-20 23:51:48 +00:00
Rodrigo Kumpera
f66be71d77 [checkpoint] Adopt Planner interface across the board. (#83781)
Change StorageReader and StorageWriter to follow the new SavePlanner / LoadPlanner design.

Add optional planner param to load_state_dict and save_state_dict and implement the new protocol.

This includes a small rework of FileSystem layer to support single file per rank and making fsync optional to match torch.save behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83781
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-08-29 14:38:32 +00:00
Sergii Dymchenko
591222f5d9 Fix use-dict-literal lint (#83718)
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
2022-08-24 00:26:46 +00:00
Rodrigo Kumpera
d11d3dd036 [dist.cp] Introduce LoadPlanner and SavePlanner extensibility API. (#83419)
The planners come with default implementations in default_planner.py.

The default planners expose their core functionality as separate functions
to make it easy for other checkpoint implementations to use this functionality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83419
Approved by: https://github.com/wanchaol
2022-08-18 19:40:15 +00:00
fduwjj
d3a176a156 [PT-D][BE][TP perf 1/N] Get rid of unnecessary collectives in Embedding/EmbeddingBag and use autograd-enabled collectives (#81853)
These two ops (Embedding and EmbeddingBag for ShardedTensor) especially for row-wise sharding is very inefficient and hard to fit in the concept of future design. So this PR is trying to:
1. Remove all unnecessary collective communications. Only one gather and one reduce(or reduce scatter) is needed.
2. Use auto-grad enabled collectives so that we can use these ops in real model training.
3. Some minor code cleaning
4. Treat input differently when it's replicated tensor. (Will add more for this for the next few PRs).

Differential Revision: [D37965687](https://our.internmc.facebook.com/intern/diff/D37965687/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81853
Approved by: https://github.com/wanchaol
2022-08-17 04:32:41 +00:00
Wanchao Liang
cda8635a5e [_shard] only check shard metadata for copy_ (#82655)
copy_ does not restrict on tensor properites, it does not check things like requires_grad or dtype, so only check if the shard metadata are the same

Differential Revision: [D38359176](https://our.internmc.facebook.com/intern/diff/D38359176/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82655
Approved by: https://github.com/fduwjj
2022-08-04 06:12:19 +00:00
Rodrigo Kumpera
f4ee37453c [dist.checkpoint] Change metadata format and improve error reporting (#82078)
This PR implements the following changes.

Move to new checkpoint metadata format with split between logical and storage data.
This is a step in the direction of supporting extensible checkpointing as it moves us away from the hardcoded storage model enforced by the FileSystem storage layer.

Change CheckpointException to include exception traceback. Exception tracebacks are not serializable so we need to take care of that otherwise we provide horribly bad errors to users.

Finally, remove `validate_state_dict` as it lost its usefulness. Loading is becoming more and more flexible to the point that the only reasonable way to verify if it's possible to load a given configuration is to actually try it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82078
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-08-03 17:00:12 +00:00
Wanchao Liang
48a34acf13 [_shard] add copy_ to shardedtensor (#82508)
as titled

Differential Revision: [D38290442](https://our.internmc.facebook.com/intern/diff/D38290442)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82508
Approved by: https://github.com/fduwjj
2022-08-01 23:52:19 +00:00
Rodrigo Kumpera
69eecdbc9c Introduce MetadataIndex and helper to use it. (#81909)
MetadataIndex simplifies indexing into state dict and Metadata.

This includes a find_state_dict_object helper that searcher into a state dict.

This PR doesn't include search over Metadata at it requires changes that will land
in a subsequent PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81909
Approved by: https://github.com/wanchaol
2022-07-28 12:20:58 +00:00
Wanchao Liang
28c43190b8 [_shard] Add ShardedTensorBase (#82291)
This PR added ShardedTensorBase, which is the base class of
ShardedTensor, and only contains local shards, ShardedTensorMetadata,
and does not have any communication backend attached (i.e ProcessGroup)

Differential Revision: [D38190272](https://our.internmc.facebook.com/intern/diff/D38190272)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82291
Approved by: https://github.com/fduwjj
2022-07-28 00:17:52 +00:00
Wanchao Liang
7ff121e75a [reland] make ShardedTensor be a Tensor and nn.Parameter (#82089)
This is the reland PR of https://github.com/pytorch/pytorch/pull/79825,
which was reverted due to multi-gpu ci failures. Fixes those failures
and reland it again.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82089
Approved by: https://github.com/fduwjj
2022-07-25 19:06:01 +00:00
PyTorch MergeBot
f51cf774c6 Revert "[_shard] make ShardedTensor be a Tensor and nn.Parameter (#79825)"
This reverts commit 9c32439a77.

Reverted https://github.com/pytorch/pytorch/pull/79825 on behalf of https://github.com/janeyx99 due to Sorry, reverting for breaking multigpu tests 9c32439a77
2022-07-22 20:39:44 +00:00
Wanchao Liang
9c32439a77 [_shard] make ShardedTensor be a Tensor and nn.Parameter (#79825)
Differential Revision: [D37707371](https://our.internmc.facebook.com/intern/diff/D37707371)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79825
Approved by: https://github.com/kumpera
2022-07-22 16:50:12 +00:00
Wen Zhang
b8f9751f11 Add cpu/gloo tests for sharded tensor distributed checkpoint (#80997)
SharedTensor checkpoint does not depend on NCCL, replicate the GPU cases for CPU and enable development on mac

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80997
Approved by: https://github.com/wanchaol
2022-07-20 23:02:52 +00:00
Wanchao Liang
bef2fecbbc [shard] make state_dict hook be consistent
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79650

for root module we shouldn't accidentally add a "." for state_dict keys, it
should be empty instead to match the module.state_dict behavior

Differential Revision: [D37191203](https://our.internmc.facebook.com/intern/diff/D37191203/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D37191203/)!

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-06-17 22:08:06 +00:00
Rodrigo Kumpera
270c518be0 [checkpoint] Implement interop between Tensor and Sharded Tensor (#78120)
This allows loading a Tensor from a checkpoint with a SharedTensor in the same FQN.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78120
Approved by: https://github.com/pritamdamania87
2022-06-16 15:31:09 +00:00
fduwjj
f4edbaa62f [PT-D] Use process group of the partial tensor so sub pg comm will be enabled during reshard
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79357

Approved by: https://github.com/wanchaol
2022-06-14 17:44:51 +00:00
pritam
a81be44410 Fix shard_module to appropriately deal with sub process groups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79264

`shard_module` API didn't work correctly with a sub-pg since
`dist.scatter` actually takes the global rank as input for `src`.

Fixing this by passing in the appropriate rank to `dist.scatter`

Differential Revision: [D37062766](https://our.internmc.facebook.com/intern/diff/D37062766/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-12 03:50:45 +00:00
pritam
b9e3d722c4 Use appropriate dtype for sharded linear implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79255

We use several collective operations in our sharded linear
implementation and for many collectives, we do not set the `dtype` of the
output tensor appropriately. As a result, using a datatype like torch.float16
(which is not the default torch.float32) results in errors.

Fixing this across the board and adding appropriate tests.

Differential Revision: [D37059752](https://our.internmc.facebook.com/intern/diff/D37059752/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-10 07:32:15 +00:00
Wanchao Liang
2fce7483a5 Back out "[shard] make ShardedTensor a torch.Tensor subclass" (#78796)
Summary:
Original commit changeset: f3ce270bad56

Original Phabricator Diff: D36569064

Test Plan: wait for sandcastle and doing additional checks

Reviewed By: guangyuwang

Differential Revision: D36890625

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78796
Approved by: https://github.com/pbelevich
2022-06-03 19:39:26 +00:00
pritam
5aa2ed1922 Remove call to .contiguous() for local_shard_t.
The call to contiguous was probably left over from a previous
implementation and is no longer needed.

Had to adjust atol for one of the tests to accomodate for this.

Differential Revision: [D36797942](https://our.internmc.facebook.com/intern/diff/D36797942/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78598

Approved by: https://github.com/kumpera
2022-06-01 18:50:10 +00:00
pritam
44aa4ad894 Use _all_gather_base and fuse matmul for sharded linear.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78477

Use `_all_gather_base` instead of all_gather for col-wise sharding
since `_all_gather_base` returns a single fused tensor that can be used to
perform a single matmul instead of looping through and performing multiple
matmuls.

This improves performance for col-wise sharding.

Differential Revision: [D36754385](https://our.internmc.facebook.com/intern/diff/D36754385/)

Approved by: https://github.com/aazzolini, https://github.com/wanchaol
2022-06-01 17:17:34 +00:00
fduwjj
141238a889 [PT-D] Enable nan_to_num op for sharded tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78223

Approved by: https://github.com/pritamdamania87
2022-05-25 18:03:42 +00:00
Wanchao Liang
8eb62bd7ba [shard] make ShardedTensor a torch.Tensor subclass
This is the reland of PR https://github.com/pytorch/pytorch/pull/74695, which was reverted due to some internal failures.

It also removes the ShardedTensorInterface change, we will delay that
change later if we found there's a need to do that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78027

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-05-24 01:20:45 +00:00
pritam
37eb31599c [reland] Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77987)
1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77987
Approved by: https://github.com/fduwjj, https://github.com/wanchaol, https://github.com/janeyx99
2022-05-21 22:33:58 +00:00
PyTorch MergeBot
0f74b44f1a Revert "Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)"
This reverts commit 8d4c8df33a.

Reverted https://github.com/pytorch/pytorch/pull/77825 on behalf of https://github.com/janeyx99 due to as it will break multigpu test reporting
2022-05-20 17:59:03 +00:00
pritam
8d4c8df33a Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)
1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77825
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-20 16:53:27 +00:00
Wanchao Liang
4124307fae [shard] fix failed tests in sharded tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77800

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-05-18 23:21:47 +00:00
Rodrigo Kumpera
dac3fba274 Add testing workaround for EFA and TensorPipe (#77363)
This is a workaround for EFA for TensorPipe.

This allows RPC enabled tests to be ran on AWS clusters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77363
Approved by: https://github.com/wanchaol
2022-05-18 22:54:15 +00:00
Rodrigo Kumpera
c9570e4b88 [checkpoint] Synchronize error handling across all ranks (#77091)
Introduce error handling across all ranks when loading and saving checkpoints.

This makes it a lot simpler for users to handle failures and, as a positive side-effect, coordination of when it successfully finished.

This change requires 3 collectives when saving and 1 when loading.
All those collectives carry a small payload so they will be latency bound and write time should dominate it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77091
Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
2022-05-18 21:24:09 +00:00
fduwjj
3b2375291a [PT-D][Sharding] Fix view op and matrix ops unit test
To fix a corner case when the sharding dim is negative number we need to handle it correctly.

Also disable RPC for matrix ops which are not necessary and they fail on AWS pytest.

Differential Revision: [D36464797](https://our.internmc.facebook.com/intern/diff/D36464797/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77706

Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
2022-05-18 03:10:37 +00:00
pritam
c83f8ee46a Fix partial_tensor ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77580

Replace process_group with _process_group.

Differential Revision: [D36419464](https://our.internmc.facebook.com/intern/diff/D36419464/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-17 08:21:38 +00:00
Wanchao Liang
25fa964d96 [shard] add clone/detach and set requires_grad for ShardedTensor
This PR adding clone/detach and set requires_grad to ShardedTensor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77367

Approved by: https://github.com/pritamdamania87
2022-05-16 21:42:27 +00:00
fduwjj
a2cb94d21a [PT-D][Sharding] Enable more ops needed in the transformer model training
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77214

From the code base of MetaSeq Model, we have found that loads of ops are not supported by sharded tensor. In https://github.com/pytorch/pytorch/pull/75374, we have enabled most of ops already and this PR/diff aims at enabling the rest of them.

Fix some unit test errors.

Differential Revision: [D36302780](https://our.internmc.facebook.com/intern/diff/D36302780/)

Approved by: https://github.com/wanchaol, https://github.com/pritamdamania87
2022-05-15 22:43:47 +00:00