mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67828 as titled ghstack-source-id: 143781976 Test Plan: wait for ci Reviewed By: mrshenli Differential Revision: D32165576 fbshipit-source-id: 40c04b74f9e3241d3b3d64dee53af01fcfd1018b
889 lines
31 KiB
Python
889 lines
31 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import copy
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from datetime import timedelta
|
|
from itertools import product
|
|
from sys import platform
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
if not dist.is_available():
|
|
print("distributed package not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
import torch.distributed.distributed_c10d as c10d
|
|
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
|
|
import torch.nn.functional as F
|
|
import torch.testing._internal.common_utils as common
|
|
from torch import nn
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
load_tests,
|
|
run_tests,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
if platform == "darwin":
|
|
LOOPBACK = "lo0"
|
|
else:
|
|
LOOPBACK = "lo"
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
|
|
def gpus_for_rank(world_size):
|
|
"""Multigpu tests are designed to simulate the multi nodes with multi
|
|
GPUs on each node. Nccl backend requires equal #GPUs in each process.
|
|
On a single node, all visible GPUs are evenly
|
|
divided to subsets, each process only uses a subset.
|
|
"""
|
|
visible_devices = list(range(torch.cuda.device_count()))
|
|
gpus_per_process = torch.cuda.device_count() // world_size
|
|
gpus_for_rank = []
|
|
for rank in range(world_size):
|
|
gpus_for_rank.append(
|
|
visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
|
|
)
|
|
return gpus_for_rank
|
|
|
|
|
|
class AbstractTimeoutTest(object):
|
|
def _test_store_timeout(self, backend, init_method, c2p):
|
|
try:
|
|
dist.init_process_group(
|
|
backend=backend,
|
|
init_method=init_method,
|
|
world_size=1,
|
|
rank=0,
|
|
timeout=timedelta(seconds=1),
|
|
)
|
|
default_store = c10d._get_default_store()
|
|
tik = time.time()
|
|
with self.assertRaisesRegex(RuntimeError, "Timeout"):
|
|
default_store.get("nonexistent key")
|
|
tok = time.time()
|
|
dist.destroy_process_group()
|
|
c2p.append(float(tok - tik))
|
|
except RuntimeError as e:
|
|
# catch "Address already in use" error and report it to the main
|
|
# thread
|
|
c2p.append(e)
|
|
|
|
def _init_methods(self):
|
|
f = tempfile.NamedTemporaryFile(delete=False)
|
|
if sys.platform == "win32":
|
|
yield "file:///%s" % f.name.replace("\\", "/")
|
|
f.close()
|
|
else:
|
|
yield "file://%s" % f.name
|
|
f.close()
|
|
yield "tcp://127.0.0.1:%d" % common.find_free_port()
|
|
|
|
def _test_default_store_timeout(self, backend):
|
|
for init_method in self._init_methods():
|
|
c2p = []
|
|
t = threading.Thread(
|
|
target=self._test_store_timeout, args=(backend, init_method, c2p)
|
|
)
|
|
t.daemon = True
|
|
t.start()
|
|
t.join(5)
|
|
|
|
self.assertEqual(1, len(c2p))
|
|
if isinstance(c2p[0], float):
|
|
# waiting time should be 1s, use 3s to rule out false alarm
|
|
self.assertGreater(3, c2p[0])
|
|
elif isinstance(c2p[0], RuntimeError):
|
|
# let @retry_on_connect_failures handle the error
|
|
raise c2p[0]
|
|
else:
|
|
raise RuntimeError("Unexpected type {}".format(type(c2p[0])))
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.fc1 = nn.Linear(2, 10, bias=False)
|
|
self.fc2 = nn.Linear(10, 50, bias=False)
|
|
self.fc3 = nn.Linear(50, 4, bias=False)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.fc1(x))
|
|
x = self.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return F.softmax(x, dim=1)
|
|
|
|
|
|
class DoubleGpuNet(nn.Module):
|
|
def __init__(self, gpus):
|
|
super(DoubleGpuNet, self).__init__()
|
|
self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
|
|
self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
|
|
self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1])
|
|
self.relu = nn.ReLU()
|
|
self.no_grad_param = nn.Parameter(
|
|
torch.tensor([2, 2]).long(), requires_grad=False
|
|
).to(gpus[0])
|
|
|
|
def forward(self, x):
|
|
dev0 = self.fc1.weight.device
|
|
dev1 = self.fc2.weight.device
|
|
x = self.relu(self.fc1(x.to(dev0)))
|
|
x = self.relu(self.fc2(x.to(dev1)))
|
|
x = self.fc3(x)
|
|
return F.softmax(x, dim=1).to(dev0)
|
|
|
|
|
|
class QuadraGpuNet(nn.Module):
|
|
def __init__(self, gpus):
|
|
super(QuadraGpuNet, self).__init__()
|
|
self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
|
|
self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
|
|
self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2])
|
|
self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3])
|
|
self.relu = nn.ReLU()
|
|
self.no_grad_param = nn.Parameter(
|
|
torch.tensor([2, 2]).long(), requires_grad=False
|
|
).to(gpus[0])
|
|
|
|
def forward(self, x):
|
|
dev0 = self.fc1.weight.device
|
|
dev1 = self.fc2.weight.device
|
|
dev2 = self.fc3.weight.device
|
|
dev3 = self.fc4.weight.device
|
|
x = self.relu(self.fc1(x.to(dev0)))
|
|
x = self.relu(self.fc2(x.to(dev1)))
|
|
x = self.relu(self.fc3(x.to(dev2)))
|
|
x = self.fc4(x.to(dev3))
|
|
return F.softmax(x, dim=1).to(dev0)
|
|
|
|
|
|
class ConvNet(nn.Module):
|
|
def __init__(self, gpus, layouts, dtypes):
|
|
super(ConvNet, self).__init__()
|
|
self.dtypes = dtypes
|
|
if isinstance(gpus, list):
|
|
self.layer_gpus = gpus
|
|
else:
|
|
gpus = [gpus] * 4
|
|
self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(
|
|
device=gpus[0], memory_format=layouts[0], dtype=dtypes[0]
|
|
)
|
|
self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(
|
|
device=gpus[1], memory_format=layouts[1], dtype=dtypes[1]
|
|
)
|
|
self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(
|
|
device=gpus[2], memory_format=layouts[2], dtype=dtypes[2]
|
|
)
|
|
self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(
|
|
device=gpus[3], memory_format=layouts[3], dtype=dtypes[3]
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x.to(self.dtypes[0])
|
|
# Could say
|
|
# x = self.conv0(x).to(device=self.conv1.weight.device, dtype=self.dtypes[1])
|
|
# etc. But I don't want to appeal to the weights' devices directly, because part of this test's purpose
|
|
# is to verify weights are where expected if the model gets replicated.
|
|
gpus = self.layer_gpus if hasattr(self, "layer_gpus") else [x.device] * 4
|
|
x = self.conv0(x).to(device=gpus[1], dtype=self.dtypes[1])
|
|
x = self.conv1(x).to(device=gpus[2], dtype=self.dtypes[2])
|
|
x = self.conv2(x).to(device=gpus[3], dtype=self.dtypes[3])
|
|
return self.conv3(x)
|
|
|
|
|
|
class Task(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.p = nn.Parameter(torch.ones(2, 2))
|
|
|
|
def forward(self, x):
|
|
return self.p + x
|
|
|
|
|
|
class ModuleForDdpCommHook(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t0 = Task()
|
|
|
|
def forward(self, x, rank):
|
|
return self.t0(x + rank)
|
|
|
|
|
|
class SparseGradientModule(nn.Module):
|
|
def __init__(self):
|
|
super(SparseGradientModule, self).__init__()
|
|
self.embedding = nn.EmbeddingBag(10, 10, sparse=True)
|
|
|
|
def forward(self, x):
|
|
return F.softmax(self.embedding(x), dim=1)
|
|
|
|
|
|
class AbstractDistributedDataParallelTest(object):
|
|
def tearDown(self):
|
|
# DistributedDataParallel test doesn't seem to call FileStore destructor
|
|
# TODO: investigate this test and the test is known to have issues
|
|
# Use this hack to remove files for that test
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def _prepare_single_device_module(
|
|
self,
|
|
process_group,
|
|
devices,
|
|
device_ids,
|
|
global_batch_size,
|
|
gradient_as_bucket_view=False,
|
|
):
|
|
model = Net()
|
|
device = devices[0] if devices else torch.device("cuda:%d" % self.rank)
|
|
ddp_model = DistributedDataParallel(
|
|
copy.deepcopy(model).to(device),
|
|
device_ids=device_ids,
|
|
process_group=process_group,
|
|
bucket_cap_mb=0.001,
|
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
|
)
|
|
|
|
model.to(device)
|
|
|
|
input = torch.randn(global_batch_size, 2).to(device)
|
|
target = torch.randn(global_batch_size, 4).to(device)
|
|
|
|
return model, ddp_model, input, target
|
|
|
|
def _prepare_multi_device_module(
|
|
self,
|
|
process_group,
|
|
devices,
|
|
device_ids,
|
|
global_batch_size,
|
|
gradient_as_bucket_view=False,
|
|
):
|
|
self.assertTrue(
|
|
len(devices) == 2 or len(devices) == 4,
|
|
"unexpected devices for ddp tests {}".format(devices),
|
|
)
|
|
if len(devices) == 2:
|
|
model = DoubleGpuNet(devices)
|
|
elif len(devices) == 4:
|
|
model = QuadraGpuNet(devices)
|
|
|
|
ddp_model = DistributedDataParallel(
|
|
copy.deepcopy(model),
|
|
device_ids=device_ids,
|
|
process_group=process_group,
|
|
bucket_cap_mb=0.001,
|
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
|
)
|
|
|
|
input = torch.randn(global_batch_size, 2).cuda(devices[0])
|
|
target = torch.randn(global_batch_size, 4)
|
|
|
|
return model, ddp_model, input, target
|
|
|
|
def _test_ddp_with_process_group(
|
|
self,
|
|
process_group,
|
|
devices,
|
|
device_ids,
|
|
multi_device=False,
|
|
gradient_as_bucket_view=False,
|
|
):
|
|
"""
|
|
Note: we pass down `device_ids` all the way to DistributedDataParallel
|
|
as part of the test. Below you find tests that either use a list of
|
|
integers, a list of `torch.Device` instances, or an empty list.
|
|
The `devices` argument is used to control placement of the model and
|
|
must always be specified as list of `torch.Device` instances.
|
|
"""
|
|
local_batch_size = 1 if devices is None else len(devices)
|
|
global_batch_size = self.world_size * local_batch_size
|
|
|
|
if multi_device:
|
|
model, ddp_model, input, target = self._prepare_multi_device_module(
|
|
process_group,
|
|
devices,
|
|
device_ids,
|
|
global_batch_size,
|
|
gradient_as_bucket_view,
|
|
)
|
|
ddp_logging_data = ddp_model._get_ddp_logging_data()
|
|
self.assertTrue(ddp_logging_data.get("is_multi_device_module"))
|
|
else:
|
|
model, ddp_model, input, target = self._prepare_single_device_module(
|
|
process_group,
|
|
devices,
|
|
device_ids,
|
|
global_batch_size,
|
|
gradient_as_bucket_view,
|
|
)
|
|
ddp_logging_data = ddp_model._get_ddp_logging_data()
|
|
self.assertFalse(ddp_logging_data.get("is_multi_device_module"))
|
|
|
|
def step_model(model, input, target):
|
|
model.train()
|
|
output = model(input)
|
|
loss = F.mse_loss(output, target.to(output.device))
|
|
loss.backward()
|
|
|
|
def update_parameters(model):
|
|
for param in model.parameters():
|
|
with torch.no_grad():
|
|
param -= param.grad
|
|
param.grad = None
|
|
|
|
# check two model parameters over 2 iterations
|
|
for iteration in range(2):
|
|
# single cpu/gpu training
|
|
step_model(model, input, target)
|
|
|
|
# DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
|
|
step_model(
|
|
ddp_model,
|
|
input[
|
|
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
|
],
|
|
target[
|
|
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
|
],
|
|
)
|
|
|
|
# Update weights and run a second iteration to shake out errors
|
|
update_parameters(model)
|
|
update_parameters(ddp_model)
|
|
self.assertEqual(
|
|
len(list(model.parameters())), len(list(ddp_model.parameters()))
|
|
)
|
|
for i, j in zip(model.parameters(), ddp_model.parameters()):
|
|
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
|
|
|
|
# Shuffle the input so that DDP input is different
|
|
torch.manual_seed(1337 + iteration)
|
|
input = input[torch.randperm(global_batch_size)]
|
|
|
|
def _gpu_model_with_ddp_comm_hook(
|
|
self, process_group, hook=None, gradient_as_bucket_view=False, state=None
|
|
):
|
|
device_id = gpus_for_rank(self.world_size)[self.rank][0]
|
|
gpu_model = DistributedDataParallel(
|
|
ModuleForDdpCommHook().to(device_id),
|
|
device_ids=[device_id],
|
|
process_group=process_group,
|
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
|
)
|
|
|
|
# Register a DDP communication hook if any.
|
|
if hook is not None:
|
|
gpu_model.register_comm_hook(state, hook)
|
|
|
|
return gpu_model
|
|
|
|
def _gpu_model_with_builtin_ddp_comm_hook(
|
|
self, process_group, hook=None, gradient_as_bucket_view=False
|
|
):
|
|
device_id = gpus_for_rank(self.world_size)[self.rank][0]
|
|
gpu_model = DistributedDataParallel(
|
|
ModuleForDdpCommHook().to(device_id),
|
|
device_ids=[device_id],
|
|
process_group=process_group,
|
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
|
)
|
|
|
|
# Register a built-in DDP communication hook if defined
|
|
if hook is not None:
|
|
gpu_model._register_builtin_comm_hook(hook)
|
|
|
|
return gpu_model
|
|
|
|
def _run_and_verify_hook(self, model, input, expected_grad):
|
|
# Run forward
|
|
output = model(input, self.rank)
|
|
|
|
# Run backward
|
|
output.mean().backward()
|
|
|
|
[self.assertEqual(p.grad, expected_grad) for p in model.parameters()]
|
|
|
|
def _simple_hook(
|
|
self, state: object, bucket: dist.GradBucket
|
|
) -> torch.futures.Future[torch.Tensor]:
|
|
fut = torch.futures.Future()
|
|
fut.set_result(torch.ones_like(bucket.buffer()))
|
|
|
|
def fut_then(fut):
|
|
# Add ones to fut's result.
|
|
t = fut.value()
|
|
return t + torch.ones_like(t)
|
|
|
|
return fut.then(fut_then)
|
|
|
|
|
|
class DistributedDataParallelTest(
|
|
AbstractDistributedDataParallelTest, MultiProcessTestCase
|
|
):
|
|
def setUp(self):
|
|
super(DistributedDataParallelTest, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
def test_invalid_powerSGD_state(self):
|
|
for start_powerSGD_iter, use_error_feedback, warm_start in product(
|
|
[0, 1], [True, False], [True, False]
|
|
):
|
|
if not use_error_feedback and not warm_start:
|
|
continue
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
|
|
"because PowerSGD can only be applied after the first two iterations in DDP.",
|
|
):
|
|
state = powerSGD.PowerSGDState(
|
|
process_group=None,
|
|
matrix_approximation_rank=1,
|
|
start_powerSGD_iter=start_powerSGD_iter,
|
|
use_error_feedback=use_error_feedback,
|
|
warm_start=warm_start,
|
|
)
|
|
|
|
|
|
class ComputeBucketAssignmentTest(TestCase):
|
|
def test_single_limit_single_dtype(self):
|
|
tensors = [
|
|
torch.empty([100], dtype=torch.float),
|
|
torch.empty([200], dtype=torch.float),
|
|
torch.empty([100], dtype=torch.float),
|
|
torch.empty([50], dtype=torch.float),
|
|
]
|
|
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
|
|
tensors, [400]
|
|
)
|
|
self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits))
|
|
self.assertEqual([[0], [1], [2], [3]], result)
|
|
|
|
def test_single_limit_multi_dtype(self):
|
|
tensors = [
|
|
torch.empty([50], dtype=torch.float),
|
|
torch.empty([25], dtype=torch.double),
|
|
torch.empty([50], dtype=torch.float),
|
|
torch.empty([25], dtype=torch.double),
|
|
torch.empty([50], dtype=torch.float),
|
|
torch.empty([25], dtype=torch.double),
|
|
]
|
|
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
|
|
tensors, [400]
|
|
)
|
|
self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits))
|
|
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
|
|
|
|
def test_multi_limit_single_dtype(self):
|
|
tensors = [
|
|
torch.empty([10], dtype=torch.float),
|
|
torch.empty([10], dtype=torch.float),
|
|
torch.empty([10], dtype=torch.float),
|
|
torch.empty([10], dtype=torch.float),
|
|
]
|
|
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
|
|
tensors, [40, 80]
|
|
)
|
|
self.assertEqual(per_bucket_size_limits, [40, 80, 80])
|
|
self.assertEqual([[0], [1, 2], [3]], result)
|
|
|
|
def test_multi_limit_multi_dtype(self):
|
|
tensors = [
|
|
torch.empty([50], dtype=torch.float),
|
|
torch.empty([25], dtype=torch.double),
|
|
torch.empty([50], dtype=torch.float),
|
|
torch.empty([25], dtype=torch.double),
|
|
torch.empty([50], dtype=torch.float),
|
|
torch.empty([25], dtype=torch.double),
|
|
]
|
|
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
|
|
tensors, [200, 400]
|
|
)
|
|
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
|
|
self.assertEqual(per_bucket_size_limits, [200, 200, 400, 400])
|
|
|
|
|
|
class AbstractCommTest(object):
|
|
@property
|
|
def op_timeout_sec(self):
|
|
return 1
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def _verify_sequence_number_across_pg(self, pg, verify_pg):
|
|
|
|
seq_num = pg._get_sequence_number_for_group()
|
|
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
|
|
# We use a separate pg to verify the sequence numbers, otherwise these
|
|
# collectives will themselves increment the sequence number.
|
|
dist.all_gather_object(obj_list, seq_num, group=verify_pg)
|
|
self.assertEqual(len(set(obj_list)), 1)
|
|
return obj_list[0]
|
|
|
|
def _test_sequence_num_incremented(self, process_group, ranks):
|
|
# verify initial sequence numbers. Use a distinct process group for
|
|
# verification to keep counts as expected with respect to process_group.
|
|
verify_pg = dist.new_group(
|
|
ranks=ranks,
|
|
backend="gloo",
|
|
)
|
|
assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)
|
|
|
|
initial_num = (
|
|
self._verify_sequence_number_across_pg(
|
|
pg=process_group, verify_pg=verify_pg
|
|
)
|
|
if not c10d._rank_not_in_group(process_group)
|
|
else -1
|
|
)
|
|
|
|
# Verify sequence numbers are appropriately incremented
|
|
for i in range(10):
|
|
t = torch.ones(1, device=torch.cuda.current_device())
|
|
dist.all_reduce(t, group=process_group)
|
|
if not c10d._rank_not_in_group(process_group):
|
|
seq_num = self._verify_sequence_number_across_pg(
|
|
pg=process_group,
|
|
verify_pg=verify_pg,
|
|
)
|
|
self.assertEqual(initial_num + i + 1, seq_num)
|
|
|
|
if dist.get_world_size(process_group) > 2:
|
|
# Test when certain ranks don't call collectives
|
|
if dist.get_rank(process_group) not in [0, 2]:
|
|
dist.all_reduce(t, group=process_group, async_op=True)
|
|
# Now ranks 0 and 2 should be lagging by 1.
|
|
if not c10d._rank_not_in_group(process_group):
|
|
seq_num = process_group._get_sequence_number_for_group()
|
|
rank = dist.get_rank(process_group)
|
|
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
|
|
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
|
|
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
|
|
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
|
|
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
|
|
expected_same = {
|
|
rank_to_seq_num[i]
|
|
for i in rank_to_seq_num.keys()
|
|
if i not in [0, 2]
|
|
}
|
|
self.assertEqual(len(expected_same), 1)
|
|
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
|
|
|
|
def _test_sequence_num_incremented_default_group(self, backend_name):
|
|
torch.cuda.set_device(self.rank)
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend_name,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
self._test_sequence_num_incremented(
|
|
c10d._get_default_group(),
|
|
ranks=list(i for i in range(dist.get_world_size())),
|
|
)
|
|
|
|
def _test_sequence_num_incremented_subgroup(self, backend_name):
|
|
torch.cuda.set_device(self.rank)
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend_name,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
subgroup_ranks = [0, 1, 2]
|
|
subgroup = dist.new_group(subgroup_ranks)
|
|
self._test_sequence_num_incremented(subgroup, subgroup_ranks)
|
|
|
|
def _test_sequence_num_set_default_pg(self, backend):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
|
|
default_pg = c10d._get_default_group()
|
|
seq_num = default_pg._get_sequence_number_for_group()
|
|
obj_list = [None for _ in range(dist.get_world_size())]
|
|
dist.all_gather_object(obj_list, seq_num)
|
|
self.assertEqual(len(set(obj_list)), 1)
|
|
|
|
def _test_sequence_num_set_new_group(self, backend):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
|
|
subgroup = dist.new_group([0, 1])
|
|
|
|
if not c10d._rank_not_in_group(subgroup):
|
|
subgroup_seq = subgroup._get_sequence_number_for_group()
|
|
obj_list = [None for _ in range(dist.get_world_size(subgroup))]
|
|
dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
|
|
self.assertEqual(len(set(obj_list)), 1)
|
|
|
|
def _test_warn_not_in_group(self, backend):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
|
|
group = dist.new_group(in_group_ranks)
|
|
|
|
x = torch.zeros(2, 2).cuda(self.rank)
|
|
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
|
|
if self.rank not in in_group_ranks:
|
|
msg = ".*{}.*does not belong to.*"
|
|
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
|
|
dist.all_gather(xs, x, group=group)
|
|
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_reduce")):
|
|
dist.all_reduce(x, group=group)
|
|
with self.assertWarnsOnceRegex(UserWarning, msg.format("barrier")):
|
|
dist.barrier(group=group)
|
|
with self.assertWarnsOnceRegex(UserWarning, msg.format("broadcast")):
|
|
dist.broadcast(x, src=0, group=group)
|
|
else:
|
|
dist.all_gather(xs, x, group=group)
|
|
dist.all_reduce(x, group=group)
|
|
dist.barrier(group=group)
|
|
dist.broadcast(x, src=0, group=group)
|
|
|
|
|
|
class CommTest(AbstractCommTest, MultiProcessTestCase):
|
|
def setUp(self):
|
|
super(CommTest, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super(CommTest, self).tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def test_distributed_debug_mode(self):
|
|
# Default should be off
|
|
default_debug_mode = dist._get_debug_mode()
|
|
self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF)
|
|
mapping = {
|
|
"OFF": dist._DistributedDebugLevel.OFF,
|
|
"INFO": dist._DistributedDebugLevel.INFO,
|
|
"DETAIL": dist._DistributedDebugLevel.DETAIL,
|
|
}
|
|
invalid_debug_modes = ["foo", 0, 1, -1]
|
|
|
|
for mode in mapping.keys():
|
|
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
|
|
set_debug_mode = dist._get_debug_mode()
|
|
self.assertEqual(
|
|
set_debug_mode,
|
|
mapping[mode],
|
|
f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
|
|
)
|
|
|
|
for mode in invalid_debug_modes:
|
|
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
|
|
with self.assertRaisesRegex(RuntimeError, "to be one of"):
|
|
dist._get_debug_mode()
|
|
|
|
|
|
class DummyWork(dist._Work):
|
|
def wait(self, timeout=5.0):
|
|
if torch.cuda.is_available():
|
|
torch.cuda.current_stream().synchronize()
|
|
return True
|
|
|
|
|
|
class DummyProcessGroup(dist.ProcessGroup):
|
|
def getBackendName(self):
|
|
return "Dummy"
|
|
|
|
def allgather(self, output_tensor_lists, input_tensor_list, opts=None):
|
|
for output_tensor_list, input_tensor in zip(output_tensor_lists, input_tensor_list):
|
|
for output_tensor in output_tensor_list:
|
|
output_tensor.copy_(input_tensor)
|
|
|
|
return DummyWork()
|
|
|
|
def allreduce(self, tensor_list, opts=None):
|
|
for tensor in tensor_list:
|
|
tensor.add_(2)
|
|
|
|
return DummyWork()
|
|
|
|
def barrier(self, opts=None):
|
|
store = c10d._get_default_store()
|
|
key = "TEST:DummyProcessGroup:barrier"
|
|
if self.rank() == 0:
|
|
worker_count = 0
|
|
# By default, TCPServer lives on rank 0. So rank 0 needs to make
|
|
# sure that it does not exit too early before other ranks finish
|
|
# using the store.
|
|
# Note that, _store_based_barrier does not solve this problem, as
|
|
# all ranks need to run at least one store.add(key, 0) before
|
|
# exiting, but there is no guarantee that rank 0 is still alive at
|
|
# that point.
|
|
while worker_count < self.size() - 1:
|
|
worker_count = store.add(key, 0)
|
|
else:
|
|
store.add(key, 1)
|
|
|
|
return DummyWork()
|
|
|
|
def broadcast(self, tensor_list, opts=None):
|
|
for tensor in tensor_list:
|
|
tensor.add_(1)
|
|
|
|
return DummyWork()
|
|
|
|
def reduce_scatter(self, output_tensor_list, input_tensor_lists, opts=None):
|
|
for output_tensor, input_tensor_list in zip(output_tensor_list, input_tensor_lists):
|
|
output_tensor.copy_(input_tensor_list[self.rank()])
|
|
|
|
return DummyWork()
|
|
|
|
def send(self, tensor_list, dst, tag=0):
|
|
for tensor in tensor_list:
|
|
tensor.add_(1)
|
|
|
|
return DummyWork()
|
|
|
|
def recv(self, tensor_list, src, tag=0):
|
|
for tensor in tensor_list:
|
|
tensor.add_(2)
|
|
|
|
return DummyWork()
|
|
|
|
|
|
class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super(PythonProcessGroupExtensionTest, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super(PythonProcessGroupExtensionTest, self).tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def test_get_backend_name(self):
|
|
dpg = DummyProcessGroup(0, 1)
|
|
self.assertEqual("Dummy", dpg.name())
|
|
|
|
def test_backend_class_attr(self):
|
|
dist.Backend.register_backend(
|
|
"dummy",
|
|
PythonProcessGroupExtensionTest.create_dummy
|
|
)
|
|
self.assertEqual(dist.Backend.DUMMY, "DUMMY")
|
|
self.assertEqual(
|
|
dist.Backend._plugins["DUMMY"],
|
|
PythonProcessGroupExtensionTest.create_dummy
|
|
)
|
|
|
|
@staticmethod
|
|
def create_dummy(store, rank, size, timeout):
|
|
return DummyProcessGroup(rank, size)
|
|
|
|
def test_collectives(self):
|
|
dist.Backend.register_backend("dummy", PythonProcessGroupExtensionTest.create_dummy)
|
|
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
os.environ['MASTER_PORT'] = '6789'
|
|
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
|
|
|
|
# test all_gather
|
|
input_tensor = torch.ones(2, 2) * 7
|
|
output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)]
|
|
dist.all_gather(output_tensor_list, input_tensor)
|
|
|
|
for tensor in output_tensor_list:
|
|
self.assertEqual(tensor, input_tensor)
|
|
|
|
# test all_reduce
|
|
input_tensor = torch.ones(2, 2) * 7
|
|
dist.all_reduce(input_tensor)
|
|
self.assertEqual(input_tensor, torch.ones(2, 2) * 7 + 2)
|
|
|
|
# test broadcast
|
|
input_tensor = torch.zeros(2, 2)
|
|
dist.broadcast(input_tensor, 0, async_op=True).wait()
|
|
self.assertEqual(torch.ones(2, 2), input_tensor)
|
|
|
|
# test reduce_scatter
|
|
output_tensor = torch.zeros(2, 2)
|
|
input_tensor_list = [torch.ones(2, 2) for _ in range(self.world_size)]
|
|
dist.reduce_scatter(output_tensor, input_tensor_list)
|
|
self.assertEqual(output_tensor, torch.zeros(2, 2) + 1)
|
|
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
def test_send_recv(self):
|
|
dist.Backend.register_backend("dummy", PythonProcessGroupExtensionTest.create_dummy)
|
|
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
os.environ['MASTER_PORT'] = '6789'
|
|
dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size)
|
|
|
|
# test send
|
|
input_tensor = torch.zeros(2, 2)
|
|
dist.send(input_tensor, (self.rank + 1) % self.world_size)
|
|
self.assertEqual(input_tensor, torch.zeros(2, 2) + 1)
|
|
|
|
# test recv
|
|
input_tensor = torch.zeros(2, 2)
|
|
dist.recv(input_tensor, (self.rank + 1) % self.world_size)
|
|
self.assertEqual(input_tensor, torch.zeros(2, 2) + 2)
|
|
|
|
dist.barrier()
|
|
# intentionally not calling into `destroy_process_group` as not all
|
|
# user applications would explicitly that.
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
assert (
|
|
not torch.cuda._initialized
|
|
), "test_distributed must not have initialized CUDA context on main process"
|
|
|
|
run_tests()
|