port 3 distributed test to Intel GPU and unified some common functions (#158533)

For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- enabled XPU for some test path
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  - requires_nccl_version
  - _dynamo_dist_per_rank_init
  - DynamoDistributedSingleProcTestCase
  - DistTestCases
  - FSDPTestMultiThread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158533
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
Deng, Daisy 2025-08-13 08:13:20 +00:00 committed by PyTorch MergeBot
parent 9a06e6d031
commit 6e8865fbc1
5 changed files with 201 additions and 139 deletions

View File

@ -13,7 +13,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecis
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
requires_nccl, requires_accelerator_dist_backend,
requires_nccl_version, requires_nccl_version,
skip_but_pass_in_sandcastle_if, skip_but_pass_in_sandcastle_if,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
@ -30,17 +30,22 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr) print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0) sys.exit(0)
# bfloat16 is only supported by CUDA 11+ device_type = (
BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
torch.version.cuda is not None or torch.version.hip is not None
) )
# bfloat16 is only supported by CUDA 11+ or XPU
BFLOAT16_AVAILABLE = (
torch.cuda.is_available()
and (torch.version.cuda is not None or torch.version.hip is not None)
) or torch.xpu.is_available()
class Net(nn.Module): class Net(nn.Module):
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
# to ensure determinism # to ensure determinism
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) torch.get_device_module(device_type).manual_seed(0)
super().__init__() super().__init__()
if has_wrapping: if has_wrapping:
@ -50,12 +55,12 @@ class Net(nn.Module):
nn.ReLU(), nn.ReLU(),
FSDP( FSDP(
nn.Linear(16, 8), nn.Linear(16, 8),
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
), ),
), ),
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
) )
@ -134,11 +139,11 @@ class TestCommunicationHooks(FSDPTest):
""" """
out_dim = self.world_size out_dim = self.world_size
net = torch.nn.Linear(1, out_dim, bias=False) net = torch.nn.Linear(1, out_dim, bias=False)
inpt = torch.tensor([self.rank]).float().cuda(self.rank) inpt = torch.tensor([self.rank]).float().to(self.rank)
net_default_hook = FSDP( net_default_hook = FSDP(
net, net,
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
).to(self.rank) ).to(self.rank)
@ -172,10 +177,10 @@ class TestCommunicationHooks(FSDPTest):
] ]
def _init_model(self, core, sharding_strategy, mixed_precision=None): def _init_model(self, core, sharding_strategy, mixed_precision=None):
device = torch.device("cuda") device = torch.device(device_type)
return FSDP( return FSDP(
core, core,
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
).to(device) ).to(device)
@ -277,7 +282,7 @@ class TestCommunicationHooks(FSDPTest):
ShardingStrategy.HYBRID_SHARD, ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy._HYBRID_SHARD_ZERO2,
): ):
model = Net(False, None, None).cuda() model = Net(False, None, None).to(device=device_type)
fsdp_model = FSDP( fsdp_model = FSDP(
model, model,
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
@ -337,7 +342,7 @@ class TestCommunicationHooks(FSDPTest):
): ):
# keep everything deterministic for input data # keep everything deterministic for input data
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) torch.get_device_module(device_type).manual_seed(0)
fsdp_with_hook = self._init_model( fsdp_with_hook = self._init_model(
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
@ -359,7 +364,7 @@ class TestCommunicationHooks(FSDPTest):
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
in_data = torch.rand(16, 8).cuda() in_data = torch.rand(16, 8).to(device=device_type)
fsdp_with_hook.train() fsdp_with_hook.train()
fsdp_with_mp.train() fsdp_with_mp.train()
loss_hook = fsdp_with_hook(in_data).sum() loss_hook = fsdp_with_hook(in_data).sum()
@ -378,7 +383,7 @@ class TestCommunicationHooks(FSDPTest):
): ):
self.assertEqual(hook_param.grad, mp_param.grad) self.assertEqual(hook_param.grad, mp_param.grad)
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False]) @parametrize("has_wrapping", [True, False])
@parametrize( @parametrize(
@ -399,11 +404,11 @@ class TestCommunicationHooks(FSDPTest):
state, hook, sharding_strategy, torch.float16, has_wrapping state, hook, sharding_strategy, torch.float16, has_wrapping
) )
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
@skip_but_pass_in_sandcastle_if( @skip_but_pass_in_sandcastle_if(
not BFLOAT16_AVAILABLE, not BFLOAT16_AVAILABLE,
"BFloat16 is only supported by CUDA 11+", "BFloat16 is only supported by CUDA 11+ or XPU",
) )
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False]) @parametrize("has_wrapping", [True, False])

