mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
port 3 distributed test to Intel GPU and unified some common functions (#158533)
For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - enabled XPU for some test path - Unify some common code under torch/testing/_internal for multiple backend, for example: - requires_nccl_version - _dynamo_dist_per_rank_init - DynamoDistributedSingleProcTestCase - DistTestCases - FSDPTestMultiThread Pull Request resolved: https://github.com/pytorch/pytorch/pull/158533 Approved by: https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
parent
9a06e6d031
commit
6e8865fbc1
|
|
@ -13,7 +13,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecis
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
||||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
requires_nccl,
|
requires_accelerator_dist_backend,
|
||||||
requires_nccl_version,
|
requires_nccl_version,
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
|
|
@ -30,17 +30,22 @@ if not dist.is_available():
|
||||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
# bfloat16 is only supported by CUDA 11+
|
device_type = (
|
||||||
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||||
torch.version.cuda is not None or torch.version.hip is not None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# bfloat16 is only supported by CUDA 11+ or XPU
|
||||||
|
BFLOAT16_AVAILABLE = (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and (torch.version.cuda is not None or torch.version.hip is not None)
|
||||||
|
) or torch.xpu.is_available()
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
|
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
|
||||||
# to ensure determinism
|
# to ensure determinism
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.manual_seed(0)
|
torch.get_device_module(device_type).manual_seed(0)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if has_wrapping:
|
if has_wrapping:
|
||||||
|
|
@ -50,12 +55,12 @@ class Net(nn.Module):
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
FSDP(
|
FSDP(
|
||||||
nn.Linear(16, 8),
|
nn.Linear(16, 8),
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
sharding_strategy=sharding_strategy,
|
sharding_strategy=sharding_strategy,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
sharding_strategy=sharding_strategy,
|
sharding_strategy=sharding_strategy,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
)
|
)
|
||||||
|
|
@ -134,11 +139,11 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
"""
|
"""
|
||||||
out_dim = self.world_size
|
out_dim = self.world_size
|
||||||
net = torch.nn.Linear(1, out_dim, bias=False)
|
net = torch.nn.Linear(1, out_dim, bias=False)
|
||||||
inpt = torch.tensor([self.rank]).float().cuda(self.rank)
|
inpt = torch.tensor([self.rank]).float().to(self.rank)
|
||||||
|
|
||||||
net_default_hook = FSDP(
|
net_default_hook = FSDP(
|
||||||
net,
|
net,
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
sharding_strategy=sharding_strategy,
|
sharding_strategy=sharding_strategy,
|
||||||
).to(self.rank)
|
).to(self.rank)
|
||||||
|
|
||||||
|
|
@ -172,10 +177,10 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
]
|
]
|
||||||
|
|
||||||
def _init_model(self, core, sharding_strategy, mixed_precision=None):
|
def _init_model(self, core, sharding_strategy, mixed_precision=None):
|
||||||
device = torch.device("cuda")
|
device = torch.device(device_type)
|
||||||
return FSDP(
|
return FSDP(
|
||||||
core,
|
core,
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
sharding_strategy=sharding_strategy,
|
sharding_strategy=sharding_strategy,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
@ -277,7 +282,7 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
ShardingStrategy.HYBRID_SHARD,
|
ShardingStrategy.HYBRID_SHARD,
|
||||||
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
||||||
):
|
):
|
||||||
model = Net(False, None, None).cuda()
|
model = Net(False, None, None).to(device=device_type)
|
||||||
fsdp_model = FSDP(
|
fsdp_model = FSDP(
|
||||||
model,
|
model,
|
||||||
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
||||||
|
|
@ -337,7 +342,7 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
):
|
):
|
||||||
# keep everything deterministic for input data
|
# keep everything deterministic for input data
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.manual_seed(0)
|
torch.get_device_module(device_type).manual_seed(0)
|
||||||
|
|
||||||
fsdp_with_hook = self._init_model(
|
fsdp_with_hook = self._init_model(
|
||||||
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
|
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
|
||||||
|
|
@ -359,7 +364,7 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
|
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
|
||||||
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
|
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
|
||||||
|
|
||||||
in_data = torch.rand(16, 8).cuda()
|
in_data = torch.rand(16, 8).to(device=device_type)
|
||||||
fsdp_with_hook.train()
|
fsdp_with_hook.train()
|
||||||
fsdp_with_mp.train()
|
fsdp_with_mp.train()
|
||||||
loss_hook = fsdp_with_hook(in_data).sum()
|
loss_hook = fsdp_with_hook(in_data).sum()
|
||||||
|
|
@ -378,7 +383,7 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
):
|
):
|
||||||
self.assertEqual(hook_param.grad, mp_param.grad)
|
self.assertEqual(hook_param.grad, mp_param.grad)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@parametrize("has_wrapping", [True, False])
|
@parametrize("has_wrapping", [True, False])
|
||||||
@parametrize(
|
@parametrize(
|
||||||
|
|
@ -399,11 +404,11 @@ class TestCommunicationHooks(FSDPTest):
|
||||||
state, hook, sharding_strategy, torch.float16, has_wrapping
|
state, hook, sharding_strategy, torch.float16, has_wrapping
|
||||||
)
|
)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
|
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
|
||||||
@skip_but_pass_in_sandcastle_if(
|
@skip_but_pass_in_sandcastle_if(
|
||||||
not BFLOAT16_AVAILABLE,
|
not BFLOAT16_AVAILABLE,
|
||||||
"BFloat16 is only supported by CUDA 11+",
|
"BFloat16 is only supported by CUDA 11+ or XPU",
|
||||||
)
|
)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@parametrize("has_wrapping", [True, False])
|
@parametrize("has_wrapping", [True, False])
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,10 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
device_type = (
|
||||||
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MyModel(nn.Module):
|
class MyModel(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -93,9 +97,9 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
|
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
|
||||||
"""
|
"""
|
||||||
dev_id = (
|
dev_id = (
|
||||||
torch.cuda.current_device()
|
torch.accelerator.current_device_index()
|
||||||
if use_index
|
if use_index
|
||||||
else torch.device("cuda", torch.cuda.current_device())
|
else torch.device(device_type, torch.accelerator.current_device_index())
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_device_matches(module, device_id):
|
def _check_device_matches(module, device_id):
|
||||||
|
|
@ -108,7 +112,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
self.assertEqual(1, len(devices))
|
self.assertEqual(1, len(devices))
|
||||||
found_device = devices.pop()
|
found_device = devices.pop()
|
||||||
if use_index and not isinstance(device_id, torch.device):
|
if use_index and not isinstance(device_id, torch.device):
|
||||||
device = torch.device("cuda", device_id)
|
device = torch.device(device_type, device_id)
|
||||||
else:
|
else:
|
||||||
device = device_id
|
device = device_id
|
||||||
self.assertEqual(found_device, device)
|
self.assertEqual(found_device, device)
|
||||||
|
|
@ -140,10 +144,11 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
self.process_group,
|
self.process_group,
|
||||||
FSDPInitMode.RECURSIVE,
|
FSDPInitMode.RECURSIVE,
|
||||||
DEVICEInitMode.DEVICE_BEFORE,
|
DEVICEInitMode.DEVICE_BEFORE,
|
||||||
fsdp_kwargs={"device_id": torch.device("cuda")},
|
fsdp_kwargs={"device_id": torch.device(device_type)},
|
||||||
)
|
)
|
||||||
_check_device_matches(
|
_check_device_matches(
|
||||||
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
|
nested_wrapped_module,
|
||||||
|
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||||
)
|
)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
|
@ -178,8 +183,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
loss = torch.nn.functional.cross_entropy(output, y)
|
loss = torch.nn.functional.cross_entropy(output, y)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
model = Mnist().cuda()
|
model = Mnist().to(device=device_type)
|
||||||
model1 = Mnist().cuda()
|
model1 = Mnist().to(device=device_type)
|
||||||
model1.load_state_dict(model.state_dict())
|
model1.load_state_dict(model.state_dict())
|
||||||
fsdp_model = FSDP(
|
fsdp_model = FSDP(
|
||||||
model,
|
model,
|
||||||
|
|
@ -197,17 +202,17 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
|
|
||||||
seed = self.rank + 20231010
|
seed = self.rank + 20231010
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.get_device_module(device_type).manual_seed(seed)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
grads = []
|
grads = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
|
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
|
||||||
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
|
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
|
||||||
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
|
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
|
||||||
seed = self.rank + i
|
seed = self.rank + i
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.get_device_module(device_type).manual_seed(seed)
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -223,8 +228,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
fsdp_model.eval()
|
fsdp_model.eval()
|
||||||
ddp_model.eval()
|
ddp_model.eval()
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
|
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
|
||||||
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
|
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
|
||||||
fsdp_loss = fsdp_model(x, y)
|
fsdp_loss = fsdp_model(x, y)
|
||||||
ddp_loss = ddp_model(x, y)
|
ddp_loss = ddp_model(x, y)
|
||||||
assert torch.allclose(fsdp_loss, ddp_loss)
|
assert torch.allclose(fsdp_loss, ddp_loss)
|
||||||
|
|
@ -232,12 +237,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
fsdp_model.train()
|
fsdp_model.train()
|
||||||
ddp_model.train()
|
ddp_model.train()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
|
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
|
||||||
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
|
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
|
||||||
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
|
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
|
||||||
seed = self.rank + i
|
seed = self.rank + i
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.get_device_module(device_type).manual_seed(seed)
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -272,12 +277,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
return out1
|
return out1
|
||||||
|
|
||||||
fsdp = FSDP(
|
fsdp = FSDP(
|
||||||
MyModel().cuda(),
|
MyModel().to(device=device_type),
|
||||||
sharding_strategy=sharding_strategy,
|
sharding_strategy=sharding_strategy,
|
||||||
auto_wrap_policy=always_wrap_policy,
|
auto_wrap_policy=always_wrap_policy,
|
||||||
)
|
)
|
||||||
x = torch.randn(10, 10, device="cuda")
|
x = torch.randn(10, 10, device=device_type)
|
||||||
y = torch.randn(10, 10, device="cuda")
|
y = torch.randn(10, 10, device=device_type)
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
if use_second_layer:
|
if use_second_layer:
|
||||||
a, _ = fsdp(x, y)
|
a, _ = fsdp(x, y)
|
||||||
|
|
@ -336,7 +341,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
torch.testing.assert_close(p1, p2)
|
torch.testing.assert_close(p1, p2)
|
||||||
|
|
||||||
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
|
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
|
||||||
m = MyModule().cuda()
|
m = MyModule().to(device=device_type)
|
||||||
m_local = deepcopy(m)
|
m_local = deepcopy(m)
|
||||||
local_m = m_local
|
local_m = m_local
|
||||||
prev_params = [p.clone() for p in m_local.parameters()]
|
prev_params = [p.clone() for p in m_local.parameters()]
|
||||||
|
|
@ -349,7 +354,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
|
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
|
||||||
|
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
t = torch.ones(4, device="cuda")
|
t = torch.ones(4, device=device_type)
|
||||||
a, b = m(t)
|
a, b = m(t)
|
||||||
local_a, local_b = local_m(t)
|
local_a, local_b = local_m(t)
|
||||||
if i < 2:
|
if i < 2:
|
||||||
|
|
@ -385,7 +390,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_fsdp_optim_overlap_no_use_orig_params_error(self):
|
def test_fsdp_optim_overlap_no_use_orig_params_error(self):
|
||||||
fsdp_overlap = FSDP(
|
fsdp_overlap = FSDP(
|
||||||
MyModel().cuda(),
|
MyModel().to(device=device_type),
|
||||||
auto_wrap_policy=always_wrap_policy,
|
auto_wrap_policy=always_wrap_policy,
|
||||||
use_orig_params=False,
|
use_orig_params=False,
|
||||||
)
|
)
|
||||||
|
|
@ -398,7 +403,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
register_hook=False,
|
register_hook=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
inp = torch.randn(10, 10, device="cuda")
|
inp = torch.randn(10, 10, device=device_type)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "only supported with use_orig_params=True"
|
RuntimeError, "only supported with use_orig_params=True"
|
||||||
):
|
):
|
||||||
|
|
@ -409,16 +414,16 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
for cpu_offload in [True, False]:
|
for cpu_offload in [True, False]:
|
||||||
offload = CPUOffload(offload_params=cpu_offload)
|
offload = CPUOffload(offload_params=cpu_offload)
|
||||||
model = MyModel().cuda()
|
model = MyModel().to(device=device_type)
|
||||||
model_overlap = deepcopy(model)
|
model_overlap = deepcopy(model)
|
||||||
fsdp = FSDP(
|
fsdp = FSDP(
|
||||||
model.cuda(),
|
model.to(device=device_type),
|
||||||
auto_wrap_policy=always_wrap_policy,
|
auto_wrap_policy=always_wrap_policy,
|
||||||
use_orig_params=True,
|
use_orig_params=True,
|
||||||
cpu_offload=offload,
|
cpu_offload=offload,
|
||||||
)
|
)
|
||||||
fsdp_overlap = FSDP(
|
fsdp_overlap = FSDP(
|
||||||
model_overlap.cuda(),
|
model_overlap.to(device=device_type),
|
||||||
auto_wrap_policy=always_wrap_policy,
|
auto_wrap_policy=always_wrap_policy,
|
||||||
use_orig_params=True,
|
use_orig_params=True,
|
||||||
cpu_offload=offload,
|
cpu_offload=offload,
|
||||||
|
|
@ -445,7 +450,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
inp = torch.randn(2, 2, device="cuda")
|
inp = torch.randn(2, 2, device=device_type)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inp_clone = inp.clone()
|
inp_clone = inp.clone()
|
||||||
fsdp(inp, inp).sum().backward()
|
fsdp(inp, inp).sum().backward()
|
||||||
|
|
@ -546,7 +551,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
"""Tests that passing a CPU module to FSDP preserves that the wrapped
|
"""Tests that passing a CPU module to FSDP preserves that the wrapped
|
||||||
module is on CPU after FSDP initialization, albeit after logging a
|
module is on CPU after FSDP initialization, albeit after logging a
|
||||||
warning, and that FSDP moves CPU input to GPU before the forward."""
|
warning, and that FSDP moves CPU input to GPU before the forward."""
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
regex = "passed-in `module` is on CPU"
|
regex = "passed-in `module` is on CPU"
|
||||||
context = self.assertWarnsRegex(
|
context = self.assertWarnsRegex(
|
||||||
expected_warning=UserWarning, expected_regex=regex
|
expected_warning=UserWarning, expected_regex=regex
|
||||||
|
|
@ -561,7 +566,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
devices = {p.device for p in fsdp_model.parameters()}
|
devices = {p.device for p in fsdp_model.parameters()}
|
||||||
self.assertEqual(1, len(devices))
|
self.assertEqual(1, len(devices))
|
||||||
self.assertEqual(torch.device("cpu"), devices.pop())
|
self.assertEqual(torch.device("cpu"), devices.pop())
|
||||||
fsdp_model = fsdp_model.cuda()
|
fsdp_model = fsdp_model.to(device=device_type)
|
||||||
# Ensure fwd + backward can be performed after moving to CUDA.
|
# Ensure fwd + backward can be performed after moving to CUDA.
|
||||||
# CPU input also tests that input is correctly moved to appropriate
|
# CPU input also tests that input is correctly moved to appropriate
|
||||||
# CUDA device.
|
# CUDA device.
|
||||||
|
|
@ -606,19 +611,19 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||||
nested_wrapped_module,
|
nested_wrapped_module,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
sync_module_states=True,
|
sync_module_states=True,
|
||||||
)
|
)
|
||||||
# Each rank's buffers should be 0s since rank 0 is the source, and they
|
# Each rank's buffers should be 0s since rank 0 is the source, and they
|
||||||
# should be on GPU since we specified `device_id`
|
# should be on GPU since we specified `device_id`
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_wrapped_module.buf.device,
|
nested_wrapped_module.buf.device,
|
||||||
torch.device("cuda", torch.cuda.current_device()),
|
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||||
)
|
)
|
||||||
self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
|
self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_wrapped_module.module.module[0].buf.device,
|
nested_wrapped_module.module.module[0].buf.device,
|
||||||
torch.device("cuda", torch.cuda.current_device()),
|
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
|
nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
|
||||||
|
|
@ -644,9 +649,9 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
m = MyModule().cuda()
|
m = MyModule().to(device=device_type)
|
||||||
m = FSDP(m)
|
m = FSDP(m)
|
||||||
t = torch.ones(1, device="cuda", requires_grad=True)
|
t = torch.ones(1, device=device_type, requires_grad=True)
|
||||||
|
|
||||||
MyOutputType = namedtuple(
|
MyOutputType = namedtuple(
|
||||||
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
|
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
|
||||||
|
|
@ -683,7 +688,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
auto_wrap_policy = ModuleWrapPolicy(module_classes)
|
auto_wrap_policy = ModuleWrapPolicy(module_classes)
|
||||||
fsdp_kwargs = {
|
fsdp_kwargs = {
|
||||||
"auto_wrap_policy": auto_wrap_policy,
|
"auto_wrap_policy": auto_wrap_policy,
|
||||||
"device_id": torch.cuda.current_device(),
|
"device_id": torch.accelerator.current_device_index(),
|
||||||
}
|
}
|
||||||
fsdp_model = TransformerWithSharedParams.init(
|
fsdp_model = TransformerWithSharedParams.init(
|
||||||
self.process_group,
|
self.process_group,
|
||||||
|
|
@ -694,7 +699,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
|
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
fsdp_module.compute_device,
|
fsdp_module.compute_device,
|
||||||
torch.device("cuda", torch.cuda.current_device()),
|
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||||
)
|
)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
|
@ -729,7 +734,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
model,
|
model,
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
cpu_offload=CPUOffload(offload_params=True),
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
use_orig_params=use_orig_params,
|
use_orig_params=use_orig_params,
|
||||||
)
|
)
|
||||||
cpu_device = torch.device("cpu")
|
cpu_device = torch.device("cpu")
|
||||||
|
|
@ -742,12 +747,16 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
module that does not match the GPU device ID raises an error."""
|
module that does not match the GPU device ID raises an error."""
|
||||||
# TODO: override FSDP MT Thread _run to set this instead of here for
|
# TODO: override FSDP MT Thread _run to set this instead of here for
|
||||||
# every test.
|
# every test.
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
|
|
||||||
context = (
|
context = (
|
||||||
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
|
self.assertRaisesRegex(
|
||||||
|
ValueError, f"{device_type}:{self.rank} vs {device_type}:0"
|
||||||
|
)
|
||||||
if self.rank != 0
|
if self.rank != 0
|
||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
NestedWrappedModule.init(
|
NestedWrappedModule.init(
|
||||||
self.process_group,
|
self.process_group,
|
||||||
|
|
@ -764,18 +773,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
"""Tests a CPU + GPU module supported if device_id is passed
|
"""Tests a CPU + GPU module supported if device_id is passed
|
||||||
in, errors if device_id is not.
|
in, errors if device_id is not.
|
||||||
"""
|
"""
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
|
|
||||||
class CPUGPUModule(nn.Module):
|
class CPUGPUModule(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.a = nn.Linear(1, 1).cuda()
|
self.a = nn.Linear(1, 1).to(device=device_type)
|
||||||
self.b = nn.Linear(1, 1)
|
self.b = nn.Linear(1, 1)
|
||||||
|
|
||||||
cpu_gpu = CPUGPUModule()
|
cpu_gpu = CPUGPUModule()
|
||||||
fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device())
|
fsdp = FSDP(cpu_gpu, device_id=torch.accelerator.current_device_index())
|
||||||
for param in fsdp.parameters():
|
for param in fsdp.parameters():
|
||||||
self.assertEqual(param.device, torch.device(torch.cuda.current_device()))
|
self.assertEqual(
|
||||||
|
param.device, torch.device(torch.accelerator.current_device_index())
|
||||||
|
)
|
||||||
|
|
||||||
# without device_id, we hit an error
|
# without device_id, we hit an error
|
||||||
with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
|
with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
|
||||||
|
|
@ -783,7 +794,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_fsdp_ignored_module_meta(self):
|
def test_fsdp_ignored_module_meta(self):
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
|
|
||||||
class CPUGPUModule(nn.Module):
|
class CPUGPUModule(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -802,11 +813,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
m = CPUGPUModule()
|
m = CPUGPUModule()
|
||||||
m = FSDP(
|
m = FSDP(
|
||||||
m,
|
m,
|
||||||
device_id=torch.cuda.current_device(),
|
device_id=torch.accelerator.current_device_index(),
|
||||||
ignored_modules=[m.a],
|
ignored_modules=[m.a],
|
||||||
use_orig_params=True,
|
use_orig_params=True,
|
||||||
param_init_fn=lambda m: m.to_empty(
|
param_init_fn=lambda m: m.to_empty(
|
||||||
device=torch.cuda.current_device(), recurse=False
|
device=torch.accelerator.current_device_index(), recurse=False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(meta_device, next(m.a.parameters()).device)
|
self.assertEqual(meta_device, next(m.a.parameters()).device)
|
||||||
|
|
@ -854,20 +865,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
"""
|
"""
|
||||||
# TODO: override FSDP MT Thread _run to set this instead of here for
|
# TODO: override FSDP MT Thread _run to set this instead of here for
|
||||||
# every test.
|
# every test.
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
# Test CPU
|
# Test CPU
|
||||||
no_params = nn.ReLU()
|
no_params = nn.ReLU()
|
||||||
FSDP(no_params)
|
FSDP(no_params)
|
||||||
# Test CUDA
|
# Test CUDA
|
||||||
no_params = nn.ReLU().cuda()
|
no_params = nn.ReLU().to(device=device_type)
|
||||||
FSDP(no_params)
|
FSDP(no_params)
|
||||||
# Test CPU + device_id
|
# Test CPU + device_id
|
||||||
no_params = nn.ReLU()
|
no_params = nn.ReLU()
|
||||||
FSDP(no_params, device_id=torch.cuda.current_device())
|
FSDP(no_params, device_id=torch.accelerator.current_device_index())
|
||||||
# For modules with no params, wrong device_id will raise error about
|
# For modules with no params, wrong device_id will raise error about
|
||||||
# inconsistency between compute_device and device_id, since compute_device
|
# inconsistency between compute_device and device_id, since compute_device
|
||||||
# is computed as torch.cuda.current_device when there are no params.
|
# is computed as torch.cuda.current_device when there are no params.
|
||||||
no_params = nn.ReLU().cuda()
|
no_params = nn.ReLU().to(device=device_type)
|
||||||
context = (
|
context = (
|
||||||
(
|
(
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
|
|
@ -892,11 +903,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Seed via rank to make model different across ranks
|
# Seed via rank to make model different across ranks
|
||||||
torch.manual_seed(rank)
|
torch.manual_seed(rank)
|
||||||
torch.cuda.manual_seed(rank)
|
torch.get_device_module(device_type).manual_seed(rank)
|
||||||
self.lin = nn.Linear(10, 10, bias=False)
|
self.lin = nn.Linear(10, 10, bias=False)
|
||||||
self.buffer = nn.Buffer(torch.ones(1) * rank)
|
self.buffer = nn.Buffer(torch.ones(1) * rank)
|
||||||
|
|
||||||
m = MyModel(self.rank).cuda()
|
m = MyModel(self.rank).to(device=device_type)
|
||||||
_assert_module_states(
|
_assert_module_states(
|
||||||
m, process_group=self.process_group, assert_fn=self.assertNotEqual
|
m, process_group=self.process_group, assert_fn=self.assertNotEqual
|
||||||
)
|
)
|
||||||
|
|
@ -913,7 +924,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
m, process_group=self.process_group, assert_fn=self.assertNotEqual
|
m, process_group=self.process_group, assert_fn=self.assertNotEqual
|
||||||
)
|
)
|
||||||
# Passing sync_module_states into FSDP makes model the same during init.
|
# Passing sync_module_states into FSDP makes model the same during init.
|
||||||
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
|
fsdp = FSDP(
|
||||||
|
m,
|
||||||
|
device_id=torch.accelerator.current_device_index(),
|
||||||
|
sync_module_states=True,
|
||||||
|
)
|
||||||
with fsdp.summon_full_params(fsdp):
|
with fsdp.summon_full_params(fsdp):
|
||||||
_assert_module_states(
|
_assert_module_states(
|
||||||
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
|
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
|
||||||
|
|
@ -968,7 +983,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, f"Expects one homogeneous value for {attr_name}"
|
ValueError, f"Expects one homogeneous value for {attr_name}"
|
||||||
):
|
):
|
||||||
inp = fsdp_model.module.get_input(torch.device("cuda"))
|
inp = fsdp_model.module.get_input(torch.device(device_type))
|
||||||
fsdp_model(*inp)
|
fsdp_model(*inp)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
|
@ -976,7 +991,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
||||||
regex = r"FSDP will not all-gather parameters for containers that do not implement forward"
|
regex = r"FSDP will not all-gather parameters for containers that do not implement forward"
|
||||||
model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)])
|
model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)])
|
||||||
with self.assertWarnsRegex(UserWarning, regex):
|
with self.assertWarnsRegex(UserWarning, regex):
|
||||||
FSDP(model, device_id="cuda")
|
FSDP(model, device_id=device_type)
|
||||||
model = nn.ModuleDict(
|
model = nn.ModuleDict(
|
||||||
{"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))}
|
{"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))}
|
||||||
)
|
)
|
||||||
|
|
@ -1000,7 +1015,10 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
||||||
# warning
|
# warning
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
warnings.simplefilter("always") # trigger all warnings
|
warnings.simplefilter("always") # trigger all warnings
|
||||||
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
|
FSDP(
|
||||||
|
nn.Linear(3, 3).to(device=device_type),
|
||||||
|
sharding_strategy=ShardingStrategy.NO_SHARD,
|
||||||
|
)
|
||||||
for warning in w:
|
for warning in w:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
warning.category != UserWarning
|
warning.category != UserWarning
|
||||||
|
|
@ -1014,16 +1032,20 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
||||||
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
|
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
|
||||||
)
|
)
|
||||||
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
|
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
|
||||||
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
|
FSDP(
|
||||||
|
nn.Linear(3, 3).to(device=device_type),
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
)
|
||||||
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
|
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
|
||||||
FSDP(nn.Linear(3, 3).cuda())
|
FSDP(nn.Linear(3, 3).to(device=device_type))
|
||||||
# - Pass `SHARD_GRAD_OP`
|
# - Pass `SHARD_GRAD_OP`
|
||||||
expected_regex_shard_grad_op = (
|
expected_regex_shard_grad_op = (
|
||||||
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
|
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
|
||||||
)
|
)
|
||||||
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
|
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
|
||||||
FSDP(
|
FSDP(
|
||||||
nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
|
nn.Linear(3, 3).to(device=device_type),
|
||||||
|
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(1)
|
@skip_if_lt_x_gpu(1)
|
||||||
|
|
@ -1047,7 +1069,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
||||||
# Incorrectly moving from CPU -> GPU
|
# Incorrectly moving from CPU -> GPU
|
||||||
model = torch.nn.Linear(10, 10)
|
model = torch.nn.Linear(10, 10)
|
||||||
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
|
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
|
||||||
fsdp_model.to(torch.device("cuda"))
|
fsdp_model.to(torch.device(device_type))
|
||||||
inp = torch.randn((2, 10))
|
inp = torch.randn((2, 10))
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
|
|
@ -1088,16 +1110,16 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
||||||
|
|
||||||
# Construct FSDP module without changing any environment variables and
|
# Construct FSDP module without changing any environment variables and
|
||||||
# run forward, which triggers both unsharded and sharded view setting
|
# run forward, which triggers both unsharded and sharded view setting
|
||||||
module = SetattrLinear(5, 5, torch.device("cuda"))
|
module = SetattrLinear(5, 5, torch.device(device_type))
|
||||||
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
||||||
inp = torch.randn((8, 5), device=torch.device("cuda"))
|
inp = torch.randn((8, 5), device=torch.device(device_type))
|
||||||
called_setattr_override = False
|
called_setattr_override = False
|
||||||
fsdp_module(inp)
|
fsdp_module(inp)
|
||||||
self.assertTrue(called_setattr_override)
|
self.assertTrue(called_setattr_override)
|
||||||
|
|
||||||
# Repeat with unsafe setattr explicitly enabled
|
# Repeat with unsafe setattr explicitly enabled
|
||||||
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
|
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
|
||||||
module = SetattrLinear(5, 5, torch.device("cuda"))
|
module = SetattrLinear(5, 5, torch.device(device_type))
|
||||||
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
||||||
called_setattr_override = False
|
called_setattr_override = False
|
||||||
fsdp_module(inp)
|
fsdp_module(inp)
|
||||||
|
|
@ -1105,7 +1127,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
||||||
|
|
||||||
# Repeat with unsafe setattr explicitly disabled
|
# Repeat with unsafe setattr explicitly disabled
|
||||||
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
|
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
|
||||||
module = SetattrLinear(5, 5, torch.device("cuda"))
|
module = SetattrLinear(5, 5, torch.device(device_type))
|
||||||
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
||||||
called_setattr_override = False
|
called_setattr_override = False
|
||||||
fsdp_module(inp)
|
fsdp_module(inp)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ if not dist.is_available():
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
DistributedTestBase,
|
DistributedTestBase,
|
||||||
MultiThreadedTestCase,
|
MultiThreadedTestCase,
|
||||||
requires_nccl,
|
requires_accelerator_dist_backend,
|
||||||
TEST_SKIPS,
|
TEST_SKIPS,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
|
||||||
skipIfHpu,
|
skipIfHpu,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
|
TEST_XPU,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -64,6 +65,9 @@ devices = ["cpu"]
|
||||||
if TEST_HPU:
|
if TEST_HPU:
|
||||||
devices.append("hpu")
|
devices.append("hpu")
|
||||||
DEVICE = "hpu"
|
DEVICE = "hpu"
|
||||||
|
elif TEST_XPU:
|
||||||
|
devices.append("xpu")
|
||||||
|
DEVICE = "xpu"
|
||||||
elif TEST_CUDA:
|
elif TEST_CUDA:
|
||||||
devices.append("cuda")
|
devices.append("cuda")
|
||||||
|
|
||||||
|
|
@ -269,10 +273,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_broadcast(self, device):
|
def test_broadcast(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
tensor = torch.ones([4], device=device)
|
tensor = torch.ones([4], device=device)
|
||||||
|
|
@ -285,10 +289,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_all_reduce_eager(self, device):
|
def test_all_reduce_eager(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
|
|
||||||
tensor = torch.ones([4], device=device)
|
tensor = torch.ones([4], device=device)
|
||||||
mesh = dt.DeviceMesh(device, torch.arange(4))
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
||||||
|
|
@ -302,10 +306,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_all_reduce_coalesced_eager(self, device):
|
def test_all_reduce_coalesced_eager(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
|
|
||||||
t0 = torch.ones([4], device=device)
|
t0 = torch.ones([4], device=device)
|
||||||
t1 = torch.ones([6], device=device) + 2
|
t1 = torch.ones([6], device=device) + 2
|
||||||
|
|
@ -317,10 +321,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_all_gather_tensor(self, device):
|
def test_all_gather_tensor(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
|
|
||||||
# testing 1d/2d mesh
|
# testing 1d/2d mesh
|
||||||
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||||
|
|
@ -339,10 +343,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_all_gather_into_tensor_coalesced(self, device):
|
def test_all_gather_into_tensor_coalesced(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
|
|
||||||
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
|
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
|
||||||
mesh = dt.DeviceMesh(device, torch.arange(4))
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
||||||
|
|
@ -356,10 +360,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_reduce_scatter_tensor(self, device):
|
def test_reduce_scatter_tensor(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
|
|
||||||
# testing 1d/2d mesh
|
# testing 1d/2d mesh
|
||||||
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||||
|
|
@ -380,10 +384,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
||||||
|
|
||||||
@parametrize("device", devices)
|
@parametrize("device", devices)
|
||||||
def test_reduce_scatter_into_tensor_coalesced(self, device):
|
def test_reduce_scatter_into_tensor_coalesced(self, device):
|
||||||
if device == "cuda":
|
if device != "cpu":
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
self.skipTest("Not enough CUDA devices")
|
self.skipTest("Not enough accelerator devices")
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.accelerator.set_device_index(dist.get_rank())
|
||||||
tensors = [
|
tensors = [
|
||||||
torch.ones([4], dtype=torch.int64, device=device),
|
torch.ones([4], dtype=torch.int64, device=device),
|
||||||
torch.ones([4], dtype=torch.int64, device=device) + 1,
|
torch.ones([4], dtype=torch.int64, device=device) + 1,
|
||||||
|
|
@ -474,18 +478,17 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
||||||
# And then set the BACKEND variable appropriately.
|
# And then set the BACKEND variable appropriately.
|
||||||
if TEST_HPU:
|
if TEST_HPU:
|
||||||
BACKEND = dist.Backend.HCCL
|
BACKEND = dist.Backend.HCCL
|
||||||
|
elif TEST_XPU:
|
||||||
|
BACKEND = dist.Backend.XCCL
|
||||||
|
|
||||||
|
|
||||||
# allows you to check for multiple accelerator irrespective of device type
|
# allows you to check for multiple accelerator irrespective of device type
|
||||||
# to add new device types to this check simply follow the same format
|
# to add new device types to this check simply follow the same format
|
||||||
# and append an elif with the conditional and appropriate device count function for your new device
|
# and append an elif with the conditional and appropriate device count function for your new device
|
||||||
def exit_if_lt_x_accelerators(x):
|
def exit_if_lt_x_accelerators(x):
|
||||||
if TEST_CUDA:
|
if torch.accelerator.is_available():
|
||||||
if torch.cuda.device_count() < x:
|
if torch.accelerator.device_count() < x:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||||
elif TEST_HPU:
|
|
||||||
if torch.hpu.device_count() < x:
|
|
||||||
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
def with_comms(func=None):
|
def with_comms(func=None):
|
||||||
|
|
@ -494,7 +497,9 @@ def with_comms(func=None):
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
if (
|
||||||
|
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL
|
||||||
|
) and torch.accelerator.device_count() < self.world_size:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
|
|
||||||
kwargs["device"] = DEVICE
|
kwargs["device"] = DEVICE
|
||||||
|
|
@ -572,7 +577,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
||||||
self.assertEqual(y, expected)
|
self.assertEqual(y, expected)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
@with_comms()
|
@with_comms()
|
||||||
def test_tracing(self, device):
|
def test_tracing(self, device):
|
||||||
def allreduce(t, pg):
|
def allreduce(t, pg):
|
||||||
|
|
@ -599,7 +604,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
@with_comms()
|
@with_comms()
|
||||||
def test_tracing_with_dce_code(self, device):
|
def test_tracing_with_dce_code(self, device):
|
||||||
if self.world_size > 2:
|
if self.world_size > 2:
|
||||||
|
|
@ -818,13 +823,19 @@ class TestFunctionalAutogradWithDistributedBackend(DistributedTestBase):
|
||||||
|
|
||||||
# Update the supported devices in DEVICE
|
# Update the supported devices in DEVICE
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE
|
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE, allow_xpu=True
|
||||||
)
|
)
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE
|
TestDistributedBackendCollectivesWithWorldSize4,
|
||||||
|
globals(),
|
||||||
|
only_for=DEVICE,
|
||||||
|
allow_xpu=True,
|
||||||
)
|
)
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE
|
TestFunctionalAutogradWithDistributedBackend,
|
||||||
|
globals(),
|
||||||
|
only_for=DEVICE,
|
||||||
|
allow_xpu=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -96,10 +96,10 @@ TEST_SKIPS = {
|
||||||
class DistTestCases:
|
class DistTestCases:
|
||||||
# Backends that do not support a specific collective
|
# Backends that do not support a specific collective
|
||||||
skip_collective = {}
|
skip_collective = {}
|
||||||
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
|
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"}
|
||||||
skip_collective["reduce"] = set()
|
skip_collective["reduce"] = set()
|
||||||
skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
|
skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"}
|
||||||
skip_collective["cpu barrier"] = {"nccl", "ucc"}
|
skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"}
|
||||||
|
|
||||||
# Sets showing that something is implemented
|
# Sets showing that something is implemented
|
||||||
backend_feature = {}
|
backend_feature = {}
|
||||||
|
|
@ -338,15 +338,26 @@ def requires_gloo():
|
||||||
|
|
||||||
|
|
||||||
def requires_nccl_version(version, msg):
|
def requires_nccl_version(version, msg):
|
||||||
if not c10d.is_nccl_available():
|
if TEST_CUDA:
|
||||||
return skip_but_pass_in_sandcastle(
|
if not c10d.is_nccl_available():
|
||||||
"c10d was not compiled with the NCCL backend",
|
return skip_but_pass_in_sandcastle(
|
||||||
)
|
"c10d was not compiled with the NCCL backend",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return skip_but_pass_in_sandcastle_if(
|
||||||
|
torch.cuda.nccl.version() < version,
|
||||||
|
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return skip_but_pass_in_sandcastle_if(
|
|
||||||
torch.cuda.nccl.version() < version,
|
def decorator(func):
|
||||||
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
|
@wraps(func)
|
||||||
)
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def requires_nccl():
|
def requires_nccl():
|
||||||
|
|
@ -435,9 +446,10 @@ def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
|
||||||
Returns True if the device's compute capability is (major, minor) or higher.
|
Returns True if the device's compute capability is (major, minor) or higher.
|
||||||
Error out if the device is not a CUDA device.
|
Error out if the device is not a CUDA device.
|
||||||
Returns False if device is a RoCM device.
|
Returns False if device is a RoCM device.
|
||||||
|
Returns True if device is a non-CUDA device.
|
||||||
"""
|
"""
|
||||||
if device.type != "cuda":
|
if device.type != "cuda":
|
||||||
raise ValueError("sm_is_or_later() is only supported for CUDA devices")
|
return True
|
||||||
|
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
# ROCm devices may have different compute capability codes
|
# ROCm devices may have different compute capability codes
|
||||||
|
|
@ -1456,12 +1468,19 @@ class SaveForwardInputsModel(nn.Module):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _dynamo_dist_per_rank_init(
|
def _dynamo_dist_per_rank_init(
|
||||||
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
|
rank, world_size, backend=None, init_pg=True, fake_pg=False
|
||||||
):
|
):
|
||||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||||
if not fake_pg:
|
if not fake_pg:
|
||||||
torch.accelerator.set_device_index(rank)
|
torch.accelerator.set_device_index(rank)
|
||||||
|
|
||||||
|
device_type = (
|
||||||
|
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||||
|
)
|
||||||
|
if backend is None:
|
||||||
|
backend = c10d.get_default_backend_for_device(device_type)
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = "6789"
|
os.environ["MASTER_PORT"] = "6789"
|
||||||
if init_pg:
|
if init_pg:
|
||||||
|
|
@ -1508,9 +1527,12 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cls.rank = 0
|
cls.rank = 0
|
||||||
cls.device = f"cuda:{cls.rank}"
|
device = torch.accelerator.current_accelerator().type
|
||||||
cls.device_ids = None if "cuda" in cls.device else [cls.rank]
|
cls.device = f"{device}:{cls.rank}"
|
||||||
c10d.init_process_group("nccl", rank=cls.rank, world_size=1)
|
cls.device_ids = None if device in cls.device else [cls.rank]
|
||||||
|
c10d.init_process_group(
|
||||||
|
c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
@ -1122,6 +1123,7 @@ def check_sharded_parity(
|
||||||
cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
|
cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(TEST_XPU, "not-support-multithread")
|
||||||
class FSDPTestMultiThread(MultiThreadedTestCase):
|
class FSDPTestMultiThread(MultiThreadedTestCase):
|
||||||
@property
|
@property
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
|
|
@ -1187,7 +1189,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||||
fake_pg = kwargs.get("fake_pg", False)
|
fake_pg = kwargs.get("fake_pg", False)
|
||||||
|
|
||||||
print(f"dist init r={self.rank}, world={self.world_size}")
|
print(f"dist init r={self.rank}, world={self.world_size}")
|
||||||
if torch.cuda.device_count() < self.world_size:
|
if torch.accelerator.device_count() < self.world_size:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
|
|
||||||
# Specify gloo backend to make 'init_process_group()' succeed,
|
# Specify gloo backend to make 'init_process_group()' succeed,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user