mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Closes https://github.com/pytorch/pytorch/issues/18724 Pull Request resolved: https://github.com/pytorch/pytorch/pull/19089 Differential Revision: D16073654 Pulled By: ezyang fbshipit-source-id: 5642179651ce45ab7c5a46cc1fcc4fd6b37fa71c
28 lines
1022 B
Python
28 lines
1022 B
Python
from ..modules import Module
|
|
from typing import Any, Optional, TypeVar
|
|
from .common_types import _devices_t, _device_t
|
|
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
|
|
|
|
class DistributedDataParallel(Module[T_co]):
|
|
process_group: Any = ...
|
|
dim: int = ...
|
|
module: Module[T_co] = ...
|
|
device_ids: _devices_t = ...
|
|
output_device: _device_t = ...
|
|
broadcast_buffers: bool = ...
|
|
check_reduction: bool = ...
|
|
broadcast_bucket_size: float = ...
|
|
bucket_bytes_cap: float = ...
|
|
|
|
# TODO type process_group once `distributed` module is stubbed
|
|
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ...,
|
|
output_device: Optional[_device_t] = ..., dim: int = ...,
|
|
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
|
|
check_reduction: bool = ...) -> None: ...
|
|
|
|
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
|
|
|
|
def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ...
|