Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48670
Support an optional error feedback for PowerSGD -- storing the difference (i.e., the local error caused by compression) between the input gradient (adjusted by the existing error) and the gradient after decompression, and reinserting it at the next iteration.
Still need to add an index field to GradBucket as the key of error_dict. This is because the current key, input tensor of the bucket, can change across steps, as the buckets may be rebuilt in forward pass in order to save peak memory usage.
This is halfway of error feedback. Plan to add the new index field in a separate PR.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117636492
Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl
Reviewed By: rohan-varma
Differential Revision: D25240290
fbshipit-source-id: 5b6e11e711caccfb8984ac2767dd107dbf4c9b3b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48507
Previously the random seed is the length of input tensor, which is not guaranteed to be the different for different batches. Now initialize a random generator in PowerSGD state, and use this generator to create a random seed to randomize the low-rank tensor Q at every step.
Therefore, the initial tensor Q should be the same across all the replicas at the same step, but different at different steps.
'torch.manual_seed' is used in the same way as https://github.com/epfml/powersgd/blob/master/gradient_reducers.py#L675
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117483639
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl
buck test mode/dev-nosan caffe2/test/distributed:c10d --
test_powerSGD_ddp_comm_hook_nccl_grad_is_view
Also checked the initial Qs and input random seeds of torch.manual_seed() of different ranks for a few steps in real runs.
Example logs:
Exactly same random seed of different ranks at the same step on two nodes, and the random seed varies at each step.
{F346971916}
Reviewed By: rohan-varma
Differential Revision: D25191589
fbshipit-source-id: f7f17df3ad2075ecae1a2a56ca082160f7c5fcfc
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48348
To support the features like error feedback, warm start, PowerSGD comm hook needs to maintain a state besides process group. Currently this state only includes a process group and a matrix approximation rank config.
This diff is a pure refactoring. Plan to add more state fields later.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117305280
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl
buck test mode/dev-nosan caffe2/test/distributed:c10d --
test_powerSGD_ddp_comm_hook_nccl_grad_is_view
Reviewed By: rohan-varma
Differential Revision: D25137962
fbshipit-source-id: cd72b8b01e20f80a92c7577d22f2c96e9eebdc52
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48253
Explained why a hand-crafted orthogonalize function is used instead of `torch.qr`.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117132622
Test Plan: N/A
Reviewed By: rohan-varma
Differential Revision: D25088607
fbshipit-source-id: ebc228afcb4737bb8529e7143ea170086730520e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48060
Implement a PowerSGD variant that applies to a batched flattened tensor with zero paddings.
This version does not require handling 1D tensors and multi-dimenionsal tensors in the input separately, and hence it does not need to create two asyncrhonous future chains.
Potential optimizations:
1) Consider FP16 compression throughout PowerSGD.
2) Warm start and save one matrix multiplication per ieration.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117105938
Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_default_ddp_comm_hooks_nccl
Reviewed By: jiayisuse
Differential Revision: D24843692
fbshipit-source-id: f44200b1fd6e12e829fc543d21ab7ae086769561
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47158
1. Test the default Python comm hook implementations ALLREDUCE and FP16_COMPRESS, besides an ad-hoc all-reduce implementation.
2. Typo fix.
3. Reformat default_hooks.py.
4. Publish register_comm_hook API for DDP module (This should be done in a separate diff, but got merged unintentionally.)
The new style can be used for testing any new comm hook like PowerSGD easily.
Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 116012600
Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_default_ddp_comm_hooks_nccl
Reviewed By: rohan-varma
Differential Revision: D24669639
fbshipit-source-id: 048c87084234edc2398f0ea6f01f2f083a707939
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47270
This is almost same as #46959, except that in caffe2/torch/nn/parallel/distributed.py, BuiltinCommHookType should be imported conditionally, only when dist.is_available(). Otherwise, this Python enum type defined in caffe2/torch/scrc/distributed/c10d/init.cpp cannot be imported. See https://github.com/pytorch/pytorch/issues/47153
I tried to follow another enum type enum type ReduceOp defined in the same file, but did not work, because the C++ enum class is defined torch/lib/c10d library, but BuiltinCommHookType is defined in torch/csrc/distributed library. These two libraries are compiled in two different ways.
To avoid adding typing to distributed package, which can be a new project, I simply removed the arg type of BuiltinCommHookType in this file.
To review the diff on top of #46959, compare V1 vs Latest:
https://www.internalfb.com/diff/D24700959?src_version_fbid=270445741055617
Main Changes in V1 (#46959):
1. Implemented the Pybind part.
2. In the reducer, once the builtin_comm_hook_type is set, a c++ comm hook instance will be created in Reducer::autograd_hook.
3. Added unit tests for the builit-in comm hooks.
Original PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348
ghstack-source-id: 115783237
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_builtin_ddp_comm_hooks_nccl
//arvr/projects/eye_tracking/Masquerade:python_test
USE_DISTRIBUTED=0 USE_GLOO=0 BUILD_TEST=0 USE_CUDA=1 USE_MKLDNN=0 DEBUG=0 python setup.py install
Reviewed By: mrshenli
Differential Revision: D24700959
fbshipit-source-id: 69f303a48ae275aa856e6e9b50e12ad8602e1c7a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46959
1. Implemented the Pybind part.
2. In the reducer, once the builtin_comm_hook_type is set, a c++ comm hook instance will be created in Reducer::autograd_hook.
3. Added unit tests for the builit-in comm hooks.
Original PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348
ghstack-source-id: 115629230
Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_builtin_ddp_comm_hooks_nccl
Reviewed By: pritamdamania87
Differential Revision: D24471910
fbshipit-source-id: f96b752298549ea2067e2568189f1b394abcd99a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46078
The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.
#Closes: https://github.com/pytorch/pytorch/issues/45968
ghstack-source-id: 113996453
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_accumulate_gradients_no_sync_allreduce_hook
Also verified the decrease in memory consumption with some toy modeling exmaples.
Reviewed By: pritamdamania87
Differential Revision: D24178118
fbshipit-source-id: 453d0b52930809bd836172936b77abd69610237a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44643
This method is not used anywhere else.
Also formatted the file.
Test Plan: buck test caffe2/test/distributed/algorithms/ddp_comm_hooks:test_ddp_hooks
Reviewed By: pritamdamania87
Differential Revision: D23675945
fbshipit-source-id: 2d04f94589a20913e46b8d71e6a39b70940c1461
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43310
In this diff, we prepared some example DDP communication hooks [#40848](https://github.com/pytorch/pytorch/pull/40848):
1\. `allreduce_hook`: This DDP communication hook just calls ``allreduce`` using ``GradBucket`` tensors. Once gradient tensors are aggregated across all workers, its ``then`` callback takes the mean and returns the result. If user registers this hook DDP results is expected to be same as the case where no hook was registered. Hence, this won't change behavior of DDP and user can use this as a reference or modify this hook to log useful information or any other purposes while unaffecting DDP behavior.
2\. `allgather_then_aggregate_hook` Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors and its ``then`` callback aggregates the gathered gradient tensors and takes mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with W workers, both the computation and communication time scale as O(W) for allgather compared to O(logW) for allreduce. Therefore, this hook is expected to be much slower than ``allreduce_hook`` although both essentially do the same thing with the gradients.
3\. `fp16_compress_hook` This DDP communication hook implements a simple gradient compression approach that converts ``GradBucket`` tensors whose type is assumed to be ``torch.float32`` to half-precision floating point format (``torch.float16``). It allreduces those ``float16`` gradient tensors. Once compressed gradient tensors are allreduced, its then callback called ``decompress`` converts the aggregated result back to ``float32`` and takes the mean.
4\. `quantization_pertensor_hook` does quantization per tensor and uses the idea in https://pytorch.org/docs/master/generated/torch.quantize_per_tensor.html. Note that we separately send scale and zero_point (two floats per rank) before quantized tensors.
5\. `quantization_perchannel_hook` does quantization per channel similar to https://pytorch.org/docs/master/generated/torch.quantize_per_channel.html. The main motivation is that after the initial QSGD study diff, we realized that for considerably large gradient tensors such as a tensor that contains 6 million floats quantizing dividing it into smaller channels (512 float chunks) and quantizing independently may significantly increase the resolution and result with lower error.
ghstack-source-id: 110923269
Test Plan:
python torch/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py
Couldn't download test skip set, leaving all tests enabled...
.....
----------------------------------------------------------------------
Ran 4 tests in 26.724s
OK
Internal testing:
```
buck run mode/dev-nosan //caffe2/test/distributed/algorithms/ddp_comm_hooks:test_ddp_hooks
```
Reviewed By: malfet
Differential Revision: D22937999
fbshipit-source-id: 274452e7932414570999cb978ae77a97eb3fb0ec