mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
port 3 distributed test to Intel GPU and unified some common functions (#158533)
For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - enabled XPU for some test path - Unify some common code under torch/testing/_internal for multiple backend, for example: - requires_nccl_version - _dynamo_dist_per_rank_init - DynamoDistributedSingleProcTestCase - DistTestCases - FSDPTestMultiThread Pull Request resolved: https://github.com/pytorch/pytorch/pull/158533 Approved by: https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
parent
9a06e6d031
commit
6e8865fbc1
|
|
@ -13,7 +13,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecis
|
|||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
requires_accelerator_dist_backend,
|
||||
requires_nccl_version,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skip_if_lt_x_gpu,
|
||||
|
|
@ -30,17 +30,22 @@ if not dist.is_available():
|
|||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
# bfloat16 is only supported by CUDA 11+
|
||||
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None or torch.version.hip is not None
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
# bfloat16 is only supported by CUDA 11+ or XPU
|
||||
BFLOAT16_AVAILABLE = (
|
||||
torch.cuda.is_available()
|
||||
and (torch.version.cuda is not None or torch.version.hip is not None)
|
||||
) or torch.xpu.is_available()
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
|
||||
# to ensure determinism
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.get_device_module(device_type).manual_seed(0)
|
||||
super().__init__()
|
||||
|
||||
if has_wrapping:
|
||||
|
|
@ -50,12 +55,12 @@ class Net(nn.Module):
|
|||
nn.ReLU(),
|
||||
FSDP(
|
||||
nn.Linear(16, 8),
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
sharding_strategy=sharding_strategy,
|
||||
mixed_precision=mixed_precision,
|
||||
),
|
||||
),
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
sharding_strategy=sharding_strategy,
|
||||
mixed_precision=mixed_precision,
|
||||
)
|
||||
|
|
@ -134,11 +139,11 @@ class TestCommunicationHooks(FSDPTest):
|
|||
"""
|
||||
out_dim = self.world_size
|
||||
net = torch.nn.Linear(1, out_dim, bias=False)
|
||||
inpt = torch.tensor([self.rank]).float().cuda(self.rank)
|
||||
inpt = torch.tensor([self.rank]).float().to(self.rank)
|
||||
|
||||
net_default_hook = FSDP(
|
||||
net,
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
sharding_strategy=sharding_strategy,
|
||||
).to(self.rank)
|
||||
|
||||
|
|
@ -172,10 +177,10 @@ class TestCommunicationHooks(FSDPTest):
|
|||
]
|
||||
|
||||
def _init_model(self, core, sharding_strategy, mixed_precision=None):
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type)
|
||||
return FSDP(
|
||||
core,
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
sharding_strategy=sharding_strategy,
|
||||
mixed_precision=mixed_precision,
|
||||
).to(device)
|
||||
|
|
@ -277,7 +282,7 @@ class TestCommunicationHooks(FSDPTest):
|
|||
ShardingStrategy.HYBRID_SHARD,
|
||||
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
||||
):
|
||||
model = Net(False, None, None).cuda()
|
||||
model = Net(False, None, None).to(device=device_type)
|
||||
fsdp_model = FSDP(
|
||||
model,
|
||||
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
||||
|
|
@ -337,7 +342,7 @@ class TestCommunicationHooks(FSDPTest):
|
|||
):
|
||||
# keep everything deterministic for input data
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.get_device_module(device_type).manual_seed(0)
|
||||
|
||||
fsdp_with_hook = self._init_model(
|
||||
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
|
||||
|
|
@ -359,7 +364,7 @@ class TestCommunicationHooks(FSDPTest):
|
|||
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
|
||||
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
|
||||
|
||||
in_data = torch.rand(16, 8).cuda()
|
||||
in_data = torch.rand(16, 8).to(device=device_type)
|
||||
fsdp_with_hook.train()
|
||||
fsdp_with_mp.train()
|
||||
loss_hook = fsdp_with_hook(in_data).sum()
|
||||
|
|
@ -378,7 +383,7 @@ class TestCommunicationHooks(FSDPTest):
|
|||
):
|
||||
self.assertEqual(hook_param.grad, mp_param.grad)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("has_wrapping", [True, False])
|
||||
@parametrize(
|
||||
|
|
@ -399,11 +404,11 @@ class TestCommunicationHooks(FSDPTest):
|
|||
state, hook, sharding_strategy, torch.float16, has_wrapping
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not BFLOAT16_AVAILABLE,
|
||||
"BFloat16 is only supported by CUDA 11+",
|
||||
"BFloat16 is only supported by CUDA 11+ or XPU",
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("has_wrapping", [True, False])
|
||||
|
|
|
|||
|
|
@ -60,6 +60,10 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
@ -93,9 +97,9 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
|
||||
"""
|
||||
dev_id = (
|
||||
torch.cuda.current_device()
|
||||
torch.accelerator.current_device_index()
|
||||
if use_index
|
||||
else torch.device("cuda", torch.cuda.current_device())
|
||||
else torch.device(device_type, torch.accelerator.current_device_index())
|
||||
)
|
||||
|
||||
def _check_device_matches(module, device_id):
|
||||
|
|
@ -108,7 +112,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
self.assertEqual(1, len(devices))
|
||||
found_device = devices.pop()
|
||||
if use_index and not isinstance(device_id, torch.device):
|
||||
device = torch.device("cuda", device_id)
|
||||
device = torch.device(device_type, device_id)
|
||||
else:
|
||||
device = device_id
|
||||
self.assertEqual(found_device, device)
|
||||
|
|
@ -140,10 +144,11 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
DEVICEInitMode.DEVICE_BEFORE,
|
||||
fsdp_kwargs={"device_id": torch.device("cuda")},
|
||||
fsdp_kwargs={"device_id": torch.device(device_type)},
|
||||
)
|
||||
_check_device_matches(
|
||||
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
|
||||
nested_wrapped_module,
|
||||
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -178,8 +183,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
loss = torch.nn.functional.cross_entropy(output, y)
|
||||
return loss
|
||||
|
||||
model = Mnist().cuda()
|
||||
model1 = Mnist().cuda()
|
||||
model = Mnist().to(device=device_type)
|
||||
model1 = Mnist().to(device=device_type)
|
||||
model1.load_state_dict(model.state_dict())
|
||||
fsdp_model = FSDP(
|
||||
model,
|
||||
|
|
@ -197,17 +202,17 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
|
||||
seed = self.rank + 20231010
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.get_device_module(device_type).manual_seed(seed)
|
||||
|
||||
losses = []
|
||||
grads = []
|
||||
for i in range(5):
|
||||
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
|
||||
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
|
||||
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
|
||||
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
|
||||
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
|
||||
seed = self.rank + i
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.get_device_module(device_type).manual_seed(seed)
|
||||
loss = model(x, y).sum()
|
||||
losses.append(loss)
|
||||
loss.backward()
|
||||
|
|
@ -223,8 +228,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
fsdp_model.eval()
|
||||
ddp_model.eval()
|
||||
for _ in range(5):
|
||||
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
|
||||
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
|
||||
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
|
||||
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
|
||||
fsdp_loss = fsdp_model(x, y)
|
||||
ddp_loss = ddp_model(x, y)
|
||||
assert torch.allclose(fsdp_loss, ddp_loss)
|
||||
|
|
@ -232,12 +237,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
fsdp_model.train()
|
||||
ddp_model.train()
|
||||
for i in range(5):
|
||||
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
|
||||
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
|
||||
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
|
||||
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
|
||||
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
|
||||
seed = self.rank + i
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.get_device_module(device_type).manual_seed(seed)
|
||||
loss = model(x, y).sum()
|
||||
losses.append(loss)
|
||||
loss.backward()
|
||||
|
|
@ -272,12 +277,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
return out1
|
||||
|
||||
fsdp = FSDP(
|
||||
MyModel().cuda(),
|
||||
MyModel().to(device=device_type),
|
||||
sharding_strategy=sharding_strategy,
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
)
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = torch.randn(10, 10, device="cuda")
|
||||
x = torch.randn(10, 10, device=device_type)
|
||||
y = torch.randn(10, 10, device=device_type)
|
||||
for _ in range(4):
|
||||
if use_second_layer:
|
||||
a, _ = fsdp(x, y)
|
||||
|
|
@ -336,7 +341,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
torch.testing.assert_close(p1, p2)
|
||||
|
||||
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
|
||||
m = MyModule().cuda()
|
||||
m = MyModule().to(device=device_type)
|
||||
m_local = deepcopy(m)
|
||||
local_m = m_local
|
||||
prev_params = [p.clone() for p in m_local.parameters()]
|
||||
|
|
@ -349,7 +354,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
|
||||
|
||||
for i in range(6):
|
||||
t = torch.ones(4, device="cuda")
|
||||
t = torch.ones(4, device=device_type)
|
||||
a, b = m(t)
|
||||
local_a, local_b = local_m(t)
|
||||
if i < 2:
|
||||
|
|
@ -385,7 +390,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
@skip_if_lt_x_gpu(2)
|
||||
def test_fsdp_optim_overlap_no_use_orig_params_error(self):
|
||||
fsdp_overlap = FSDP(
|
||||
MyModel().cuda(),
|
||||
MyModel().to(device=device_type),
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
use_orig_params=False,
|
||||
)
|
||||
|
|
@ -398,7 +403,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
register_hook=False,
|
||||
)
|
||||
|
||||
inp = torch.randn(10, 10, device="cuda")
|
||||
inp = torch.randn(10, 10, device=device_type)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "only supported with use_orig_params=True"
|
||||
):
|
||||
|
|
@ -409,16 +414,16 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
torch.manual_seed(0)
|
||||
for cpu_offload in [True, False]:
|
||||
offload = CPUOffload(offload_params=cpu_offload)
|
||||
model = MyModel().cuda()
|
||||
model = MyModel().to(device=device_type)
|
||||
model_overlap = deepcopy(model)
|
||||
fsdp = FSDP(
|
||||
model.cuda(),
|
||||
model.to(device=device_type),
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
use_orig_params=True,
|
||||
cpu_offload=offload,
|
||||
)
|
||||
fsdp_overlap = FSDP(
|
||||
model_overlap.cuda(),
|
||||
model_overlap.to(device=device_type),
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
use_orig_params=True,
|
||||
cpu_offload=offload,
|
||||
|
|
@ -445,7 +450,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
]
|
||||
|
||||
for i in range(6):
|
||||
inp = torch.randn(2, 2, device="cuda")
|
||||
inp = torch.randn(2, 2, device=device_type)
|
||||
with torch.no_grad():
|
||||
inp_clone = inp.clone()
|
||||
fsdp(inp, inp).sum().backward()
|
||||
|
|
@ -546,7 +551,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
"""Tests that passing a CPU module to FSDP preserves that the wrapped
|
||||
module is on CPU after FSDP initialization, albeit after logging a
|
||||
warning, and that FSDP moves CPU input to GPU before the forward."""
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
regex = "passed-in `module` is on CPU"
|
||||
context = self.assertWarnsRegex(
|
||||
expected_warning=UserWarning, expected_regex=regex
|
||||
|
|
@ -561,7 +566,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
devices = {p.device for p in fsdp_model.parameters()}
|
||||
self.assertEqual(1, len(devices))
|
||||
self.assertEqual(torch.device("cpu"), devices.pop())
|
||||
fsdp_model = fsdp_model.cuda()
|
||||
fsdp_model = fsdp_model.to(device=device_type)
|
||||
# Ensure fwd + backward can be performed after moving to CUDA.
|
||||
# CPU input also tests that input is correctly moved to appropriate
|
||||
# CUDA device.
|
||||
|
|
@ -606,19 +611,19 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
nested_wrapped_module,
|
||||
self.process_group,
|
||||
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
sync_module_states=True,
|
||||
)
|
||||
# Each rank's buffers should be 0s since rank 0 is the source, and they
|
||||
# should be on GPU since we specified `device_id`
|
||||
self.assertEqual(
|
||||
nested_wrapped_module.buf.device,
|
||||
torch.device("cuda", torch.cuda.current_device()),
|
||||
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||
)
|
||||
self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
|
||||
self.assertEqual(
|
||||
nested_wrapped_module.module.module[0].buf.device,
|
||||
torch.device("cuda", torch.cuda.current_device()),
|
||||
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
|
||||
|
|
@ -644,9 +649,9 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
def forward(self, x):
|
||||
return x
|
||||
|
||||
m = MyModule().cuda()
|
||||
m = MyModule().to(device=device_type)
|
||||
m = FSDP(m)
|
||||
t = torch.ones(1, device="cuda", requires_grad=True)
|
||||
t = torch.ones(1, device=device_type, requires_grad=True)
|
||||
|
||||
MyOutputType = namedtuple(
|
||||
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
|
||||
|
|
@ -683,7 +688,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
auto_wrap_policy = ModuleWrapPolicy(module_classes)
|
||||
fsdp_kwargs = {
|
||||
"auto_wrap_policy": auto_wrap_policy,
|
||||
"device_id": torch.cuda.current_device(),
|
||||
"device_id": torch.accelerator.current_device_index(),
|
||||
}
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
|
|
@ -694,7 +699,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
|
||||
self.assertEqual(
|
||||
fsdp_module.compute_device,
|
||||
torch.device("cuda", torch.cuda.current_device()),
|
||||
torch.device(device_type, torch.accelerator.current_device_index()),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -729,7 +734,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
cpu_device = torch.device("cpu")
|
||||
|
|
@ -742,12 +747,16 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
module that does not match the GPU device ID raises an error."""
|
||||
# TODO: override FSDP MT Thread _run to set this instead of here for
|
||||
# every test.
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
|
||||
context = (
|
||||
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
|
||||
self.assertRaisesRegex(
|
||||
ValueError, f"{device_type}:{self.rank} vs {device_type}:0"
|
||||
)
|
||||
if self.rank != 0
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with context:
|
||||
NestedWrappedModule.init(
|
||||
self.process_group,
|
||||
|
|
@ -764,18 +773,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
"""Tests a CPU + GPU module supported if device_id is passed
|
||||
in, errors if device_id is not.
|
||||
"""
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
|
||||
class CPUGPUModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = nn.Linear(1, 1).cuda()
|
||||
self.a = nn.Linear(1, 1).to(device=device_type)
|
||||
self.b = nn.Linear(1, 1)
|
||||
|
||||
cpu_gpu = CPUGPUModule()
|
||||
fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device())
|
||||
fsdp = FSDP(cpu_gpu, device_id=torch.accelerator.current_device_index())
|
||||
for param in fsdp.parameters():
|
||||
self.assertEqual(param.device, torch.device(torch.cuda.current_device()))
|
||||
self.assertEqual(
|
||||
param.device, torch.device(torch.accelerator.current_device_index())
|
||||
)
|
||||
|
||||
# without device_id, we hit an error
|
||||
with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
|
||||
|
|
@ -783,7 +794,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fsdp_ignored_module_meta(self):
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
|
||||
class CPUGPUModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
@ -802,11 +813,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
m = CPUGPUModule()
|
||||
m = FSDP(
|
||||
m,
|
||||
device_id=torch.cuda.current_device(),
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
ignored_modules=[m.a],
|
||||
use_orig_params=True,
|
||||
param_init_fn=lambda m: m.to_empty(
|
||||
device=torch.cuda.current_device(), recurse=False
|
||||
device=torch.accelerator.current_device_index(), recurse=False
|
||||
),
|
||||
)
|
||||
self.assertEqual(meta_device, next(m.a.parameters()).device)
|
||||
|
|
@ -854,20 +865,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
"""
|
||||
# TODO: override FSDP MT Thread _run to set this instead of here for
|
||||
# every test.
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
# Test CPU
|
||||
no_params = nn.ReLU()
|
||||
FSDP(no_params)
|
||||
# Test CUDA
|
||||
no_params = nn.ReLU().cuda()
|
||||
no_params = nn.ReLU().to(device=device_type)
|
||||
FSDP(no_params)
|
||||
# Test CPU + device_id
|
||||
no_params = nn.ReLU()
|
||||
FSDP(no_params, device_id=torch.cuda.current_device())
|
||||
FSDP(no_params, device_id=torch.accelerator.current_device_index())
|
||||
# For modules with no params, wrong device_id will raise error about
|
||||
# inconsistency between compute_device and device_id, since compute_device
|
||||
# is computed as torch.cuda.current_device when there are no params.
|
||||
no_params = nn.ReLU().cuda()
|
||||
no_params = nn.ReLU().to(device=device_type)
|
||||
context = (
|
||||
(
|
||||
self.assertRaisesRegex(
|
||||
|
|
@ -892,11 +903,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
super().__init__()
|
||||
# Seed via rank to make model different across ranks
|
||||
torch.manual_seed(rank)
|
||||
torch.cuda.manual_seed(rank)
|
||||
torch.get_device_module(device_type).manual_seed(rank)
|
||||
self.lin = nn.Linear(10, 10, bias=False)
|
||||
self.buffer = nn.Buffer(torch.ones(1) * rank)
|
||||
|
||||
m = MyModel(self.rank).cuda()
|
||||
m = MyModel(self.rank).to(device=device_type)
|
||||
_assert_module_states(
|
||||
m, process_group=self.process_group, assert_fn=self.assertNotEqual
|
||||
)
|
||||
|
|
@ -913,7 +924,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
m, process_group=self.process_group, assert_fn=self.assertNotEqual
|
||||
)
|
||||
# Passing sync_module_states into FSDP makes model the same during init.
|
||||
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
|
||||
fsdp = FSDP(
|
||||
m,
|
||||
device_id=torch.accelerator.current_device_index(),
|
||||
sync_module_states=True,
|
||||
)
|
||||
with fsdp.summon_full_params(fsdp):
|
||||
_assert_module_states(
|
||||
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
|
||||
|
|
@ -968,7 +983,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
with self.assertRaisesRegex(
|
||||
ValueError, f"Expects one homogeneous value for {attr_name}"
|
||||
):
|
||||
inp = fsdp_model.module.get_input(torch.device("cuda"))
|
||||
inp = fsdp_model.module.get_input(torch.device(device_type))
|
||||
fsdp_model(*inp)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -976,7 +991,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
regex = r"FSDP will not all-gather parameters for containers that do not implement forward"
|
||||
model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)])
|
||||
with self.assertWarnsRegex(UserWarning, regex):
|
||||
FSDP(model, device_id="cuda")
|
||||
FSDP(model, device_id=device_type)
|
||||
model = nn.ModuleDict(
|
||||
{"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))}
|
||||
)
|
||||
|
|
@ -1000,7 +1015,10 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
|||
# warning
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # trigger all warnings
|
||||
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
|
||||
FSDP(
|
||||
nn.Linear(3, 3).to(device=device_type),
|
||||
sharding_strategy=ShardingStrategy.NO_SHARD,
|
||||
)
|
||||
for warning in w:
|
||||
self.assertTrue(
|
||||
warning.category != UserWarning
|
||||
|
|
@ -1014,16 +1032,20 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
|||
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
|
||||
)
|
||||
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
|
||||
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
|
||||
FSDP(
|
||||
nn.Linear(3, 3).to(device=device_type),
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
)
|
||||
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
|
||||
FSDP(nn.Linear(3, 3).cuda())
|
||||
FSDP(nn.Linear(3, 3).to(device=device_type))
|
||||
# - Pass `SHARD_GRAD_OP`
|
||||
expected_regex_shard_grad_op = (
|
||||
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
|
||||
)
|
||||
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
|
||||
FSDP(
|
||||
nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
|
||||
nn.Linear(3, 3).to(device=device_type),
|
||||
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
|
|
@ -1047,7 +1069,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
|||
# Incorrectly moving from CPU -> GPU
|
||||
model = torch.nn.Linear(10, 10)
|
||||
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
|
||||
fsdp_model.to(torch.device("cuda"))
|
||||
fsdp_model.to(torch.device(device_type))
|
||||
inp = torch.randn((2, 10))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
|
|
@ -1088,16 +1110,16 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
|||
|
||||
# Construct FSDP module without changing any environment variables and
|
||||
# run forward, which triggers both unsharded and sharded view setting
|
||||
module = SetattrLinear(5, 5, torch.device("cuda"))
|
||||
module = SetattrLinear(5, 5, torch.device(device_type))
|
||||
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
||||
inp = torch.randn((8, 5), device=torch.device("cuda"))
|
||||
inp = torch.randn((8, 5), device=torch.device(device_type))
|
||||
called_setattr_override = False
|
||||
fsdp_module(inp)
|
||||
self.assertTrue(called_setattr_override)
|
||||
|
||||
# Repeat with unsafe setattr explicitly enabled
|
||||
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
|
||||
module = SetattrLinear(5, 5, torch.device("cuda"))
|
||||
module = SetattrLinear(5, 5, torch.device(device_type))
|
||||
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
||||
called_setattr_override = False
|
||||
fsdp_module(inp)
|
||||
|
|
@ -1105,7 +1127,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
|||
|
||||
# Repeat with unsafe setattr explicitly disabled
|
||||
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
|
||||
module = SetattrLinear(5, 5, torch.device("cuda"))
|
||||
module = SetattrLinear(5, 5, torch.device(device_type))
|
||||
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
|
||||
called_setattr_override = False
|
||||
fsdp_module(inp)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ if not dist.is_available():
|
|||
from torch.testing._internal.common_distributed import (
|
||||
DistributedTestBase,
|
||||
MultiThreadedTestCase,
|
||||
requires_nccl,
|
||||
requires_accelerator_dist_backend,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
|
|||
skipIfHpu,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
|
@ -64,6 +65,9 @@ devices = ["cpu"]
|
|||
if TEST_HPU:
|
||||
devices.append("hpu")
|
||||
DEVICE = "hpu"
|
||||
elif TEST_XPU:
|
||||
devices.append("xpu")
|
||||
DEVICE = "xpu"
|
||||
elif TEST_CUDA:
|
||||
devices.append("cuda")
|
||||
|
||||
|
|
@ -269,10 +273,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_broadcast(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
tensor = torch.ones([4], device=device)
|
||||
|
|
@ -285,10 +289,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_all_reduce_eager(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
|
||||
tensor = torch.ones([4], device=device)
|
||||
mesh = dt.DeviceMesh(device, torch.arange(4))
|
||||
|
|
@ -302,10 +306,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_all_reduce_coalesced_eager(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
|
||||
t0 = torch.ones([4], device=device)
|
||||
t1 = torch.ones([6], device=device) + 2
|
||||
|
|
@ -317,10 +321,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_all_gather_tensor(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
|
||||
# testing 1d/2d mesh
|
||||
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||
|
|
@ -339,10 +343,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_all_gather_into_tensor_coalesced(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
|
||||
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
|
||||
mesh = dt.DeviceMesh(device, torch.arange(4))
|
||||
|
|
@ -356,10 +360,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_reduce_scatter_tensor(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
|
||||
# testing 1d/2d mesh
|
||||
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
||||
|
|
@ -380,10 +384,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
|
|||
|
||||
@parametrize("device", devices)
|
||||
def test_reduce_scatter_into_tensor_coalesced(self, device):
|
||||
if device == "cuda":
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
self.skipTest("Not enough CUDA devices")
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if device != "cpu":
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
self.skipTest("Not enough accelerator devices")
|
||||
torch.accelerator.set_device_index(dist.get_rank())
|
||||
tensors = [
|
||||
torch.ones([4], dtype=torch.int64, device=device),
|
||||
torch.ones([4], dtype=torch.int64, device=device) + 1,
|
||||
|
|
@ -474,18 +478,17 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
|||
# And then set the BACKEND variable appropriately.
|
||||
if TEST_HPU:
|
||||
BACKEND = dist.Backend.HCCL
|
||||
elif TEST_XPU:
|
||||
BACKEND = dist.Backend.XCCL
|
||||
|
||||
|
||||
# allows you to check for multiple accelerator irrespective of device type
|
||||
# to add new device types to this check simply follow the same format
|
||||
# and append an elif with the conditional and appropriate device count function for your new device
|
||||
def exit_if_lt_x_accelerators(x):
|
||||
if TEST_CUDA:
|
||||
if torch.cuda.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
elif TEST_HPU:
|
||||
if torch.hpu.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
|
||||
if torch.accelerator.is_available():
|
||||
if torch.accelerator.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
|
|
@ -494,7 +497,9 @@ def with_comms(func=None):
|
|||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
||||
if (
|
||||
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL
|
||||
) and torch.accelerator.device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
kwargs["device"] = DEVICE
|
||||
|
|
@ -572,7 +577,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
|||
self.assertEqual(y, expected)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@with_comms()
|
||||
def test_tracing(self, device):
|
||||
def allreduce(t, pg):
|
||||
|
|
@ -599,7 +604,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
|||
dist.destroy_process_group()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@with_comms()
|
||||
def test_tracing_with_dce_code(self, device):
|
||||
if self.world_size > 2:
|
||||
|
|
@ -818,13 +823,19 @@ class TestFunctionalAutogradWithDistributedBackend(DistributedTestBase):
|
|||
|
||||
# Update the supported devices in DEVICE
|
||||
instantiate_device_type_tests(
|
||||
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE
|
||||
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE, allow_xpu=True
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE
|
||||
TestDistributedBackendCollectivesWithWorldSize4,
|
||||
globals(),
|
||||
only_for=DEVICE,
|
||||
allow_xpu=True,
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE
|
||||
TestFunctionalAutogradWithDistributedBackend,
|
||||
globals(),
|
||||
only_for=DEVICE,
|
||||
allow_xpu=True,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -96,10 +96,10 @@ TEST_SKIPS = {
|
|||
class DistTestCases:
|
||||
# Backends that do not support a specific collective
|
||||
skip_collective = {}
|
||||
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
|
||||
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"}
|
||||
skip_collective["reduce"] = set()
|
||||
skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
|
||||
skip_collective["cpu barrier"] = {"nccl", "ucc"}
|
||||
skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"}
|
||||
skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"}
|
||||
|
||||
# Sets showing that something is implemented
|
||||
backend_feature = {}
|
||||
|
|
@ -338,15 +338,26 @@ def requires_gloo():
|
|||
|
||||
|
||||
def requires_nccl_version(version, msg):
|
||||
if not c10d.is_nccl_available():
|
||||
return skip_but_pass_in_sandcastle(
|
||||
"c10d was not compiled with the NCCL backend",
|
||||
)
|
||||
if TEST_CUDA:
|
||||
if not c10d.is_nccl_available():
|
||||
return skip_but_pass_in_sandcastle(
|
||||
"c10d was not compiled with the NCCL backend",
|
||||
)
|
||||
else:
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
torch.cuda.nccl.version() < version,
|
||||
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
|
||||
)
|
||||
else:
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
torch.cuda.nccl.version() < version,
|
||||
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
|
||||
)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def requires_nccl():
|
||||
|
|
@ -435,9 +446,10 @@ def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
|
|||
Returns True if the device's compute capability is (major, minor) or higher.
|
||||
Error out if the device is not a CUDA device.
|
||||
Returns False if device is a RoCM device.
|
||||
Returns True if device is a non-CUDA device.
|
||||
"""
|
||||
if device.type != "cuda":
|
||||
raise ValueError("sm_is_or_later() is only supported for CUDA devices")
|
||||
return True
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# ROCm devices may have different compute capability codes
|
||||
|
|
@ -1456,12 +1468,19 @@ class SaveForwardInputsModel(nn.Module):
|
|||
|
||||
@contextmanager
|
||||
def _dynamo_dist_per_rank_init(
|
||||
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
|
||||
rank, world_size, backend=None, init_pg=True, fake_pg=False
|
||||
):
|
||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||
if not fake_pg:
|
||||
torch.accelerator.set_device_index(rank)
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
)
|
||||
if backend is None:
|
||||
backend = c10d.get_default_backend_for_device(device_type)
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "6789"
|
||||
if init_pg:
|
||||
|
|
@ -1508,9 +1527,12 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
)
|
||||
cls.rank = 0
|
||||
cls.device = f"cuda:{cls.rank}"
|
||||
cls.device_ids = None if "cuda" in cls.device else [cls.rank]
|
||||
c10d.init_process_group("nccl", rank=cls.rank, world_size=1)
|
||||
device = torch.accelerator.current_accelerator().type
|
||||
cls.device = f"{device}:{cls.rank}"
|
||||
cls.device_ids = None if device in cls.device else [cls.rank]
|
||||
c10d.init_process_group(
|
||||
c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
|
|
@ -1122,6 +1123,7 @@ def check_sharded_parity(
|
|||
cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "not-support-multithread")
|
||||
class FSDPTestMultiThread(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
@ -1187,7 +1189,7 @@ class FSDPTest(MultiProcessTestCase):
|
|||
fake_pg = kwargs.get("fake_pg", False)
|
||||
|
||||
print(f"dist init r={self.rank}, world={self.world_size}")
|
||||
if torch.cuda.device_count() < self.world_size:
|
||||
if torch.accelerator.device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
# Specify gloo backend to make 'init_process_group()' succeed,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user