View File

@ -60,6 +60,10 @@ if TEST_WITH_DEV_DBG_ASAN:
) )
sys.exit(0) sys.exit(0)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class MyModel(nn.Module): class MyModel(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -93,9 +97,9 @@ class TestFSDPMiscMultiProcess(FSDPTest):
without specifying a device ID (i.e. ``torch.device("cuda")``) warns without specifying a device ID (i.e. ``torch.device("cuda")``) warns
""" """
dev_id = ( dev_id = (
torch.cuda.current_device() torch.accelerator.current_device_index()
if use_index if use_index
else torch.device("cuda", torch.cuda.current_device()) else torch.device(device_type, torch.accelerator.current_device_index())
) )
def _check_device_matches(module, device_id): def _check_device_matches(module, device_id):
@ -108,7 +112,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
self.assertEqual(1, len(devices)) self.assertEqual(1, len(devices))
found_device = devices.pop() found_device = devices.pop()
if use_index and not isinstance(device_id, torch.device): if use_index and not isinstance(device_id, torch.device):
device = torch.device("cuda", device_id) device = torch.device(device_type, device_id)
else: else:
device = device_id device = device_id
self.assertEqual(found_device, device) self.assertEqual(found_device, device)
@ -140,10 +144,11 @@ class TestFSDPMiscMultiProcess(FSDPTest):
self.process_group, self.process_group,
FSDPInitMode.RECURSIVE, FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_BEFORE, DEVICEInitMode.DEVICE_BEFORE,
fsdp_kwargs={"device_id": torch.device("cuda")}, fsdp_kwargs={"device_id": torch.device(device_type)},
) )
_check_device_matches( _check_device_matches(
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device()) nested_wrapped_module,
torch.device(device_type, torch.accelerator.current_device_index()),
) )
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@ -178,8 +183,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
loss = torch.nn.functional.cross_entropy(output, y) loss = torch.nn.functional.cross_entropy(output, y)
return loss return loss
model = Mnist().cuda() model = Mnist().to(device=device_type)
model1 = Mnist().cuda() model1 = Mnist().to(device=device_type)
model1.load_state_dict(model.state_dict()) model1.load_state_dict(model.state_dict())
fsdp_model = FSDP( fsdp_model = FSDP(
model, model,
@ -197,17 +202,17 @@ class TestFSDPMiscMultiProcess(FSDPTest):
seed = self.rank + 20231010 seed = self.rank + 20231010
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.get_device_module(device_type).manual_seed(seed)
losses = [] losses = []
grads = [] grads = []
for i in range(5): for i in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda") y = torch.randint(low=0, high=9, size=(8,), device=device_type)
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i seed = self.rank + i
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.get_device_module(device_type).manual_seed(seed)
loss = model(x, y).sum() loss = model(x, y).sum()
losses.append(loss) losses.append(loss)
loss.backward() loss.backward()
@ -223,8 +228,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
fsdp_model.eval() fsdp_model.eval()
ddp_model.eval() ddp_model.eval()
for _ in range(5): for _ in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda") y = torch.randint(low=0, high=9, size=(8,), device=device_type)
fsdp_loss = fsdp_model(x, y) fsdp_loss = fsdp_model(x, y)
ddp_loss = ddp_model(x, y) ddp_loss = ddp_model(x, y)
assert torch.allclose(fsdp_loss, ddp_loss) assert torch.allclose(fsdp_loss, ddp_loss)
@ -232,12 +237,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
fsdp_model.train() fsdp_model.train()
ddp_model.train() ddp_model.train()
for i in range(5): for i in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda") y = torch.randint(low=0, high=9, size=(8,), device=device_type)
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i seed = self.rank + i
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.get_device_module(device_type).manual_seed(seed)
loss = model(x, y).sum() loss = model(x, y).sum()
losses.append(loss) losses.append(loss)
loss.backward() loss.backward()
@ -272,12 +277,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
return out1 return out1
fsdp = FSDP( fsdp = FSDP(
MyModel().cuda(), MyModel().to(device=device_type),
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy, auto_wrap_policy=always_wrap_policy,
) )
x = torch.randn(10, 10, device="cuda") x = torch.randn(10, 10, device=device_type)
y = torch.randn(10, 10, device="cuda") y = torch.randn(10, 10, device=device_type)
for _ in range(4): for _ in range(4):
if use_second_layer: if use_second_layer:
a, _ = fsdp(x, y) a, _ = fsdp(x, y)
@ -336,7 +341,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
torch.testing.assert_close(p1, p2) torch.testing.assert_close(p1, p2)
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy) fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda() m = MyModule().to(device=device_type)
m_local = deepcopy(m) m_local = deepcopy(m)
local_m = m_local local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()] prev_params = [p.clone() for p in m_local.parameters()]
@ -349,7 +354,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3) opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
for i in range(6): for i in range(6):
t = torch.ones(4, device="cuda") t = torch.ones(4, device=device_type)
a, b = m(t) a, b = m(t)
local_a, local_b = local_m(t) local_a, local_b = local_m(t)
if i < 2: if i < 2:
@ -385,7 +390,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_fsdp_optim_overlap_no_use_orig_params_error(self): def test_fsdp_optim_overlap_no_use_orig_params_error(self):
fsdp_overlap = FSDP( fsdp_overlap = FSDP(
MyModel().cuda(), MyModel().to(device=device_type),
auto_wrap_policy=always_wrap_policy, auto_wrap_policy=always_wrap_policy,
use_orig_params=False, use_orig_params=False,
) )
@ -398,7 +403,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
register_hook=False, register_hook=False,
) )
inp = torch.randn(10, 10, device="cuda") inp = torch.randn(10, 10, device=device_type)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "only supported with use_orig_params=True" RuntimeError, "only supported with use_orig_params=True"
): ):
@ -409,16 +414,16 @@ class TestFSDPMiscMultiProcess(FSDPTest):
torch.manual_seed(0) torch.manual_seed(0)
for cpu_offload in [True, False]: for cpu_offload in [True, False]:
offload = CPUOffload(offload_params=cpu_offload) offload = CPUOffload(offload_params=cpu_offload)
model = MyModel().cuda() model = MyModel().to(device=device_type)
model_overlap = deepcopy(model) model_overlap = deepcopy(model)
fsdp = FSDP( fsdp = FSDP(
model.cuda(), model.to(device=device_type),
auto_wrap_policy=always_wrap_policy, auto_wrap_policy=always_wrap_policy,
use_orig_params=True, use_orig_params=True,
cpu_offload=offload, cpu_offload=offload,
) )
fsdp_overlap = FSDP( fsdp_overlap = FSDP(
model_overlap.cuda(), model_overlap.to(device=device_type),
auto_wrap_policy=always_wrap_policy, auto_wrap_policy=always_wrap_policy,
use_orig_params=True, use_orig_params=True,
cpu_offload=offload, cpu_offload=offload,
@ -445,7 +450,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
] ]
for i in range(6): for i in range(6):
inp = torch.randn(2, 2, device="cuda") inp = torch.randn(2, 2, device=device_type)
with torch.no_grad(): with torch.no_grad():
inp_clone = inp.clone() inp_clone = inp.clone()
fsdp(inp, inp).sum().backward() fsdp(inp, inp).sum().backward()
@ -546,7 +551,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
"""Tests that passing a CPU module to FSDP preserves that the wrapped """Tests that passing a CPU module to FSDP preserves that the wrapped
module is on CPU after FSDP initialization, albeit after logging a module is on CPU after FSDP initialization, albeit after logging a
warning, and that FSDP moves CPU input to GPU before the forward.""" warning, and that FSDP moves CPU input to GPU before the forward."""
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
regex = "passed-in `module` is on CPU" regex = "passed-in `module` is on CPU"
context = self.assertWarnsRegex( context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex expected_warning=UserWarning, expected_regex=regex
@ -561,7 +566,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
devices = {p.device for p in fsdp_model.parameters()} devices = {p.device for p in fsdp_model.parameters()}
self.assertEqual(1, len(devices)) self.assertEqual(1, len(devices))
self.assertEqual(torch.device("cpu"), devices.pop()) self.assertEqual(torch.device("cpu"), devices.pop())
fsdp_model = fsdp_model.cuda() fsdp_model = fsdp_model.to(device=device_type)
# Ensure fwd + backward can be performed after moving to CUDA. # Ensure fwd + backward can be performed after moving to CUDA.
# CPU input also tests that input is correctly moved to appropriate # CPU input also tests that input is correctly moved to appropriate
# CUDA device. # CUDA device.
@ -606,19 +611,19 @@ class TestFSDPMiscMultiProcess(FSDPTest):
nested_wrapped_module, nested_wrapped_module,
self.process_group, self.process_group,
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
sync_module_states=True, sync_module_states=True,
) )
# Each rank's buffers should be 0s since rank 0 is the source, and they # Each rank's buffers should be 0s since rank 0 is the source, and they
# should be on GPU since we specified `device_id` # should be on GPU since we specified `device_id`
self.assertEqual( self.assertEqual(
nested_wrapped_module.buf.device, nested_wrapped_module.buf.device,
torch.device("cuda", torch.cuda.current_device()), torch.device(device_type, torch.accelerator.current_device_index()),
) )
self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2))) self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
self.assertEqual( self.assertEqual(
nested_wrapped_module.module.module[0].buf.device, nested_wrapped_module.module.module[0].buf.device,
torch.device("cuda", torch.cuda.current_device()), torch.device(device_type, torch.accelerator.current_device_index()),
) )
self.assertEqual( self.assertEqual(
nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2)) nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
@ -644,9 +649,9 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
def forward(self, x): def forward(self, x):
return x return x
m = MyModule().cuda() m = MyModule().to(device=device_type)
m = FSDP(m) m = FSDP(m)
t = torch.ones(1, device="cuda", requires_grad=True) t = torch.ones(1, device=device_type, requires_grad=True)
MyOutputType = namedtuple( MyOutputType = namedtuple(
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t) "MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
@ -683,7 +688,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
auto_wrap_policy = ModuleWrapPolicy(module_classes) auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_kwargs = { fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy, "auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(), "device_id": torch.accelerator.current_device_index(),
} }
fsdp_model = TransformerWithSharedParams.init( fsdp_model = TransformerWithSharedParams.init(
self.process_group, self.process_group,
@ -694,7 +699,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
for fsdp_module in FSDP.fsdp_modules(fsdp_model): for fsdp_module in FSDP.fsdp_modules(fsdp_model):
self.assertEqual( self.assertEqual(
fsdp_module.compute_device, fsdp_module.compute_device,
torch.device("cuda", torch.cuda.current_device()), torch.device(device_type, torch.accelerator.current_device_index()),
) )
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@ -729,7 +734,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
model, model,
auto_wrap_policy=auto_wrap_policy, auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
use_orig_params=use_orig_params, use_orig_params=use_orig_params,
) )
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
@ -742,12 +747,16 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
module that does not match the GPU device ID raises an error.""" module that does not match the GPU device ID raises an error."""
# TODO: override FSDP MT Thread _run to set this instead of here for # TODO: override FSDP MT Thread _run to set this instead of here for
# every test. # every test.
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
context = ( context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0") self.assertRaisesRegex(
ValueError, f"{device_type}:{self.rank} vs {device_type}:0"
)
if self.rank != 0 if self.rank != 0
else nullcontext() else nullcontext()
) )
with context: with context:
NestedWrappedModule.init( NestedWrappedModule.init(
self.process_group, self.process_group,
@ -764,18 +773,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
"""Tests a CPU + GPU module supported if device_id is passed """Tests a CPU + GPU module supported if device_id is passed
in, errors if device_id is not. in, errors if device_id is not.
""" """
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
class CPUGPUModule(nn.Module): class CPUGPUModule(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.a = nn.Linear(1, 1).cuda() self.a = nn.Linear(1, 1).to(device=device_type)
self.b = nn.Linear(1, 1) self.b = nn.Linear(1, 1)
cpu_gpu = CPUGPUModule() cpu_gpu = CPUGPUModule()
fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device()) fsdp = FSDP(cpu_gpu, device_id=torch.accelerator.current_device_index())
for param in fsdp.parameters(): for param in fsdp.parameters():
self.assertEqual(param.device, torch.device(torch.cuda.current_device())) self.assertEqual(
param.device, torch.device(torch.accelerator.current_device_index())
)
# without device_id, we hit an error # without device_id, we hit an error
with self.assertRaisesRegex(RuntimeError, "please pass in device_id"): with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
@ -783,7 +794,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_fsdp_ignored_module_meta(self): def test_fsdp_ignored_module_meta(self):
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
class CPUGPUModule(nn.Module): class CPUGPUModule(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -802,11 +813,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
m = CPUGPUModule() m = CPUGPUModule()
m = FSDP( m = FSDP(
m, m,
device_id=torch.cuda.current_device(), device_id=torch.accelerator.current_device_index(),
ignored_modules=[m.a], ignored_modules=[m.a],
use_orig_params=True, use_orig_params=True,
param_init_fn=lambda m: m.to_empty( param_init_fn=lambda m: m.to_empty(
device=torch.cuda.current_device(), recurse=False device=torch.accelerator.current_device_index(), recurse=False
), ),
) )
self.assertEqual(meta_device, next(m.a.parameters()).device) self.assertEqual(meta_device, next(m.a.parameters()).device)
@ -854,20 +865,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
""" """
# TODO: override FSDP MT Thread _run to set this instead of here for # TODO: override FSDP MT Thread _run to set this instead of here for
# every test. # every test.
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
# Test CPU # Test CPU
no_params = nn.ReLU() no_params = nn.ReLU()
FSDP(no_params) FSDP(no_params)
# Test CUDA # Test CUDA
no_params = nn.ReLU().cuda() no_params = nn.ReLU().to(device=device_type)
FSDP(no_params) FSDP(no_params)
# Test CPU + device_id # Test CPU + device_id
no_params = nn.ReLU() no_params = nn.ReLU()
FSDP(no_params, device_id=torch.cuda.current_device()) FSDP(no_params, device_id=torch.accelerator.current_device_index())
# For modules with no params, wrong device_id will raise error about # For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device # inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params. # is computed as torch.cuda.current_device when there are no params.
no_params = nn.ReLU().cuda() no_params = nn.ReLU().to(device=device_type)
context = ( context = (
( (
self.assertRaisesRegex( self.assertRaisesRegex(
@ -892,11 +903,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
super().__init__() super().__init__()
# Seed via rank to make model different across ranks # Seed via rank to make model different across ranks
torch.manual_seed(rank) torch.manual_seed(rank)
torch.cuda.manual_seed(rank) torch.get_device_module(device_type).manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False) self.lin = nn.Linear(10, 10, bias=False)
self.buffer = nn.Buffer(torch.ones(1) * rank) self.buffer = nn.Buffer(torch.ones(1) * rank)
m = MyModel(self.rank).cuda() m = MyModel(self.rank).to(device=device_type)
_assert_module_states( _assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual m, process_group=self.process_group, assert_fn=self.assertNotEqual
) )
@ -913,7 +924,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
m, process_group=self.process_group, assert_fn=self.assertNotEqual m, process_group=self.process_group, assert_fn=self.assertNotEqual
) )
# Passing sync_module_states into FSDP makes model the same during init. # Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) fsdp = FSDP(
m,
device_id=torch.accelerator.current_device_index(),
sync_module_states=True,
)
with fsdp.summon_full_params(fsdp): with fsdp.summon_full_params(fsdp):
_assert_module_states( _assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual fsdp, process_group=self.process_group, assert_fn=self.assertEqual
@ -968,7 +983,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, f"Expects one homogeneous value for {attr_name}" ValueError, f"Expects one homogeneous value for {attr_name}"
): ):
inp = fsdp_model.module.get_input(torch.device("cuda")) inp = fsdp_model.module.get_input(torch.device(device_type))
fsdp_model(*inp) fsdp_model(*inp)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@ -976,7 +991,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
regex = r"FSDP will not all-gather parameters for containers that do not implement forward" regex = r"FSDP will not all-gather parameters for containers that do not implement forward"
model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)]) model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)])
with self.assertWarnsRegex(UserWarning, regex): with self.assertWarnsRegex(UserWarning, regex):
FSDP(model, device_id="cuda") FSDP(model, device_id=device_type)
model = nn.ModuleDict( model = nn.ModuleDict(
{"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))} {"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))}
) )
@ -1000,7 +1015,10 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# warning # warning
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # trigger all warnings warnings.simplefilter("always") # trigger all warnings
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD) FSDP(
nn.Linear(3, 3).to(device=device_type),
sharding_strategy=ShardingStrategy.NO_SHARD,
)
for warning in w: for warning in w:
self.assertTrue( self.assertTrue(
warning.category != UserWarning warning.category != UserWarning
@ -1014,16 +1032,20 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
) )
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD) FSDP(
nn.Linear(3, 3).to(device=device_type),
sharding_strategy=ShardingStrategy.FULL_SHARD,
)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda()) FSDP(nn.Linear(3, 3).to(device=device_type))
# - Pass `SHARD_GRAD_OP` # - Pass `SHARD_GRAD_OP`
expected_regex_shard_grad_op = ( expected_regex_shard_grad_op = (
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
) )
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op): with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
FSDP( FSDP(
nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP nn.Linear(3, 3).to(device=device_type),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
) )
@skip_if_lt_x_gpu(1) @skip_if_lt_x_gpu(1)
@ -1047,7 +1069,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# Incorrectly moving from CPU -> GPU # Incorrectly moving from CPU -> GPU
model = torch.nn.Linear(10, 10) model = torch.nn.Linear(10, 10)
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True)) fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
fsdp_model.to(torch.device("cuda")) fsdp_model.to(torch.device(device_type))
inp = torch.randn((2, 10)) inp = torch.randn((2, 10))
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
@ -1088,16 +1110,16 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# Construct FSDP module without changing any environment variables and # Construct FSDP module without changing any environment variables and
# run forward, which triggers both unsharded and sharded view setting # run forward, which triggers both unsharded and sharded view setting
module = SetattrLinear(5, 5, torch.device("cuda")) module = SetattrLinear(5, 5, torch.device(device_type))
fsdp_module = FSDP(module, use_orig_params=use_orig_params) fsdp_module = FSDP(module, use_orig_params=use_orig_params)
inp = torch.randn((8, 5), device=torch.device("cuda")) inp = torch.randn((8, 5), device=torch.device(device_type))
called_setattr_override = False called_setattr_override = False
fsdp_module(inp) fsdp_module(inp)
self.assertTrue(called_setattr_override) self.assertTrue(called_setattr_override)
# Repeat with unsafe setattr explicitly enabled # Repeat with unsafe setattr explicitly enabled
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1" os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
module = SetattrLinear(5, 5, torch.device("cuda")) module = SetattrLinear(5, 5, torch.device(device_type))
fsdp_module = FSDP(module, use_orig_params=use_orig_params) fsdp_module = FSDP(module, use_orig_params=use_orig_params)
called_setattr_override = False called_setattr_override = False
fsdp_module(inp) fsdp_module(inp)
@ -1105,7 +1127,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# Repeat with unsafe setattr explicitly disabled # Repeat with unsafe setattr explicitly disabled
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0" os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
module = SetattrLinear(5, 5, torch.device("cuda")) module = SetattrLinear(5, 5, torch.device(device_type))
fsdp_module = FSDP(module, use_orig_params=use_orig_params) fsdp_module = FSDP(module, use_orig_params=use_orig_params)
called_setattr_override = False called_setattr_override = False
fsdp_module(inp) fsdp_module(inp)

