port 3 distributed test to Intel GPU and unified some common functions (#158533)

For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- enabled XPU for some test path
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  - requires_nccl_version
  - _dynamo_dist_per_rank_init
  - DynamoDistributedSingleProcTestCase
  - DistTestCases
  - FSDPTestMultiThread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158533
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
Deng, Daisy 2025-08-13 08:13:20 +00:00 committed by PyTorch MergeBot
parent 9a06e6d031
commit 6e8865fbc1
5 changed files with 201 additions and 139 deletions

View File

@ -13,7 +13,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecis
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import (
requires_nccl,
requires_accelerator_dist_backend,
requires_nccl_version,
skip_but_pass_in_sandcastle_if,
skip_if_lt_x_gpu,
@ -30,17 +30,22 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
# bfloat16 is only supported by CUDA 11+
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
torch.version.cuda is not None or torch.version.hip is not None
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
# bfloat16 is only supported by CUDA 11+ or XPU
BFLOAT16_AVAILABLE = (
torch.cuda.is_available()
and (torch.version.cuda is not None or torch.version.hip is not None)
) or torch.xpu.is_available()
class Net(nn.Module):
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
# to ensure determinism
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.get_device_module(device_type).manual_seed(0)
super().__init__()
if has_wrapping:
@ -50,12 +55,12 @@ class Net(nn.Module):
nn.ReLU(),
FSDP(
nn.Linear(16, 8),
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
),
),
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
)
@ -134,11 +139,11 @@ class TestCommunicationHooks(FSDPTest):
"""
out_dim = self.world_size
net = torch.nn.Linear(1, out_dim, bias=False)
inpt = torch.tensor([self.rank]).float().cuda(self.rank)
inpt = torch.tensor([self.rank]).float().to(self.rank)
net_default_hook = FSDP(
net,
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy,
).to(self.rank)
@ -172,10 +177,10 @@ class TestCommunicationHooks(FSDPTest):
]
def _init_model(self, core, sharding_strategy, mixed_precision=None):
device = torch.device("cuda")
device = torch.device(device_type)
return FSDP(
core,
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
).to(device)
@ -277,7 +282,7 @@ class TestCommunicationHooks(FSDPTest):
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
):
model = Net(False, None, None).cuda()
model = Net(False, None, None).to(device=device_type)
fsdp_model = FSDP(
model,
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
@ -337,7 +342,7 @@ class TestCommunicationHooks(FSDPTest):
):
# keep everything deterministic for input data
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.get_device_module(device_type).manual_seed(0)
fsdp_with_hook = self._init_model(
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
@ -359,7 +364,7 @@ class TestCommunicationHooks(FSDPTest):
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
in_data = torch.rand(16, 8).cuda()
in_data = torch.rand(16, 8).to(device=device_type)
fsdp_with_hook.train()
fsdp_with_mp.train()
loss_hook = fsdp_with_hook(in_data).sum()
@ -378,7 +383,7 @@ class TestCommunicationHooks(FSDPTest):
):
self.assertEqual(hook_param.grad, mp_param.grad)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False])
@parametrize(
@ -399,11 +404,11 @@ class TestCommunicationHooks(FSDPTest):
state, hook, sharding_strategy, torch.float16, has_wrapping
)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
@skip_but_pass_in_sandcastle_if(
not BFLOAT16_AVAILABLE,
"BFloat16 is only supported by CUDA 11+",
"BFloat16 is only supported by CUDA 11+ or XPU",
)
@skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False])

View File

@ -60,6 +60,10 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class MyModel(nn.Module):
def __init__(self) -> None:
@ -93,9 +97,9 @@ class TestFSDPMiscMultiProcess(FSDPTest):
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
"""
dev_id = (
torch.cuda.current_device()
torch.accelerator.current_device_index()
if use_index
else torch.device("cuda", torch.cuda.current_device())
else torch.device(device_type, torch.accelerator.current_device_index())
)
def _check_device_matches(module, device_id):
@ -108,7 +112,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
self.assertEqual(1, len(devices))
found_device = devices.pop()
if use_index and not isinstance(device_id, torch.device):
device = torch.device("cuda", device_id)
device = torch.device(device_type, device_id)
else:
device = device_id
self.assertEqual(found_device, device)
@ -140,10 +144,11 @@ class TestFSDPMiscMultiProcess(FSDPTest):
self.process_group,
FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_BEFORE,
fsdp_kwargs={"device_id": torch.device("cuda")},
fsdp_kwargs={"device_id": torch.device(device_type)},
)
_check_device_matches(
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
nested_wrapped_module,
torch.device(device_type, torch.accelerator.current_device_index()),
)
@skip_if_lt_x_gpu(2)
@ -178,8 +183,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
loss = torch.nn.functional.cross_entropy(output, y)
return loss
model = Mnist().cuda()
model1 = Mnist().cuda()
model = Mnist().to(device=device_type)
model1 = Mnist().to(device=device_type)
model1.load_state_dict(model.state_dict())
fsdp_model = FSDP(
model,
@ -197,17 +202,17 @@ class TestFSDPMiscMultiProcess(FSDPTest):
seed = self.rank + 20231010
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.get_device_module(device_type).manual_seed(seed)
losses = []
grads = []
for i in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.get_device_module(device_type).manual_seed(seed)
loss = model(x, y).sum()
losses.append(loss)
loss.backward()
@ -223,8 +228,8 @@ class TestFSDPMiscMultiProcess(FSDPTest):
fsdp_model.eval()
ddp_model.eval()
for _ in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
fsdp_loss = fsdp_model(x, y)
ddp_loss = ddp_model(x, y)
assert torch.allclose(fsdp_loss, ddp_loss)
@ -232,12 +237,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
fsdp_model.train()
ddp_model.train()
for i in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device=device_type)
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.get_device_module(device_type).manual_seed(seed)
loss = model(x, y).sum()
losses.append(loss)
loss.backward()
@ -272,12 +277,12 @@ class TestFSDPMiscMultiProcess(FSDPTest):
return out1
fsdp = FSDP(
MyModel().cuda(),
MyModel().to(device=device_type),
sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy,
)
x = torch.randn(10, 10, device="cuda")
y = torch.randn(10, 10, device="cuda")
x = torch.randn(10, 10, device=device_type)
y = torch.randn(10, 10, device=device_type)
for _ in range(4):
if use_second_layer:
a, _ = fsdp(x, y)
@ -336,7 +341,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
torch.testing.assert_close(p1, p2)
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda()
m = MyModule().to(device=device_type)
m_local = deepcopy(m)
local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()]
@ -349,7 +354,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
for i in range(6):
t = torch.ones(4, device="cuda")
t = torch.ones(4, device=device_type)
a, b = m(t)
local_a, local_b = local_m(t)
if i < 2:
@ -385,7 +390,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_fsdp_optim_overlap_no_use_orig_params_error(self):
fsdp_overlap = FSDP(
MyModel().cuda(),
MyModel().to(device=device_type),
auto_wrap_policy=always_wrap_policy,
use_orig_params=False,
)
@ -398,7 +403,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
register_hook=False,
)
inp = torch.randn(10, 10, device="cuda")
inp = torch.randn(10, 10, device=device_type)
with self.assertRaisesRegex(
RuntimeError, "only supported with use_orig_params=True"
):
@ -409,16 +414,16 @@ class TestFSDPMiscMultiProcess(FSDPTest):
torch.manual_seed(0)
for cpu_offload in [True, False]:
offload = CPUOffload(offload_params=cpu_offload)
model = MyModel().cuda()
model = MyModel().to(device=device_type)
model_overlap = deepcopy(model)
fsdp = FSDP(
model.cuda(),
model.to(device=device_type),
auto_wrap_policy=always_wrap_policy,
use_orig_params=True,
cpu_offload=offload,
)
fsdp_overlap = FSDP(
model_overlap.cuda(),
model_overlap.to(device=device_type),
auto_wrap_policy=always_wrap_policy,
use_orig_params=True,
cpu_offload=offload,
@ -445,7 +450,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
]
for i in range(6):
inp = torch.randn(2, 2, device="cuda")
inp = torch.randn(2, 2, device=device_type)
with torch.no_grad():
inp_clone = inp.clone()
fsdp(inp, inp).sum().backward()
@ -546,7 +551,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
"""Tests that passing a CPU module to FSDP preserves that the wrapped
module is on CPU after FSDP initialization, albeit after logging a
warning, and that FSDP moves CPU input to GPU before the forward."""
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
regex = "passed-in `module` is on CPU"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
@ -561,7 +566,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
devices = {p.device for p in fsdp_model.parameters()}
self.assertEqual(1, len(devices))
self.assertEqual(torch.device("cpu"), devices.pop())
fsdp_model = fsdp_model.cuda()
fsdp_model = fsdp_model.to(device=device_type)
# Ensure fwd + backward can be performed after moving to CUDA.
# CPU input also tests that input is correctly moved to appropriate
# CUDA device.
@ -606,19 +611,19 @@ class TestFSDPMiscMultiProcess(FSDPTest):
nested_wrapped_module,
self.process_group,
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
sync_module_states=True,
)
# Each rank's buffers should be 0s since rank 0 is the source, and they
# should be on GPU since we specified `device_id`
self.assertEqual(
nested_wrapped_module.buf.device,
torch.device("cuda", torch.cuda.current_device()),
torch.device(device_type, torch.accelerator.current_device_index()),
)
self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
self.assertEqual(
nested_wrapped_module.module.module[0].buf.device,
torch.device("cuda", torch.cuda.current_device()),
torch.device(device_type, torch.accelerator.current_device_index()),
)
self.assertEqual(
nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
@ -644,9 +649,9 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
def forward(self, x):
return x
m = MyModule().cuda()
m = MyModule().to(device=device_type)
m = FSDP(m)
t = torch.ones(1, device="cuda", requires_grad=True)
t = torch.ones(1, device=device_type, requires_grad=True)
MyOutputType = namedtuple(
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
@ -683,7 +688,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(),
"device_id": torch.accelerator.current_device_index(),
}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
@ -694,7 +699,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
self.assertEqual(
fsdp_module.compute_device,
torch.device("cuda", torch.cuda.current_device()),
torch.device(device_type, torch.accelerator.current_device_index()),
)
@skip_if_lt_x_gpu(2)
@ -729,7 +734,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
use_orig_params=use_orig_params,
)
cpu_device = torch.device("cpu")
@ -742,12 +747,16 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
module that does not match the GPU device ID raises an error."""
# TODO: override FSDP MT Thread _run to set this instead of here for
# every test.
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
self.assertRaisesRegex(
ValueError, f"{device_type}:{self.rank} vs {device_type}:0"
)
if self.rank != 0
else nullcontext()
)
with context:
NestedWrappedModule.init(
self.process_group,
@ -764,18 +773,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
"""Tests a CPU + GPU module supported if device_id is passed
in, errors if device_id is not.
"""
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
class CPUGPUModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = nn.Linear(1, 1).cuda()
self.a = nn.Linear(1, 1).to(device=device_type)
self.b = nn.Linear(1, 1)
cpu_gpu = CPUGPUModule()
fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device())
fsdp = FSDP(cpu_gpu, device_id=torch.accelerator.current_device_index())
for param in fsdp.parameters():
self.assertEqual(param.device, torch.device(torch.cuda.current_device()))
self.assertEqual(
param.device, torch.device(torch.accelerator.current_device_index())
)
# without device_id, we hit an error
with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
@ -783,7 +794,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
@skip_if_lt_x_gpu(2)
def test_fsdp_ignored_module_meta(self):
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
class CPUGPUModule(nn.Module):
def __init__(self) -> None:
@ -802,11 +813,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
m = CPUGPUModule()
m = FSDP(
m,
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
ignored_modules=[m.a],
use_orig_params=True,
param_init_fn=lambda m: m.to_empty(
device=torch.cuda.current_device(), recurse=False
device=torch.accelerator.current_device_index(), recurse=False
),
)
self.assertEqual(meta_device, next(m.a.parameters()).device)
@ -854,20 +865,20 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
"""
# TODO: override FSDP MT Thread _run to set this instead of here for
# every test.
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
# Test CPU
no_params = nn.ReLU()
FSDP(no_params)
# Test CUDA
no_params = nn.ReLU().cuda()
no_params = nn.ReLU().to(device=device_type)
FSDP(no_params)
# Test CPU + device_id
no_params = nn.ReLU()
FSDP(no_params, device_id=torch.cuda.current_device())
FSDP(no_params, device_id=torch.accelerator.current_device_index())
# For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params.
no_params = nn.ReLU().cuda()
no_params = nn.ReLU().to(device=device_type)
context = (
(
self.assertRaisesRegex(
@ -892,11 +903,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
super().__init__()
# Seed via rank to make model different across ranks
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
torch.get_device_module(device_type).manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.buffer = nn.Buffer(torch.ones(1) * rank)
m = MyModel(self.rank).cuda()
m = MyModel(self.rank).to(device=device_type)
_assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
@ -913,7 +924,11 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
fsdp = FSDP(
m,
device_id=torch.accelerator.current_device_index(),
sync_module_states=True,
)
with fsdp.summon_full_params(fsdp):
_assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
@ -968,7 +983,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
with self.assertRaisesRegex(
ValueError, f"Expects one homogeneous value for {attr_name}"
):
inp = fsdp_model.module.get_input(torch.device("cuda"))
inp = fsdp_model.module.get_input(torch.device(device_type))
fsdp_model(*inp)
@skip_if_lt_x_gpu(2)
@ -976,7 +991,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
regex = r"FSDP will not all-gather parameters for containers that do not implement forward"
model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)])
with self.assertWarnsRegex(UserWarning, regex):
FSDP(model, device_id="cuda")
FSDP(model, device_id=device_type)
model = nn.ModuleDict(
{"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))}
)
@ -1000,7 +1015,10 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # trigger all warnings
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
FSDP(
nn.Linear(3, 3).to(device=device_type),
sharding_strategy=ShardingStrategy.NO_SHARD,
)
for warning in w:
self.assertTrue(
warning.category != UserWarning
@ -1014,16 +1032,20 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
FSDP(
nn.Linear(3, 3).to(device=device_type),
sharding_strategy=ShardingStrategy.FULL_SHARD,
)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda())
FSDP(nn.Linear(3, 3).to(device=device_type))
# - Pass `SHARD_GRAD_OP`
expected_regex_shard_grad_op = (
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
)
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
FSDP(
nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
nn.Linear(3, 3).to(device=device_type),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
)
@skip_if_lt_x_gpu(1)
@ -1047,7 +1069,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# Incorrectly moving from CPU -> GPU
model = torch.nn.Linear(10, 10)
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
fsdp_model.to(torch.device("cuda"))
fsdp_model.to(torch.device(device_type))
inp = torch.randn((2, 10))
with self.assertRaisesRegex(
RuntimeError,
@ -1088,16 +1110,16 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# Construct FSDP module without changing any environment variables and
# run forward, which triggers both unsharded and sharded view setting
module = SetattrLinear(5, 5, torch.device("cuda"))
module = SetattrLinear(5, 5, torch.device(device_type))
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
inp = torch.randn((8, 5), device=torch.device("cuda"))
inp = torch.randn((8, 5), device=torch.device(device_type))
called_setattr_override = False
fsdp_module(inp)
self.assertTrue(called_setattr_override)
# Repeat with unsafe setattr explicitly enabled
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
module = SetattrLinear(5, 5, torch.device("cuda"))
module = SetattrLinear(5, 5, torch.device(device_type))
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
called_setattr_override = False
fsdp_module(inp)
@ -1105,7 +1127,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
# Repeat with unsafe setattr explicitly disabled
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
module = SetattrLinear(5, 5, torch.device("cuda"))
module = SetattrLinear(5, 5, torch.device(device_type))
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
called_setattr_override = False
fsdp_module(inp)

View File

@ -24,7 +24,7 @@ if not dist.is_available():
from torch.testing._internal.common_distributed import (
DistributedTestBase,
MultiThreadedTestCase,
requires_nccl,
requires_accelerator_dist_backend,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
skipIfHpu,
TEST_CUDA,
TEST_HPU,
TEST_XPU,
TestCase,
)
@ -64,6 +65,9 @@ devices = ["cpu"]
if TEST_HPU:
devices.append("hpu")
DEVICE = "hpu"
elif TEST_XPU:
devices.append("xpu")
DEVICE = "xpu"
elif TEST_CUDA:
devices.append("cuda")
@ -269,10 +273,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_broadcast(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
if dist.get_rank() == 0:
tensor = torch.ones([4], device=device)
@ -285,10 +289,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_all_reduce_eager(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
tensor = torch.ones([4], device=device)
mesh = dt.DeviceMesh(device, torch.arange(4))
@ -302,10 +306,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_all_reduce_coalesced_eager(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
t0 = torch.ones([4], device=device)
t1 = torch.ones([6], device=device) + 2
@ -317,10 +321,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_all_gather_tensor(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
# testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
@ -339,10 +343,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_all_gather_into_tensor_coalesced(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
mesh = dt.DeviceMesh(device, torch.arange(4))
@ -356,10 +360,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_reduce_scatter_tensor(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
# testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
@ -380,10 +384,10 @@ class TestTraceableCollectives(MultiThreadedTestCase):
@parametrize("device", devices)
def test_reduce_scatter_into_tensor_coalesced(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if device != "cpu":
if torch.accelerator.device_count() < self.world_size:
self.skipTest("Not enough accelerator devices")
torch.accelerator.set_device_index(dist.get_rank())
tensors = [
torch.ones([4], dtype=torch.int64, device=device),
torch.ones([4], dtype=torch.int64, device=device) + 1,
@ -474,18 +478,17 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
# And then set the BACKEND variable appropriately.
if TEST_HPU:
BACKEND = dist.Backend.HCCL
elif TEST_XPU:
BACKEND = dist.Backend.XCCL
# allows you to check for multiple accelerator irrespective of device type
# to add new device types to this check simply follow the same format
# and append an elif with the conditional and appropriate device count function for your new device
def exit_if_lt_x_accelerators(x):
if TEST_CUDA:
if torch.cuda.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
elif TEST_HPU:
if torch.hpu.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
def with_comms(func=None):
@ -494,7 +497,9 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
if (
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL
) and torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
kwargs["device"] = DEVICE
@ -572,7 +577,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
self.assertEqual(y, expected)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@with_comms()
def test_tracing(self, device):
def allreduce(t, pg):
@ -599,7 +604,7 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@with_comms()
def test_tracing_with_dce_code(self, device):
if self.world_size > 2:
@ -818,13 +823,19 @@ class TestFunctionalAutogradWithDistributedBackend(DistributedTestBase):
# Update the supported devices in DEVICE
instantiate_device_type_tests(
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE
TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE, allow_xpu=True
)
instantiate_device_type_tests(
TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE
TestDistributedBackendCollectivesWithWorldSize4,
globals(),
only_for=DEVICE,
allow_xpu=True,
)
instantiate_device_type_tests(
TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE
TestFunctionalAutogradWithDistributedBackend,
globals(),
only_for=DEVICE,
allow_xpu=True,
)
if __name__ == "__main__":

View File

@ -96,10 +96,10 @@ TEST_SKIPS = {
class DistTestCases:
# Backends that do not support a specific collective
skip_collective = {}
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"}
skip_collective["reduce"] = set()
skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
skip_collective["cpu barrier"] = {"nccl", "ucc"}
skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"}
skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"}
# Sets showing that something is implemented
backend_feature = {}
@ -338,15 +338,26 @@ def requires_gloo():
def requires_nccl_version(version, msg):
if not c10d.is_nccl_available():
return skip_but_pass_in_sandcastle(
"c10d was not compiled with the NCCL backend",
)
if TEST_CUDA:
if not c10d.is_nccl_available():
return skip_but_pass_in_sandcastle(
"c10d was not compiled with the NCCL backend",
)
else:
return skip_but_pass_in_sandcastle_if(
torch.cuda.nccl.version() < version,
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
)
else:
return skip_but_pass_in_sandcastle_if(
torch.cuda.nccl.version() < version,
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
def requires_nccl():
@ -435,9 +446,10 @@ def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
Returns True if the device's compute capability is (major, minor) or higher.
Error out if the device is not a CUDA device.
Returns False if device is a RoCM device.
Returns True if device is a non-CUDA device.
"""
if device.type != "cuda":
raise ValueError("sm_is_or_later() is only supported for CUDA devices")
return True
if torch.version.hip is not None:
# ROCm devices may have different compute capability codes
@ -1456,12 +1468,19 @@ class SaveForwardInputsModel(nn.Module):
@contextmanager
def _dynamo_dist_per_rank_init(
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
rank, world_size, backend=None, init_pg=True, fake_pg=False
):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg:
torch.accelerator.set_device_index(rank)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
if backend is None:
backend = c10d.get_default_backend_for_device(device_type)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789"
if init_pg:
@ -1508,9 +1527,12 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
)
)
cls.rank = 0
cls.device = f"cuda:{cls.rank}"
cls.device_ids = None if "cuda" in cls.device else [cls.rank]
c10d.init_process_group("nccl", rank=cls.rank, world_size=1)
device = torch.accelerator.current_accelerator().type
cls.device = f"{device}:{cls.rank}"
cls.device_ids = None if device in cls.device else [cls.rank]
c10d.init_process_group(
c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1
)
@classmethod
def tearDownClass(cls):

View File

@ -6,6 +6,7 @@ import os
import re
import sys
import time
import unittest
import warnings
from abc import ABC, abstractmethod
from contextlib import nullcontext
@ -1122,6 +1123,7 @@ def check_sharded_parity(
cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
@unittest.skipIf(TEST_XPU, "not-support-multithread")
class FSDPTestMultiThread(MultiThreadedTestCase):
@property
def world_size(self):
@ -1187,7 +1189,7 @@ class FSDPTest(MultiProcessTestCase):
fake_pg = kwargs.get("fake_pg", False)
print(f"dist init r={self.rank}, world={self.world_size}")
if torch.cuda.device_count() < self.world_size:
if torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
# Specify gloo backend to make 'init_process_group()' succeed,