pytorch/torch/testing
pritam a81be44410 Fix shard_module to appropriately deal with sub process groups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79264

`shard_module` API didn't work correctly with a sub-pg since
`dist.scatter` actually takes the global rank as input for `src`.

Fixing this by passing in the appropriate rank to `dist.scatter`

Differential Revision: [D37062766](https://our.internmc.facebook.com/intern/diff/D37062766/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-12 03:50:45 +00:00
..
_internal Fix shard_module to appropriately deal with sub process groups. 2022-06-12 03:50:45 +00:00
__init__.py cleanup torch.testing namespace (#72708) 2022-02-25 06:30:31 +00:00
_comparison.py move MPS compat into common comparison machinery (#77836) 2022-06-08 08:09:18 +00:00
_creation.py [chalf] div, eq, masked_fill, index_put (#77479) 2022-05-18 17:01:08 +00:00
_deprecated.py promote torch.testing to stable (#73348) 2022-02-25 06:30:31 +00:00
_legacy.py [complex32] real and imag (also remove unused real and imag kernels) 2022-05-05 04:36:58 +00:00