mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75753 As per the design in https://github.com/pytorch/pytorch/issues/72138, convert DDP parameters to ReplicatedTensor during its forward pass. Concretely, this is done as follows: 1) Create a separate `_replicated_tensor_module` which is a copy of self.module without creating copies of the Tensors themselves. 2) Use `_replicated_tensor_module` instead of `self.module` during the forward pass. 3) Have a context manager `_ddp_replicated_tensor` to enable this, since certain edge cases can fail where self.module is changed out of band resulting in discrepancy between self.module and `_replicated_tensor_module`. Differential Revision: [D35533736](https://our.internmc.facebook.com/intern/diff/D35533736/) Approved by: https://github.com/wanchaol, https://github.com/rohan-varma
5 lines
142 B
Python
5 lines
142 B
Python
import torch
|
|
if torch.distributed.rpc.is_available():
|
|
from .api.remote_module import RemoteModule
|
|
from .functional import * # noqa: F403
|