diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 0fceb2137a3..7d64ed8a417 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -329,7 +329,7 @@ class DistributedDataParallel(Module): Example:: >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') - >>> net = torch.nn.DistributedDataParallel(model, pg) + >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) """ def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, @@ -626,7 +626,7 @@ class DistributedDataParallel(Module): Example:: - >>> ddp = torch.nn.DistributedDataParallel(model, pg) + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> with ddp.no_sync(): >>> for input in inputs: >>> ddp(input).backward() # no synchronization, accumulate grads