pytorch/torch/distributed
Yi Wang 17f53bffef [Gradient Compression] Replace the key of error_dict in PowerSGD state with bucket index (#48867)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48867

Previously the key of error_dict is the hashcode of tensor. Now replaced with bucket index.

Bucket index can have a few advantages over the hashcode of tensor.
1) Error dict in the state never removes any key. If the bucket rebuild process occurs frequently, the size of error dict can increase. For now, such rebuild process is infrequent, so it is probably fine.

2) Integer index has a better readability than hashcode, and it can facilitate debugging.
If the user wants to debug the tensor values, usually only a specific bucket needs to be targeted. It's easy to specify such condition (e..g, bucket_index = 0), but it's hard to specify a hashcode in advance, as it can only be determined at runtime.

Note that sometimes the buckets can be rebuilt in the forward pass. In this case, the shape of the bucket with the same index will not be consistent with the one in the previous iteration, and hence the error tensor will be re--initialized as a zero tensor of the new shape. Therefore, `and state.error_dict[bucket_index].shape[0] == padded_total_length` is added to the condition of applying the local error from the previous iteration.

Deleted the arg type of `dist._GradBucket` in powerSGD_hook.py, because somehow test_run_mypy - TestTypeHints failed:
AssertionError: mypy failed: torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py:128: error: "_GradBucket" has no attribute "get_index"  [attr-defined]

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 117951402

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

Reviewed By: rohan-varma

Differential Revision: D25346347

fbshipit-source-id: 8348aa103002ec1c69e3ae759504b431140b3b0d
2020-12-05 23:53:27 -08:00
..
_pipeline Remove balance and devices parameter from Pipe. (#48432) 2020-12-01 11:21:59 -08:00
algorithms [Gradient Compression] Replace the key of error_dict in PowerSGD state with bucket index (#48867) 2020-12-05 23:53:27 -08:00
autograd Add Python declaration of torch._C and torch._C._autograd modules. (#46622) 2020-11-06 01:25:47 -08:00
benchmarks Benchmark combining Distributed Data Parallel and Distributed RPC (#46993) 2020-11-04 18:53:19 -08:00
nn Fix typing errors in torch.distributed.nn.* directory. (#47533) 2020-11-16 23:27:55 -08:00
optim [dist_optim] serialize compilation when creating dist_optim (#45871) 2020-10-07 15:10:41 -07:00
rpc RRef proxy support for ScriptModule methods (#48339) 2020-12-04 11:33:16 -08:00
__init__.py Enable TCPStore on Windows (#47749) 2020-12-03 08:32:01 -08:00
constants.py Add NCCL_ASYNC_ERROR_HANDLING to docs (#46856) 2020-10-26 14:41:32 -07:00
CONTRIBUTING.md Move python-independent c10d implementations to torch/lib (#47309) 2020-11-03 23:39:54 -08:00
distributed_c10d.py scatter_object_list API for c10d (#43930) 2020-12-04 18:55:57 -08:00
launch.py Fix typing errors in torch.distributed.*, close issue #42967. (#47534) 2020-11-16 23:27:59 -08:00
rendezvous.py Enable TCPStore on Windows (#47749) 2020-12-03 08:32:01 -08:00