Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75374
From the code base of FairSeq and MetaSeq codebase (which is essentially a transformer model), we have found that loads of ops are not supported by sharded tensor. So we now implement a simple version so that we can at least run a transformer example:
Ops include: chuck, transpose, view, mask_fill, dropout, softmax and type_as.
Isolate the common logic of registering simple ops into a function and for future register, we just need to implement at most three functions for a new op.
ghstack-source-id: 155309147
Test Plan: CI
Reviewed By: pritamdamania87
Differential Revision: D35123021
fbshipit-source-id: 660e559fb8b4a910eb63e0586c63ab927873a2ce
(cherry picked from commit 83a87ebf627d863448dfe1019c7c5f7112cc14ab)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76199
Since Partial Tensor is somehow isolated to sharded tensor. We now move it to the _shard folder.
Also, we added the logic to remove paddings when the size is not divisible by the world size. Modify the unit test to reflect this changes.
Finally, we need to consider the placement order for the reshading spec for partial tensor, related logic is added in this change. Futhermore, for sharded linear, we will need to order the placement by rank to get the expected local result.
ghstack-source-id: 154853290
Test Plan: CI
Reviewed By: pritamdamania87, wanchaol
Differential Revision: D35827894
fbshipit-source-id: 58dab77969b8b6557f42afa7e8f5a8a053dd5793
(cherry picked from commit abeb28f16582dcf707c2e165f39df6caf692384d)
As per title.
### When running `python run_test.py -h`
It used to show:
- The general unittest parser help that we print via a second thread 35545d85dc/torch/testing/_internal/common_utils.py (L467-L470)
- The common_utils's parser help
<details><summary>Full result</summary>
<p>
```bash
$ python run_test.py -h
usage: run_test.py [-h] [-v] [-q] [--locals] [-f] [-c] [-b] [-k TESTNAMEPATTERNS] [tests [tests ...]]
positional arguments:
tests a list of any number of test modules, classes and test methods.
optional arguments:
-h, --help show this help message and exit
-v, --verbose Verbose output
-q, --quiet Quiet output
--locals Show local variables in tracebacks
-f, --failfast Stop on first fail or error
-c, --catch Catch Ctrl-C and display results so far
-b, --buffer Buffer stdout and stderr during tests
-k TESTNAMEPATTERNS Only run tests which match the given substring
Examples:
run_test.py - run default set of tests
run_test.py MyTestSuite - run suite 'MyTestSuite'
run_test.py MyTestCase.testSomething - run MyTestCase.testSomething
run_test.py MyTestCase - run all 'test*' test methods
in MyTestCase
usage: run_test.py [-h] [--subprocess] [--seed SEED] [--accept] [--jit_executor JIT_EXECUTOR] [--repeat REPEAT] [--test_bailouts]
[--save-xml [SAVE_XML]] [--discover-tests] [--log-suffix LOG_SUFFIX] [--run-parallel RUN_PARALLEL]
[--import-slow-tests [IMPORT_SLOW_TESTS]] [--import-disabled-tests [IMPORT_DISABLED_TESTS]]
optional arguments:
-h, --help show this help message and exit
--subprocess whether to run each test in a subprocess
--seed SEED
--accept
--jit_executor JIT_EXECUTOR
--repeat REPEAT
--test_bailouts
--save-xml [SAVE_XML]
--discover-tests
--log-suffix LOG_SUFFIX
--run-parallel RUN_PARALLEL
--import-slow-tests [IMPORT_SLOW_TESTS]
--import-disabled-tests [IMPORT_DISABLED_TESTS]
```
</p>
</details>
It now prints:
- The general unittest parser help the same way. Should we remove this? We can't merge them unfortunately as inittest does not accept parent / does not expose the parser for us to take it as a parent.
- The combined common_utils + run_test parsers help
<details><summary>Full result</summary>
<p>
```bash
$ python run_test.py -h
usage: run_test.py [-h] [-v] [-q] [--locals] [-f] [-c] [-b] [-k TESTNAMEPATTERNS] [tests [tests ...]]
positional arguments:
tests a list of any number of test modules, classes and test methods.
optional arguments:
-h, --help show this help message and exit
-v, --verbose Verbose output
-q, --quiet Quiet output
--locals Show local variables in tracebacks
-f, --failfast Stop on first fail or error
-c, --catch Catch Ctrl-C and display results so far
-b, --buffer Buffer stdout and stderr during tests
-k TESTNAMEPATTERNS Only run tests which match the given substring
Examples:
run_test.py - run default set of tests
run_test.py MyTestSuite - run suite 'MyTestSuite'
run_test.py MyTestCase.testSomething - run MyTestCase.testSomething
run_test.py MyTestCase - run all 'test*' test methods
in MyTestCase
Ignoring disabled issues: []
usage: run_test.py [-h] [--subprocess] [--seed SEED] [--accept] [--jit_executor JIT_EXECUTOR] [--repeat REPEAT] [--test_bailouts]
[--save-xml [SAVE_XML]] [--discover-tests] [--log-suffix LOG_SUFFIX] [--run-parallel RUN_PARALLEL]
[--import-slow-tests [IMPORT_SLOW_TESTS]] [--import-disabled-tests [IMPORT_DISABLED_TESTS]] [-v] [--jit]
[--distributed-tests] [-core] [-pt] [-c] [-i TESTS [TESTS ...]] [-x TESTS [TESTS ...]] [-f TESTS] [-l TESTS]
[--bring-to-front TESTS [TESTS ...]] [--ignore-win-blocklist] [--continue-through-error]
[--export-past-test-times [EXPORT_PAST_TEST_TIMES]] [--shard SHARD SHARD] [--exclude-jit-executor]
[--exclude-distributed-tests] [--run-specified-test-cases [RUN_SPECIFIED_TEST_CASES]]
[--use-specified-test-cases-by {include,bring-to-front}] [--dry-run]
[additional_unittest_args [additional_unittest_args ...]]
Run the PyTorch unit test suite
positional arguments:
additional_unittest_args
additional arguments passed through to unittest, e.g., python run_test.py -i sparse -- TestSparse.test_factory_size_check
optional arguments:
-h, --help show this help message and exit
--subprocess whether to run each test in a subprocess
--seed SEED
--accept
--jit_executor JIT_EXECUTOR
--repeat REPEAT
--test_bailouts
--save-xml [SAVE_XML]
--discover-tests
--log-suffix LOG_SUFFIX
--run-parallel RUN_PARALLEL
--import-slow-tests [IMPORT_SLOW_TESTS]
--import-disabled-tests [IMPORT_DISABLED_TESTS]
-v, --verbose print verbose information and test-by-test results
--jit, --jit run all jit tests
--distributed-tests, --distributed-tests
run all distributed tests
-core, --core Only run core tests, or tests that validate PyTorch's ops, modules,and autograd. They are defined by CORE_TEST_LIST.
-pt, --pytest If true, use `pytest` to execute the tests. E.g., this runs TestTorch with pytest in verbose and coverage mode: python run_test.py -vci torch -pt
-c, --coverage enable coverage
-i TESTS [TESTS ...], --include TESTS [TESTS ...]
select a set of tests to include (defaults to ALL tests). tests must be a part of the TESTS list defined in run_test.py
-x TESTS [TESTS ...], --exclude TESTS [TESTS ...]
select a set of tests to exclude
-f TESTS, --first TESTS
select the test to start from (excludes previous tests)
-l TESTS, --last TESTS
select the last test to run (excludes following tests)
--bring-to-front TESTS [TESTS ...]
select a set of tests to run first. This can be used in situations where you want to run all tests, but care more about some set, e.g. after making a change to a specific component
--ignore-win-blocklist
always run blocklisted windows tests
--continue-through-error
Runs the full test suite despite one of the tests failing
--export-past-test-times [EXPORT_PAST_TEST_TIMES]
dumps test times from previous S3 stats into a file, format JSON
--shard SHARD SHARD runs a shard of the tests (taking into account other selections), e.g., --shard 2 3 will break up the selected tests into 3 shards and run the tests in the 2nd shard (the first number should not exceed the second)
--exclude-jit-executor
exclude tests that are run for a specific jit config
--exclude-distributed-tests
exclude distributed tests
--run-specified-test-cases [RUN_SPECIFIED_TEST_CASES]
load specified test cases file dumped from previous OSS CI stats, format CSV. If all test cases should run for a <test_module> please add a single row:
test_filename,test_case_name
...
<test_module>,__all__
...
how we use the stats will be based on option "--use-specified-test-cases-by".
--use-specified-test-cases-by {include,bring-to-front}
used together with option "--run-specified-test-cases". When specified test case file is set, this option allows the user to control whether to only run the specified test modules or to simply bring the specified modules to front and also run the remaining modules. Note: regardless of this option, we will only run the specified test cases within a specified test module. For unspecified test modules with the bring-to-front option, all test cases will be run, as one may expect.
--dry-run Only list the test that will run.
where TESTS is any of: benchmark_utils/test_benchmark_utils, distributed/_shard/sharded_optim/test_sharded_optim, distributed/_shard/sharded_tensor/ops/test_binary_cmp, distributed/_shard/sharded_tensor/ops/test_elementwise_ops, distributed/_shard/sharded_tensor/ops/test_embedding, distributed/_shard/sharded_tensor/ops/test_embedding_bag, distributed/_shard/sharded_tensor/ops/test_init, distributed/_shard/sharded_tensor/ops/test_linear, distributed/_shard/sharded_tensor/ops/test_math_ops, distributed/_shard/sharded_tensor/test_megatron_prototype, distributed/_shard/sharded_tensor/test_partial_tensor, distributed/_shard/sharded_tensor/test_sharded_tensor, distributed/_shard/sharded_tensor/test_sharded_tensor_reshard, distributed/_shard/sharding_spec/test_sharding_spec, distributed/_shard/test_replicated_tensor, distributed/algorithms/test_join, distributed/elastic/events/lib_test, distributed/elastic/metrics/api_test, distributed/elastic/multiprocessing/api_test, distributed/elastic/timer/api_test, distributed/elastic/timer/local_timer_example, distributed/elastic/timer/local_timer_test, distributed/elastic/utils/distributed_test, distributed/elastic/utils/logging_test, distributed/elastic/utils/util_test, distributed/fsdp/test_flatten_params_wrapper, distributed/fsdp/test_fsdp_apply, distributed/fsdp/test_fsdp_checkpoint, distributed/fsdp/test_fsdp_clip_grad_norm, distributed/fsdp/test_fsdp_comm, distributed/fsdp/test_fsdp_core, distributed/fsdp/test_fsdp_freezing_weights, distributed/fsdp/test_fsdp_grad_acc, distributed/fsdp/test_fsdp_ignored_modules, distributed/fsdp/test_fsdp_input, distributed/fsdp/test_fsdp_memory, distributed/fsdp/test_fsdp_mixed_precision, distributed/fsdp/test_fsdp_multiple_forward, distributed/fsdp/test_fsdp_multiple_wrapping, distributed/fsdp/test_fsdp_optim_state, distributed/fsdp/test_fsdp_overlap, distributed/fsdp/test_fsdp_pure_fp16, distributed/fsdp/test_fsdp_state_dict, distributed/fsdp/test_fsdp_summon_full_params, distributed/fsdp/test_fsdp_traversal, distributed/fsdp/test_fsdp_uneven, distributed/fsdp/test_shard_utils, distributed/fsdp/test_utils, distributed/fsdp/test_wrap, distributed/nn/jit/test_instantiator, distributed/optim/test_zero_redundancy_optimizer, distributed/pipeline/sync/skip/test_api, distributed/pipeline/sync/skip/test_gpipe, distributed/pipeline/sync/skip/test_inspect_skip_layout, distributed/pipeline/sync/skip/test_leak, distributed/pipeline/sync/skip/test_portal, distributed/pipeline/sync/skip/test_stash_pop, distributed/pipeline/sync/skip/test_tracker, distributed/pipeline/sync/skip/test_verify_skippables, distributed/pipeline/sync/test_balance, distributed/pipeline/sync/test_bugs, distributed/pipeline/sync/test_checkpoint, distributed/pipeline/sync/test_copy, distributed/pipeline/sync/test_deferred_batch_norm, distributed/pipeline/sync/test_dependency, distributed/pipeline/sync/test_inplace, distributed/pipeline/sync/test_microbatch, distributed/pipeline/sync/test_phony, distributed/pipeline/sync/test_pipe, distributed/pipeline/sync/test_pipeline, distributed/pipeline/sync/test_stream, distributed/pipeline/sync/test_transparency, distributed/pipeline/sync/test_worker, distributed/rpc/cuda/test_tensorpipe_agent, distributed/rpc/test_faulty_agent, distributed/rpc/test_tensorpipe_agent, distributed/test_c10d_common, distributed/test_c10d_gloo, distributed/test_c10d_nccl, distributed/test_c10d_spawn_gloo, distributed/test_c10d_spawn_nccl, distributed/test_data_parallel, distributed/test_distributed_spawn, distributed/test_launcher, distributed/test_nccl, distributed/test_pg_wrapper, distributed/test_store, distributions/test_constraints, distributions/test_distributions, lazy/test_bindings, lazy/test_extract_compiled_graph, lazy/test_ts_opinfo, test_ao_sparsity, test_autocast, test_autograd, test_binary_ufuncs, test_bundled_inputs, test_complex, test_cpp_api_parity, test_cpp_extensions_aot_ninja, test_cpp_extensions_aot_no_ninja, test_cpp_extensions_jit, test_cuda, test_cuda_primary_ctx, test_dataloader, test_datapipe, test_deploy, test_deploy, test_dispatch, test_expanded_weights, test_foreach, test_function_schema, test_functional_autograd_benchmark, test_functional_optim, test_functionalization, test_futures, test_fx, test_fx_experimental, test_hub, test_import_stats, test_indexing, test_jit, test_jit_autocast, test_jit_cuda_fuser, test_jit_disabled, test_jit_fuser_legacy, test_jit_fuser_te, test_jit_legacy, test_jit_profiling, test_license, test_linalg, test_logging, test_masked, test_mkldnn, test_mobile_optimizer, test_model_dump, test_module_init, test_modules, test_monitor, test_multiprocessing, test_multiprocessing_spawn, test_namedtensor, test_namedtuple_return_api, test_native_functions, test_nestedtensor, test_nn, test_numba_integration, test_numpy_interop, test_openmp, test_ops, test_ops_gradients, test_ops_jit, test_optim, test_overrides, test_package, test_per_overload_api, test_profiler, test_pruning_op, test_public_bindings, test_python_dispatch, test_pytree, test_quantization, test_reductions, test_scatter_gather_ops, test_serialization, test_set_default_mobile_cpu_allocator, test_shape_ops, test_show_pickle, test_sort_and_select, test_sparse, test_sparse_csr, test_spectral_ops, test_stateless, test_tensor_creation_ops, test_tensorboard, test_tensorexpr, test_tensorexpr_pybind, test_testing, test_torch, test_type_hints, test_type_info, test_type_promotion, test_unary_ufuncs, test_utils, test_view_ops, test_vmap, test_vulkan, test_xnnpack_integration
```
</p>
</details>
### When running anything else (for example `python test_autograd.py -h`)
It did not change and still does:
- The general unittest parser help that we print via a second thread
- The common_utils's parser help
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76152
Approved by: https://github.com/malfet, https://github.com/seemethere
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73873
Basic ShardingPlan interface and Sharder implemention:
1. We provide `ShardingPlan` to allow user to specify all parameter sharding strategies for a given model, this including `plan` for sharding the parameters, and `output_plan` for tagging the output layout, `return_local_tensor` for converting back to DDP.
2. Introduce `shard_module` API, that could take a nn.Module, a ShardingPlan, then shard the module according to the plan.
TODO:
next PR we will introduce Extensible Sharder and ShardingPlanner.
ghstack-source-id: 154682421
Test Plan: test_sharding_plann.py
Reviewed By: pritamdamania87, fduwjj
Differential Revision: D34695159
fbshipit-source-id: 3d695803c4b7e9a7543177ade5b709b5f847baa9
(cherry picked from commit 670cd279b0e5304a9bf0ce6e6651a08273a77035)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73322
These tests have been disabled in OSS CI since #34785.
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D34436844
Pulled By: davidberard98
fbshipit-source-id: c5b14b33e7f369a6fa1e9cfbcb484a30dffc659e
(cherry picked from commit b08f51587c0203c3e8b69f06ea613759e740aa4f)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781
Test Plan: test_replicated_tensor
Reviewed By: pritamdamania87, fduwjj
Differential Revision: D34529374
fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73676
For some reason https://github.com/pytorch/pytorch/pull/72637 ended up in getting messed up during rebasing so please refer to that pr for review history.
This PR creates a new workflow called ` deploy-linux-xenial-cuda11.3-py3.7-gcc7` for torch::deploy tests.
For testing go to https://www.torch-ci.com/pytorch/pytorch/pull/73676 and check if a build and test job occur with ` deploy-linux-xenial-cuda11.3-py3.7-gcc7`
Test Plan: Imported from OSS
Reviewed By: soulitzer
Differential Revision: D34586702
Pulled By: PaliC
fbshipit-source-id: 5627cf4ff411a4a04030f8b7726f84af979da213
(cherry picked from commit df6dddebb9fe078a6053a31033b5a40cc742fcf3)
Fixes#72368
As per reference issue, the test_ops in single file takes around 3:30-4:00Hrs to execute on asan jobs:
Reference : pytorch_test_times.json
```
{
"commit": "39535fec6c3ff5bf7c2d322d096c59571c3295ed",
"JOB_BASE_NAME": "linux-xenial-py3.7-clang7-asan",
"job_times": {
"test_ops": 14928.355000000636, <- This test group is over 4hrs alone
```
----
Hence separating test_ops into following parts:
1. TestGradients
2. TestJit
3. TestCommon and TestMathBits
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74297
Approved by: https://github.com/malfet
Fixes#72368
As per reference issue, the test_ops in single file takes around 3:30-4:00Hrs to execute on asan jobs:
Reference : pytorch_test_times.json
```
{
"commit": "39535fec6c3ff5bf7c2d322d096c59571c3295ed",
"JOB_BASE_NAME": "linux-xenial-py3.7-clang7-asan",
"job_times": {
"test_ops": 14928.355000000636, <- This test group is over 4hrs alone
```
----
Hence separating test_ops into following parts:
1. TestGradients
2. TestJit
3. TestCommon and TestMathBits
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74297
Approved by: https://github.com/malfet
Summary:
Remove fx2trt test from oss CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72595
Test Plan: CI
Reviewed By: houseroad
Differential Revision: D34112595
Pulled By: wushirong
fbshipit-source-id: 02376ef0f25381eff31b72dcbf964c1966af9793
(cherry picked from commit e3d698a942)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69735
We want to build a prototype of Megatron-LM so that we can apply PT-D op to models like transformer and other Meta flagship models like
The basic idea of Megatron-LM is as following:
1. Col-wise sharding of linear weight. Perform the linear op for the first layer.
2. Perform a math op (optional), such as ReLU or GeLU. We use GeLU in our example unit test. The input is from step 1.
3. Row-wise sharing of linear weight. Perform the linear op for the second layer. The input is from step 2.
We then save communications to concatenate the col-wise sharding results and spreading the input to different ranks for row-wise sharding.
The change is as following:
1. Return a ShardedTensor for the col-wise sharding in the sharded_linear op.
2. Return a PartialTensors for the row-wise sharding in the sharded_linear op.
3. Leverage APIs already defined for `reshard` to merge/aggregate local results to a fully sync local result if needed.
4. Add helper function to create sharded tensor based on the local result.
5. Add a unit test to test the Megatron-LM idea mentioned above and compare with local ops, including the grad and optimizer so that we can ensure the correctness of the implementation.
6. Refactor the unit test of sharded linear to reflect the changes in the code.
ghstack-source-id: 148273049
Test Plan: Unit test + CI
Reviewed By: pritamdamania87
Differential Revision: D32978221
fbshipit-source-id: 565fc92e7807e19d53b0261f8ace3945bef69e3e
(cherry picked from commit 344abe7520)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70079
We defined a new concept named `PartialTensor`, which is an abstraction to represent Tensors that need aggregation across multiple devices and multiple processes.
We also defined a API `reshard_output` to reshard a `PartialTensor` to `Tensor` or reshard a `ShardedTensor` to `ShardedTensor/Tensor`. This is done via class `ModuleResharder` which acts like a wrapper of original modules plus the a reshard in the final step.
The `reshard` logic is defined in each class (`ShardedTensor` and `PartialTensor`).
ghstack-source-id: 148273050
Test Plan: Unit test is in the next PR.
Reviewed By: pritamdamania87
Differential Revision: D33121037
fbshipit-source-id: 5f56617ea526b857c5b73df6e069697d428ec359
(cherry picked from commit 58b1457cbc)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141
We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.
As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861
Test Plan: waitforbuildbot
Reviewed By: fduwjj
Differential Revision: D33904585
fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7af)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71742
We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.
As a result, organizing all of this under the `torch.distributed.shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 147899768
Test Plan: waitforbuildbot
Reviewed By: fduwjj, wanchaol
Differential Revision: D33755913
fbshipit-source-id: dc692b31e2607063d55dfcb3db33ec53961d5a5b
(cherry picked from commit 5b6885f358)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70145
Added support for torch.equal to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.
ghstack-source-id: 146066939
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33201714
fbshipit-source-id: 56adfc36e345d512c9901c56c07759bf658c745b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69734
Added support for `torch.equal` to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.
Will implement `allclose` in a follow PR.
ghstack-source-id: 145301451
Test Plan: waitforbuildbot
Reviewed By: fduwjj, wanchaol
Differential Revision: D33004315
fbshipit-source-id: 786fe26baf82e1bb4fecfdbfc9ad4b64e704877f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68607
This PR added ShardedOptimizer and a API to get module parameters along with ShardedTensor param, it allows user to use this Optimizer Wrapper to construct a optimizer that involves ShardedTensor
The state_dict support will be a follow up diff
ghstack-source-id: 145532834
Test Plan: python test_sharded_optim.py
Reviewed By: pritamdamania87
Differential Revision: D32539994
fbshipit-source-id: a3313c6870d1f1817fc3e08dc2fc27dc43bef743