mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47534 Test Plan: Imported from OSS Reviewed By: walterddr Differential Revision: D24952497 Pulled By: xuzhao9 fbshipit-source-id: 063bfd0707198436fcfd9431f72f9a392bc0017e
This commit is contained in:
parent
7f66fa62ca
commit
49f0e5dfeb
12
mypy.ini
12
mypy.ini
|
|
@ -53,18 +53,6 @@ ignore_errors = True
|
||||||
[mypy-torch.backends._nnapi.*]
|
[mypy-torch.backends._nnapi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.distributed.*]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.distributed.rpc.*]
|
|
||||||
ignore_errors = False
|
|
||||||
|
|
||||||
[mypy-torch.distributed.distributed_c10d.*]
|
|
||||||
ignore_errors = False
|
|
||||||
|
|
||||||
[mypy-torch.distributed.nn.*]
|
|
||||||
ignore_errors = False
|
|
||||||
|
|
||||||
[mypy-torch.testing._internal.hypothesis_utils.*]
|
[mypy-torch.testing._internal.hypothesis_utils.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ class DeferredBatchNorm(_BatchNorm):
|
||||||
|
|
||||||
sum: Tensor
|
sum: Tensor
|
||||||
sum_squares: Tensor
|
sum_squares: Tensor
|
||||||
|
running_mean: Tensor
|
||||||
|
running_var: Tensor
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -149,6 +151,8 @@ class DeferredBatchNorm(_BatchNorm):
|
||||||
if module.affine:
|
if module.affine:
|
||||||
module_output.register_parameter("weight", module.weight)
|
module_output.register_parameter("weight", module.weight)
|
||||||
module_output.register_parameter("bias", module.bias)
|
module_output.register_parameter("bias", module.bias)
|
||||||
|
assert isinstance(module.running_mean, Tensor)
|
||||||
|
assert isinstance(module.running_var, Tensor)
|
||||||
module_output.register_buffer("running_mean", module.running_mean)
|
module_output.register_buffer("running_mean", module.running_mean)
|
||||||
module_output.register_buffer("running_var", module.running_var)
|
module_output.register_buffer("running_var", module.running_var)
|
||||||
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
|
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ import threading
|
||||||
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import ByteTensor, Tensor
|
from torch import Tensor
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
|
|
||||||
from .dependency import fork, join
|
from .dependency import fork, join
|
||||||
|
|
@ -45,7 +45,7 @@ TensorOrTensors = Union[Tensor, Tensors]
|
||||||
|
|
||||||
# Types for shared memory between Checkpoint and Recompute.
|
# Types for shared memory between Checkpoint and Recompute.
|
||||||
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
|
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
|
||||||
RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state)
|
RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -207,7 +207,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
|
||||||
"""
|
"""
|
||||||
cpu_rng_state = torch.get_rng_state()
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
|
||||||
gpu_rng_state: Optional[ByteTensor]
|
gpu_rng_state: Optional[Tensor]
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
gpu_rng_state = torch.cuda.get_rng_state(device)
|
gpu_rng_state = torch.cuda.get_rng_state(device)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ Tensors = Tuple[Tensor, ...]
|
||||||
TensorOrTensors = Union[Tensor, Tensors]
|
TensorOrTensors = Union[Tensor, Tensors]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Module = nn.Module[TensorOrTensors]
|
# Typechecking: nn.Module is not a Generic
|
||||||
|
Module = nn.Module[TensorOrTensors] # type: ignore[type-arg]
|
||||||
NamedModules = OrderedDict[str, Module]
|
NamedModules = OrderedDict[str, Module]
|
||||||
else:
|
else:
|
||||||
Module = nn.Module
|
Module = nn.Module
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,8 @@ TensorOrTensors = Union[Tensor, Tensors]
|
||||||
StashPop = Union["stash", "pop"]
|
StashPop = Union["stash", "pop"]
|
||||||
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
|
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]
|
# Typechecking: nn.Module is not a Generic
|
||||||
|
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg]
|
||||||
else:
|
else:
|
||||||
SkippableModule = nn.Module
|
SkippableModule = nn.Module
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,8 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
|
||||||
#
|
#
|
||||||
tensor = tensor.new_empty([0]).set_(tensor.storage())
|
tensor = tensor.new_empty([0]).set_(tensor.storage())
|
||||||
|
|
||||||
tensor.record_stream(as_cuda(stream))
|
# Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
|
||||||
|
tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def is_cuda(stream: AbstractStream) -> bool:
|
def is_cuda(stream: AbstractStream) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def allreduce_hook(
|
def allreduce_hook(
|
||||||
process_group: object, bucket: dist._GradBucket
|
process_group: dist.ProcessGroup, bucket: dist._GradBucket
|
||||||
) -> torch.futures.Future:
|
) -> torch.futures.Future:
|
||||||
"""
|
"""
|
||||||
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
|
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
|
||||||
|
|
@ -32,7 +32,7 @@ def allreduce_hook(
|
||||||
|
|
||||||
|
|
||||||
def fp16_compress_hook(
|
def fp16_compress_hook(
|
||||||
process_group: object, bucket: dist._GradBucket
|
process_group: dist.ProcessGroup, bucket: dist._GradBucket
|
||||||
) -> torch.futures.Future:
|
) -> torch.futures.Future:
|
||||||
"""
|
"""
|
||||||
This DDP communication hook implements a simple gradient compression
|
This DDP communication hook implements a simple gradient compression
|
||||||
|
|
@ -79,7 +79,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size):
|
||||||
|
|
||||||
|
|
||||||
def _allgather_then_aggregate_hook(
|
def _allgather_then_aggregate_hook(
|
||||||
process_group: object, bucket: dist._GradBucket
|
process_group: dist.ProcessGroup, bucket: dist._GradBucket
|
||||||
) -> torch.futures.Future:
|
) -> torch.futures.Future:
|
||||||
"""
|
"""
|
||||||
Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
|
Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size):
|
||||||
|
|
||||||
|
|
||||||
def quantization_pertensor_hook(
|
def quantization_pertensor_hook(
|
||||||
process_group: object, bucket: dist._GradBucket
|
process_group: dist.ProcessGroup, bucket: dist._GradBucket
|
||||||
) -> torch.futures.Future:
|
) -> torch.futures.Future:
|
||||||
"""
|
"""
|
||||||
Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
|
Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
|
||||||
|
|
@ -118,7 +118,7 @@ def quantization_pertensor_hook(
|
||||||
|
|
||||||
|
|
||||||
def quantization_perchannel_hook(
|
def quantization_perchannel_hook(
|
||||||
process_group: object, bucket: dist._GradBucket, bucket_size=512
|
process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512
|
||||||
) -> torch.futures.Future:
|
) -> torch.futures.Future:
|
||||||
"""
|
"""
|
||||||
Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
|
Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,7 @@ import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, REMAINDER
|
from argparse import ArgumentParser, REMAINDER
|
||||||
|
from typing import Optional, IO
|
||||||
|
|
||||||
node_local_rank_stdout_filename = "node_{}_local_rank_{}_stdout"
|
node_local_rank_stdout_filename = "node_{}_local_rank_{}_stdout"
|
||||||
node_local_rank_stderr_filename = "node_{}_local_rank_{}_stderr"
|
node_local_rank_stderr_filename = "node_{}_local_rank_{}_stderr"
|
||||||
|
|
@ -269,6 +270,8 @@ def main():
|
||||||
|
|
||||||
cmd.extend(args.training_script_args)
|
cmd.extend(args.training_script_args)
|
||||||
|
|
||||||
|
stdout_handle: Optional[IO]
|
||||||
|
stderr_handle: Optional[IO]
|
||||||
if args.logdir:
|
if args.logdir:
|
||||||
directory_path = os.path.join(os.getcwd(), args.logdir)
|
directory_path = os.path.join(os.getcwd(), args.logdir)
|
||||||
node_rank = args.node_rank
|
node_rank = args.node_rank
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
try:
|
try:
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from urlparse import urlparse, urlunparse
|
raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.")
|
||||||
|
|
||||||
import torch._six as six
|
import torch._six as six
|
||||||
import numbers
|
import numbers
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Optional, Dict, Union
|
||||||
from torch._C._distributed_c10d import FileStore
|
from torch._C._distributed_c10d import FileStore
|
||||||
from .constants import default_pg_timeout
|
from .constants import default_pg_timeout
|
||||||
|
|
||||||
|
|
@ -47,7 +49,7 @@ def register_rendezvous_handler(scheme, handler):
|
||||||
_rendezvous_handlers[scheme] = handler
|
_rendezvous_handlers[scheme] = handler
|
||||||
|
|
||||||
|
|
||||||
def rendezvous(url, rank=-1, world_size=-1, **kwargs):
|
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||||
if not isinstance(url, six.string_classes):
|
if not isinstance(url, six.string_classes):
|
||||||
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
|
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
|
||||||
|
|
||||||
|
|
@ -60,8 +62,9 @@ def rendezvous(url, rank=-1, world_size=-1, **kwargs):
|
||||||
# Append node-specific arguments.
|
# Append node-specific arguments.
|
||||||
result = urlparse(url)
|
result = urlparse(url)
|
||||||
if rank != -1 or world_size != -1:
|
if rank != -1 or world_size != -1:
|
||||||
query_dict = dict(
|
query_dict: Dict[str, Union[int, str]] = dict(
|
||||||
pair.split("=") for pair in filter(None, result.query.split("&"))
|
# mypy doesn't allow dict() to accept List of values (#257)
|
||||||
|
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
"rank" not in query_dict and "world_size" not in query_dict
|
"rank" not in query_dict and "world_size" not in query_dict
|
||||||
|
|
@ -87,7 +90,7 @@ def _rendezvous_error(msg):
|
||||||
return ValueError("Error initializing torch.distributed using " + msg)
|
return ValueError("Error initializing torch.distributed using " + msg)
|
||||||
|
|
||||||
|
|
||||||
def _file_rendezvous_handler(url, **kwargs):
|
def _file_rendezvous_handler(url: str, **kwargs):
|
||||||
def _error(msg):
|
def _error(msg):
|
||||||
return _rendezvous_error("file:// rendezvous: " + msg)
|
return _rendezvous_error("file:// rendezvous: " + msg)
|
||||||
|
|
||||||
|
|
@ -99,7 +102,9 @@ def _file_rendezvous_handler(url, **kwargs):
|
||||||
|
|
||||||
if not path:
|
if not path:
|
||||||
raise _error("path missing")
|
raise _error("path missing")
|
||||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
query: Dict[str, str]
|
||||||
|
# mypy doesn't allow dict() to accept List of values (#257)
|
||||||
|
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||||
if "rank" not in query:
|
if "rank" not in query:
|
||||||
raise _error("rank parameter missing")
|
raise _error("rank parameter missing")
|
||||||
if "world_size" not in query:
|
if "world_size" not in query:
|
||||||
|
|
@ -114,14 +119,16 @@ def _file_rendezvous_handler(url, **kwargs):
|
||||||
raise RuntimeError("Unable to perform rerendezvous using file:// method")
|
raise RuntimeError("Unable to perform rerendezvous using file:// method")
|
||||||
|
|
||||||
|
|
||||||
def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
|
||||||
def _error(msg):
|
def _error(msg):
|
||||||
return _rendezvous_error("tcp:// rendezvous: " + msg)
|
return _rendezvous_error("tcp:// rendezvous: " + msg)
|
||||||
|
|
||||||
result = urlparse(url)
|
result = urlparse(url)
|
||||||
if not result.port:
|
if not result.port:
|
||||||
raise _error("port number missing")
|
raise _error("port number missing")
|
||||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
query: Dict[str, Union[int, str]]
|
||||||
|
# mypy doesn't allow dict() to accept List of values (#257)
|
||||||
|
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||||
if "rank" not in query:
|
if "rank" not in query:
|
||||||
raise _error("rank parameter missing")
|
raise _error("rank parameter missing")
|
||||||
if "world_size" not in query:
|
if "world_size" not in query:
|
||||||
|
|
@ -130,6 +137,7 @@ def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||||
rank = int(query["rank"])
|
rank = int(query["rank"])
|
||||||
world_size = int(query["world_size"])
|
world_size = int(query["world_size"])
|
||||||
start_daemon = rank == 0
|
start_daemon = rank == 0
|
||||||
|
assert result.hostname is not None
|
||||||
store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
|
store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
|
||||||
yield (store, rank, world_size)
|
yield (store, rank, world_size)
|
||||||
|
|
||||||
|
|
@ -137,7 +145,7 @@ def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||||
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
|
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
|
||||||
|
|
||||||
|
|
||||||
def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
|
||||||
def _error(msg):
|
def _error(msg):
|
||||||
return _rendezvous_error("env:// rendezvous: " + msg)
|
return _rendezvous_error("env:// rendezvous: " + msg)
|
||||||
|
|
||||||
|
|
@ -145,7 +153,13 @@ def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||||
return _error("environment variable %s expected, but not set" % var)
|
return _error("environment variable %s expected, but not set" % var)
|
||||||
|
|
||||||
result = urlparse(url)
|
result = urlparse(url)
|
||||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
query: Dict[str, Union[int, str]]
|
||||||
|
# mypy doesn't allow dict() to accept List of values (#257)
|
||||||
|
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||||
|
|
||||||
|
rank: Optional[Union[str, int]]
|
||||||
|
world_size: Optional[Union[str, int]]
|
||||||
|
master_port: Optional[Union[str, int]]
|
||||||
|
|
||||||
if "rank" in query:
|
if "rank" in query:
|
||||||
rank = int(query["rank"])
|
rank = int(query["rank"])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user