View File

@ -24,7 +24,7 @@ if not dist.is_available():
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
DistributedTestBase, DistributedTestBase,
MultiThreadedTestCase, MultiThreadedTestCase,
requires_nccl, requires_accelerator_dist_backend,
TEST_SKIPS, TEST_SKIPS,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
skipIfHpu, skipIfHpu,
TEST_CUDA, TEST_CUDA,
TEST_HPU, TEST_HPU,
TEST_XPU,
TestCase, TestCase,
) )
@ -64,6 +65,9 @@ devices = ["cpu"]
if TEST_HPU: if TEST_HPU:
devices.append("hpu") devices.append("hpu")
DEVICE = "hpu" DEVICE = "hpu"
elif TEST_XPU:
devices.append("xpu")
DEVICE = "xpu"
elif TEST_CUDA: elif TEST_CUDA:
devices.append("cuda") devices.append("cuda")
@ -269,10 +273,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_broadcast(self, device): def test_broadcast(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
if dist.get_rank() == 0: if dist.get_rank() == 0:
tensor = torch.ones([4], device=device) tensor = torch.ones([4], device=device)
@ -285,10 +289,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_all_reduce_eager(self, device): def test_all_reduce_eager(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
tensor = torch.ones([4], device=device) tensor = torch.ones([4], device=device)
mesh = dt.DeviceMesh(device, torch.arange(4)) mesh = dt.DeviceMesh(device, torch.arange(4))
@ -302,10 +306,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_all_reduce_coalesced_eager(self, device): def test_all_reduce_coalesced_eager(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
t0 = torch.ones([4], device=device) t0 = torch.ones([4], device=device)
t1 = torch.ones([6], device=device) + 2 t1 = torch.ones([6], device=device) + 2
@ -317,10 +321,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_all_gather_tensor(self, device): def test_all_gather_tensor(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
# testing 1d/2d mesh # testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
@ -339,10 +343,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_all_gather_into_tensor_coalesced(self, device): def test_all_gather_into_tensor_coalesced(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1] tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
mesh = dt.DeviceMesh(device, torch.arange(4)) mesh = dt.DeviceMesh(device, torch.arange(4))
@ -356,10 +360,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_reduce_scatter_tensor(self, device): def test_reduce_scatter_tensor(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
# testing 1d/2d mesh # testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
@ -380,10 +384,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices) @parametrize("device", devices)
def test_reduce_scatter_into_tensor_coalesced(self, device): def test_reduce_scatter_into_tensor_coalesced(self, device):
if device == "cuda": if device != "cpu":
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices") self.skipTest("Not enough accelerator devices")
torch.cuda.set_device(dist.get_rank()) torch.accelerator.set_device_index(dist.get_rank())
tensors = [ tensors = [
torch.ones([4], dtype=torch.int64, device=device), torch.ones([4], dtype=torch.int64, device=device),
torch.ones([4], dtype=torch.int64, device=device) + 1, torch.ones([4], dtype=torch.int64, device=device) + 1,
@ -474,18 +478,17 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
# And then set the BACKEND variable appropriately. # And then set the BACKEND variable appropriately.
if TEST_HPU: if TEST_HPU:
BACKEND = dist.Backend.HCCL BACKEND = dist.Backend.HCCL
elif TEST_XPU:
BACKEND = dist.Backend.XCCL
# allows you to check for multiple accelerator irrespective of device type # allows you to check for multiple accelerator irrespective of device type
# to add new device types to this check simply follow the same format # to add new device types to this check simply follow the same format
# and append an elif with the conditional and appropriate device count function for your new device # and append an elif with the conditional and appropriate device count function for your new device
def exit_if_lt_x_accelerators(x): def exit_if_lt_x_accelerators(x):
if TEST_CUDA: if torch.accelerator.is_available():
if torch.cuda.device_count() < x: if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
elif TEST_HPU:
if torch.hpu.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
def with_comms(func=None): def with_comms(func=None):
@ -494,7 +497,9 @@ def with_comms(func=None):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: if (
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL
) and torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
kwargs["device"] = DEVICE kwargs["device"] = DEVICE
@ -572,7 +577,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
self.assertEqual(y, expected) self.assertEqual(y, expected)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
@with_comms() @with_comms()
def test_tracing(self, device): def test_tracing(self, device):
def allreduce(t, pg): def allreduce(t, pg):
@ -599,7 +604,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
dist.destroy_process_group() dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
@with_comms() @with_comms()
def test_tracing_with_dce_code(self, device): def test_tracing_with_dce_code(self, device):
if self.world_size > 2: if self.world_size > 2:
@ -818,13 +823,19 @@ class TestFunctionalAutogradWithDistributedBackend(DistributedTestBase):
# Update the supported devices in DEVICE # Update the supported devices in DEVICE
instantiate_device_type_tests( instantiate_device_type_tests(
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE, allow_xpu=True
) )
instantiate_device_type_tests( instantiate_device_type_tests(
TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE TestDistributedBackendCollectivesWithWorldSize4,
globals(),
only_for=DEVICE,
allow_xpu=True,
) )
instantiate_device_type_tests( instantiate_device_type_tests(
TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE TestFunctionalAutogradWithDistributedBackend,
globals(),
only_for=DEVICE,
allow_xpu=True,
) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -96,10 +96,10 @@ TEST_SKIPS = {
class DistTestCases: class DistTestCases:
# Backends that do not support a specific collective # Backends that do not support a specific collective
skip_collective = {} skip_collective = {}
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"} skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"}
skip_collective["reduce"] = set() skip_collective["reduce"] = set()
skip_collective["sendrecv anysource"] = {"nccl", "ucc"} skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"}
skip_collective["cpu barrier"] = {"nccl", "ucc"} skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"}
# Sets showing that something is implemented # Sets showing that something is implemented
backend_feature = {} backend_feature = {}
@ -338,15 +338,26 @@ def requires_gloo():
def requires_nccl_version(version, msg): def requires_nccl_version(version, msg):
if not c10d.is_nccl_available(): if TEST_CUDA:
return skip_but_pass_in_sandcastle( if not c10d.is_nccl_available():
"c10d was not compiled with the NCCL backend", return skip_but_pass_in_sandcastle(
) "c10d was not compiled with the NCCL backend",
)
else:
return skip_but_pass_in_sandcastle_if(
torch.cuda.nccl.version() < version,
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
)
else: else:
return skip_but_pass_in_sandcastle_if(
torch.cuda.nccl.version() < version, def decorator(func):
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}", @wraps(func)
) def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
def requires_nccl(): def requires_nccl():
@ -435,9 +446,10 @@ def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
Returns True if the device's compute capability is (major, minor) or higher. Returns True if the device's compute capability is (major, minor) or higher.
Error out if the device is not a CUDA device. Error out if the device is not a CUDA device.
Returns False if device is a RoCM device. Returns False if device is a RoCM device.
Returns True if device is a non-CUDA device.
""" """
if device.type != "cuda": if device.type != "cuda":
raise ValueError("sm_is_or_later() is only supported for CUDA devices") return True
if torch.version.hip is not None: if torch.version.hip is not None:
# ROCm devices may have different compute capability codes # ROCm devices may have different compute capability codes
@ -1456,12 +1468,19 @@ class SaveForwardInputsModel(nn.Module):
@contextmanager @contextmanager
def _dynamo_dist_per_rank_init( def _dynamo_dist_per_rank_init(
rank, world_size, backend="nccl", init_pg=True, fake_pg=False rank, world_size, backend=None, init_pg=True, fake_pg=False
): ):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear. # Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg: if not fake_pg:
torch.accelerator.set_device_index(rank) torch.accelerator.set_device_index(rank)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
if backend is None:
backend = c10d.get_default_backend_for_device(device_type)
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789" os.environ["MASTER_PORT"] = "6789"
if init_pg: if init_pg:
@ -1508,9 +1527,12 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
) )
) )
cls.rank = 0 cls.rank = 0
cls.device = f"cuda:{cls.rank}" device = torch.accelerator.current_accelerator().type
cls.device_ids = None if "cuda" in cls.device else [cls.rank] cls.device = f"{device}:{cls.rank}"
c10d.init_process_group("nccl", rank=cls.rank, world_size=1) cls.device_ids = None if device in cls.device else [cls.rank]
c10d.init_process_group(
c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):

View File

@ -6,6 +6,7 @@ import os
import re import re
import sys import sys
import time import time
import unittest
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
@ -1122,6 +1123,7 @@ def check_sharded_parity(
cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local()) cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
@unittest.skipIf(TEST_XPU, "not-support-multithread")
class FSDPTestMultiThread(MultiThreadedTestCase): class FSDPTestMultiThread(MultiThreadedTestCase):
@property @property
def world_size(self): def world_size(self):
@ -1187,7 +1189,7 @@ class FSDPTest(MultiProcessTestCase):
fake_pg = kwargs.get("fake_pg", False) fake_pg = kwargs.get("fake_pg", False)
print(f"dist init r={self.rank}, world={self.world_size}") print(f"dist init r={self.rank}, world={self.world_size}")
if torch.cuda.device_count() < self.world_size: if torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
# Specify gloo backend to make 'init_process_group()' succeed, # Specify gloo backend to make 'init_process_group()' succeed,