[2/N] Port 5 _composable distributed test to Intel GPU (#159241)

For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This is the second PR for _composable cases, the first is https://github.com/pytorch/pytorch/pull/159118.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- Use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- Enabled XPU for some test path
- Skip some test cases which Intel GPU does not support
- Added "cpu:gloo,xpu:xccl" for distributed backend

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159241
Approved by: https://github.com/guangyey, https://github.com/d4l3k
This commit is contained in:
Zeng, Xiangdong 2025-09-15 06:24:55 +00:00 committed by PyTorch MergeBot
parent 06bb32d55e
commit 814ba34fa6
6 changed files with 120 additions and 75 deletions

View File

@ -10,10 +10,13 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable import checkpoint from torch.distributed._composable import checkpoint
from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.utils.checkpoint import CheckpointError from torch.utils.checkpoint import CheckpointError
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class MemoryDelta(ContextDecorator): class MemoryDelta(ContextDecorator):
def __init__(self, device: torch.device): def __init__(self, device: torch.device):
self.device: torch.device = device self.device: torch.device = device
@ -22,16 +25,16 @@ class MemoryDelta(ContextDecorator):
def __enter__(self): def __enter__(self):
self.active_memory_enter = ( self.active_memory_enter = (
torch.cuda.memory_stats()["active_bytes.all.current"] torch.accelerator.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda" if self.device.type == "cuda" or self.device.type == "xpu"
else 0 else 0
) )
return self return self
def __exit__(self, *exc): def __exit__(self, *exc):
self.active_memory_exit = ( self.active_memory_exit = (
torch.cuda.memory_stats()["active_bytes.all.current"] torch.accelerator.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda" if self.device.type == "cuda" or self.device.type == "xpu"
else 0 else 0
) )
@ -126,7 +129,7 @@ class TestCheckpoint(TestCase):
loss2 = net2(x2).sum() loss2 = net2(x2).sum()
loss2.backward() loss2.backward()
if x.is_cuda: if x.is_cuda or x.is_xpu:
self.assertTrue(mem2.delta() < mem1.delta()) self.assertTrue(mem2.delta() < mem1.delta())
for p1, p2 in zip(net1.parameters(), net2.parameters()): for p1, p2 in zip(net1.parameters(), net2.parameters()):
@ -137,10 +140,10 @@ class TestCheckpoint(TestCase):
net = ToyModel() net = ToyModel()
self._test_tensor_only(net, x) self._test_tensor_only(net, x)
@unittest.skipIf(not TEST_CUDA, "no cuda") @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
def test_tensor_only_gpu(self): def test_tensor_only_gpu(self):
x = torch.randn(20, 100, device="cuda:0") x = torch.randn(20, 100, device=f"{device_type}:0")
net = ToyModel().to("cuda:0") net = ToyModel().to(f"{device_type}:0")
self._test_tensor_only(net, x) self._test_tensor_only(net, x)
def test_random_cpu(self): def test_random_cpu(self):

View File

@ -47,6 +47,8 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
run_tests, run_tests,
TEST_XPU,
xfailIf,
) )
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,
@ -58,6 +60,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class SimpleModel(nn.Module): class SimpleModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -73,7 +78,7 @@ class SimpleModel(nn.Module):
return x return x
def get_input(self): def get_input(self):
return torch.rand(4, 5, device="cuda") return torch.rand(4, 5, device=device_type)
class SimpleModelUneven(nn.Module): class SimpleModelUneven(nn.Module):
@ -94,7 +99,7 @@ class SimpleModelUneven(nn.Module):
return x return x
def get_input(self): def get_input(self):
return torch.rand(4, 5, device="cuda") return torch.rand(4, 5, device=device_type)
class TestFullyShard2DTraining(FSDPTest): class TestFullyShard2DTraining(FSDPTest):
@ -105,13 +110,15 @@ class TestFullyShard2DTraining(FSDPTest):
@property @property
def world_size(self) -> int: def world_size(self) -> int:
return min(4, torch.cuda.device_count()) return min(4, torch.accelerator.device_count())
def init_global_mesh(self) -> DeviceMesh: def init_global_mesh(self) -> DeviceMesh:
# Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP # Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
dp_size = 2 if self.world_size > 2 else 1 dp_size = 2 if self.world_size > 2 else 1
return init_device_mesh( return init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") device_type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
) )
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@ -138,7 +145,7 @@ class TestFullyShard2DTraining(FSDPTest):
torch.manual_seed(42) torch.manual_seed(42)
model = MLPStack(mlp_dim) model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda() ref_model = copy.deepcopy(model).to(device_type)
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize( model.parallelize(
@ -150,9 +157,8 @@ class TestFullyShard2DTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
torch.manual_seed(42 + dp_pg.rank() + 1) torch.manual_seed(42 + dp_pg.rank() + 1)
device = torch.device("cuda")
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device) inp = torch.randn((8, mlp_dim), device=device_type)
losses: list[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
@ -162,6 +168,7 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(losses[0], losses[1]) self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881
def test_train_parity_2d_transformer(self): def test_train_parity_2d_transformer(self):
self.run_subtests( self.run_subtests(
{"use_shard_placement_fn": [False, True]}, {"use_shard_placement_fn": [False, True]},
@ -172,12 +179,12 @@ class TestFullyShard2DTraining(FSDPTest):
torch.manual_seed(42) torch.manual_seed(42)
model_args = ModelArgs(n_layers=3, dropout_p=0.0) model_args = ModelArgs(n_layers=3, dropout_p=0.0)
model = Transformer(model_args) model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda() ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
dp_size, tp_size = self.world_size // 2, 2 dp_size, tp_size = self.world_size // 2, 2
global_mesh = init_device_mesh( global_mesh = init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
) )
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)
@ -205,7 +212,7 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(full_param, ref_param) self.assertEqual(full_param, ref_param)
torch.manual_seed(42 + global_mesh.get_local_rank("dp")) torch.manual_seed(42 + global_mesh.get_local_rank("dp"))
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type)
for iter_idx in range(5): for iter_idx in range(5):
ref_loss = ref_model(inp).sum() ref_loss = ref_model(inp).sum()
loss = model(inp).sum() loss = model(inp).sum()
@ -242,15 +249,16 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(full_param, ref_param) self.assertEqual(full_param, ref_param)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/pytorch/pytorch/issues/156782
def test_tp_with_fsdp_offloading(self): def test_tp_with_fsdp_offloading(self):
global_mesh = init_device_mesh( global_mesh = init_device_mesh(
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp") device_type, (1, self.world_size), mesh_dim_names=("dp", "tp")
) )
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42) torch.manual_seed(42)
mlp_dim = 16 mlp_dim = 16
model = MLPStack(mlp_dim) model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda() ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
# Parallelize with N-way TP and 1-way FSDP # Parallelize with N-way TP and 1-way FSDP
model.parallelize( model.parallelize(
@ -268,7 +276,7 @@ class TestFullyShard2DTraining(FSDPTest):
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
# called, but they will just be no-ops without issuing any kernels. # called, but they will just be no-ops without issuing any kernels.
# We prefer to keep the no-op check at the c10d level, not in FSDP. # We prefer to keep the no-op check at the c10d level, not in FSDP.
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks inp = torch.randn((4, mlp_dim), device=device_type) # same on all ranks
for _ in range(10): for _ in range(10):
ref_optim.zero_grad() ref_optim.zero_grad()
optim.zero_grad() optim.zero_grad()
@ -297,6 +305,7 @@ class TestFullyShard2DTraining(FSDPTest):
ref_optim.step() ref_optim.step()
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881
@with_temp_dir @with_temp_dir
def test_train_parity_2d_transformer_checkpoint_resume(self): def test_train_parity_2d_transformer_checkpoint_resume(self):
""" """
@ -352,7 +361,7 @@ class TestFullyShard2DTraining(FSDPTest):
) )
torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1) torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1)
inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda") inp = torch.randint(0, model_args.vocab_size, (3, 16), device=device_type)
loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp) loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp) loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp)
@ -410,14 +419,14 @@ class TestFullyShard2DStateDict(DTensorTestBase):
@property @property
def backend(self): def backend(self):
# need to specify gloo backend for testing cpu offload # need to specify gloo backend for testing cpu offload
return "cpu:gloo,cuda:nccl" return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
@with_comms @with_comms
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
def test_fully_shard_tp_2d_set_full_state_dict(self): def test_fully_shard_tp_2d_set_full_state_dict(self):
dummy_model = SimpleModel().cuda() dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh( mesh_2d = init_device_mesh(
"cuda", device_type,
(2, self.world_size // 2), (2, self.world_size // 2),
mesh_dim_names=("dp", "tp"), mesh_dim_names=("dp", "tp"),
) )
@ -561,7 +570,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
) )
model = FSDP( model = FSDP(
SimpleModel().cuda(), SimpleModel().to(device_type),
device_mesh=mesh_2d["dp"], device_mesh=mesh_2d["dp"],
) )
fsdp_state = _get_module_fsdp_state(model) fsdp_state = _get_module_fsdp_state(model)
@ -573,7 +582,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
recompute_activation=False, recompute_activation=False,
) -> None: ) -> None:
torch.manual_seed(0) torch.manual_seed(0)
model = SimpleModel().cuda(self.rank) model = SimpleModel().to(f"{device_type}:{self.rank}")
model = FSDP(model, use_orig_params=use_orig_params) model = FSDP(model, use_orig_params=use_orig_params)
optim = torch.optim.Adam(model.parameters(), lr=0.01) optim = torch.optim.Adam(model.parameters(), lr=0.01)
@ -587,7 +596,9 @@ class TestNew2dParallelTraining(DTensorTestBase):
"net1": ColwiseParallel(), "net1": ColwiseParallel(),
"net2": RowwiseParallel(), "net2": RowwiseParallel(),
} }
model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan) model_2d = parallelize_module(
SimpleModel().to(device_type), tp_mesh, parallelize_plan
)
model_2d = FSDP( model_2d = FSDP(
model_2d, model_2d,
device_mesh=dp_mesh, device_mesh=dp_mesh,
@ -615,7 +626,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
# Ensure all input across TP ranks are same. # Ensure all input across TP ranks are same.
# TODO: add a get_group_rank() to DeviceMesh. # TODO: add a get_group_rank() to DeviceMesh.
torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0))) torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0)))
input = torch.rand(4, 5).cuda(self.rank) input = torch.rand(4, 5).to(f"{device_type}:{self.rank}")
output = model(input) output = model(input)
output_2d = model_2d(input) output_2d = model_2d(input)
self.assertEqual(output, output_2d) self.assertEqual(output, output_2d)
@ -652,7 +663,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@property @property
def backend(self): def backend(self):
# need to specify gloo backend for testing cpu offload # need to specify gloo backend for testing cpu offload
return "cpu:gloo,cuda:nccl" return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
@with_comms @with_comms
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@ -669,7 +680,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net3": ColwiseParallel(), "net3": ColwiseParallel(),
} }
model_2d = parallelize_module( model_2d = parallelize_module(
SimpleModel().cuda(), SimpleModel().to(device_type),
mesh_2d["tp"], mesh_2d["tp"],
parallelize_plan=parallelize_plan, parallelize_plan=parallelize_plan,
) )
@ -679,8 +690,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions) isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions)
) )
mesh_1d = init_device_mesh("cuda", (self.world_size,)) mesh_1d = init_device_mesh(device_type, (self.world_size,))
model_1d = FSDP(SimpleModel().cuda(), device_mesh=mesh_1d, use_orig_params=True) model_1d = FSDP(
SimpleModel().to(device_type), device_mesh=mesh_1d, use_orig_params=True
)
model_1d_fsdp_state = _get_module_fsdp_state(model_1d) model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None) self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
@ -692,7 +705,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# Create a model without wrapper # Create a model without wrapper
torch.manual_seed(0) torch.manual_seed(0)
no_wrap_model = simple_model().cuda(self.rank) no_wrap_model = simple_model().to(f"{device_type}:{self.rank}")
no_wrap_state_dict = no_wrap_model.state_dict() no_wrap_state_dict = no_wrap_model.state_dict()
# Create a model and sharded it with 2D FSDP + TP # Create a model and sharded it with 2D FSDP + TP
@ -706,7 +719,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net1": ColwiseParallel(), "net1": ColwiseParallel(),
"net2": RowwiseParallel(), "net2": RowwiseParallel(),
} }
model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) model_2d = parallelize_module(
simple_model().to(device_type), tp_mesh, parallelize_plan
)
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
FSDP.set_state_dict_type( FSDP.set_state_dict_type(
@ -754,7 +769,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net1": ColwiseParallel(), "net1": ColwiseParallel(),
"net2": RowwiseParallel(), "net2": RowwiseParallel(),
} }
model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) model_2d = parallelize_module(
simple_model().to(device_type), tp_mesh, parallelize_plan
)
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
@ -768,7 +785,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
ref_state_dict = deepcopy(model_2d.state_dict()) ref_state_dict = deepcopy(model_2d.state_dict())
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd. # Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
optim_2d.step() optim_2d.step()
# Load ref_state_dict back. # Load ref_state_dict back.
@ -799,9 +816,11 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# Create a model without wrapper # Create a model without wrapper
torch.manual_seed(0) torch.manual_seed(0)
no_wrap_model = simple_model().cuda(self.rank) no_wrap_model = simple_model().to(f"{device_type}:{self.rank}")
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01) no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward() no_wrap_model(
no_wrap_model.get_input().to(f"{device_type}:{self.rank}")
).sum().backward()
no_wrap_optim.step() no_wrap_optim.step()
no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim) no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim)
@ -815,7 +834,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net2": RowwiseParallel(), "net2": RowwiseParallel(),
} }
model_2d = parallelize_module( model_2d = parallelize_module(
simple_model().cuda(), mesh_2d["tp"], parallelize_plan simple_model().to(device_type), mesh_2d["tp"], parallelize_plan
) )
model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True) model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
FSDP.set_state_dict_type( FSDP.set_state_dict_type(
@ -823,7 +842,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
StateDictType.SHARDED_STATE_DICT, StateDictType.SHARDED_STATE_DICT,
) )
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
optim_2d.step() optim_2d.step()
optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d) optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
ref_optim_2d_osd = deepcopy(optim_2d_osd) ref_optim_2d_osd = deepcopy(optim_2d_osd)
@ -842,7 +861,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# compare with no_wrap state. # compare with no_wrap state.
if isinstance(dist_state, DTensor): if isinstance(dist_state, DTensor):
dist_state = ( dist_state = (
dist_state.cuda() dist_state.to(device_type)
.redistribute(placements=(Replicate(), Replicate())) .redistribute(placements=(Replicate(), Replicate()))
.to_local() .to_local()
) )
@ -850,7 +869,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
self.assertTrue(torch.allclose(state, dist_state)) self.assertTrue(torch.allclose(state, dist_state))
# Update the parameters 2d optim states will be different from ref_optim_state_dict. # Update the parameters 2d optim states will be different from ref_optim_state_dict.
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
optim_2d.step() optim_2d.step()
set_optimizer_state_dict( set_optimizer_state_dict(
@ -892,8 +911,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
5) dcp.load the state dict from storage 5) dcp.load the state dict from storage
6) load the state dict into the 2D model 6) load the state dict into the 2D model
""" """
dummy_model = SimpleModel().cuda() dummy_model = SimpleModel().to(device_type)
mesh_1d = init_device_mesh("cuda", (self.world_size,)) mesh_1d = init_device_mesh(device_type, (self.world_size,))
model = FSDP(dummy_model, device_mesh=mesh_1d) model = FSDP(dummy_model, device_mesh=mesh_1d)
optim = torch.optim.Adam(model.parameters(), lr=0.01) optim = torch.optim.Adam(model.parameters(), lr=0.01)
model(model.get_input()).sum().backward() model(model.get_input()).sum().backward()
@ -911,9 +930,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
dcp.save(state_dict, checkpoint_id=self.temp_dir) dcp.save(state_dict, checkpoint_id=self.temp_dir)
# initialize 2d model # initialize 2d model
dummy_model = SimpleModel().cuda() dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh( mesh_2d = init_device_mesh(
"cuda", device_type,
(2, self.world_size // 2), (2, self.world_size // 2),
mesh_dim_names=("dp", "tp"), mesh_dim_names=("dp", "tp"),
) )

View File

@ -30,7 +30,7 @@ from torch.distributed.tensor.parallel import (
from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
MultiProcessTestCase, MultiProcessTestCase,
requires_nccl, requires_accelerator_dist_backend,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -38,6 +38,7 @@ from torch.testing._internal.common_utils import (
parametrize, parametrize,
run_tests, run_tests,
skip_but_pass_in_sandcastle_if, skip_but_pass_in_sandcastle_if,
TEST_XPU,
) )
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@ -46,6 +47,10 @@ if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = torch.distributed.get_default_backend_for_device(device_type)
# MLP Layer # MLP Layer
class MLPModule(torch.nn.Module): class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int): def __init__(self, d_hid: int):
@ -79,7 +84,7 @@ class ComposabilityTest(MultiProcessTestCase):
@classmethod @classmethod
def backend_str(cls) -> str: def backend_str(cls) -> str:
# Testing with NCCL backend # Testing with NCCL backend
return "nccl" return backend
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -100,9 +105,11 @@ class ComposabilityTest(MultiProcessTestCase):
def device(self): def device(self):
return self.rank return self.rank
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs") @skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
)
def test_pp_and_dcp(self): def test_pp_and_dcp(self):
""" """
Test that pipeline parallelism and distributed checkpointing can be used together and Test that pipeline parallelism and distributed checkpointing can be used together and
@ -143,11 +150,11 @@ class ComposabilityTest(MultiProcessTestCase):
x = layer(x) x = layer(x)
return x return x
device = torch.device("cuda", self.device) device = torch.device(device_type, self.device)
torch.cuda.set_device(self.device) torch.accelerator.set_device_index(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size) store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend=backend,
store=store, store=store,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
@ -192,9 +199,11 @@ class ComposabilityTest(MultiProcessTestCase):
_dcp_test(self) _dcp_test(self)
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(8) @skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs") @skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@parametrize( @parametrize(
"ScheduleClass", "ScheduleClass",
[ [
@ -213,11 +222,11 @@ class ComposabilityTest(MultiProcessTestCase):
], ],
) )
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam): def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
_device_raii = torch.device("cuda", self.device) _device_raii = torch.device(device_type, self.device)
torch.cuda.set_device(self.device) torch.accelerator.set_device_index(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size) store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend=backend,
store=store, store=store,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
@ -228,7 +237,7 @@ class ComposabilityTest(MultiProcessTestCase):
num_microbatches = 8 num_microbatches = 8
dp_size = self.world_size // (tp_size * pp_size) dp_size = self.world_size // (tp_size * pp_size)
device_mesh = init_device_mesh( device_mesh = init_device_mesh(
"cuda", device_type,
mesh_shape=(dp_size, pp_size, tp_size), mesh_shape=(dp_size, pp_size, tp_size),
mesh_dim_names=("dp", "pp", "tp"), mesh_dim_names=("dp", "pp", "tp"),
) )

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import os import os
import unittest
from copy import deepcopy from copy import deepcopy
import torch import torch
@ -14,7 +15,11 @@ from torch.testing._internal.common_distributed import (
MultiProcessTestCase, MultiProcessTestCase,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
) )
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests, TEST_XPU
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_module = torch.get_device_module(device_type)
class Net(nn.Module): class Net(nn.Module):
@ -154,6 +159,7 @@ class ReplicateTest(MultiProcessTestCase):
self._compare_module(model, replicate_model) self._compare_module(model, replicate_model)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_move_args_kwargs_to_device(self): def test_replicate_move_args_kwargs_to_device(self):
class MyNet(nn.Module): class MyNet(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -166,24 +172,25 @@ class ReplicateTest(MultiProcessTestCase):
return self.a(inp) return self.a(inp)
self._init_pg() self._init_pg()
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
model = MyNet().cuda() model = MyNet().to(device_type)
replicate(model, device_id=torch.cuda.current_device()) replicate(model, device_id=torch.accelerator.current_device_index())
# CPU input ensures replicate can move arg and kwargs to device. # CPU input ensures replicate can move arg and kwargs to device.
a, b = torch.randn(2, 2), torch.randn(2, 2) a, b = torch.randn(2, 2), torch.randn(2, 2)
model(a, kwarg=b).sum().backward() model(a, kwarg=b).sum().backward()
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_ignore_module(self): def test_replicate_ignore_module(self):
self._init_pg() self._init_pg()
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
# Seed ensures diff input and thus different local grads across ranks. # Seed ensures diff input and thus different local grads across ranks.
torch.manual_seed(self.rank) torch.manual_seed(self.rank)
torch.cuda.manual_seed(self.rank) device_module.manual_seed(self.rank)
model = Net().cuda() model = Net().to(device_type)
replicate(model, ignored_modules=[model.fc1]) replicate(model, ignored_modules=[model.fc1])
# CPU input ensures that replicate can move input to GPU as DDP does. # CPU input ensures that replicate can move input to GPU as DDP does.
inp = torch.randn(5, 2, device="cuda") * (self.rank + 1) inp = torch.randn(5, 2, device=device_type) * (self.rank + 1)
out = model(inp) * 10 out = model(inp) * 10
out.sum().backward() out.sum().backward()
# FC1 grads should not be synchronized, FC2 and 3 should be. # FC1 grads should not be synchronized, FC2 and 3 should be.
@ -221,10 +228,11 @@ class ReplicateTest(MultiProcessTestCase):
self._compare_module(model, replicate_model) self._compare_module(model, replicate_model)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_device_id(self): def test_replicate_device_id(self):
self._init_pg() self._init_pg()
model = Net() model = Net()
model_cuda = deepcopy(model).cuda() model_cuda = deepcopy(model).to(device_type)
model_cuda2 = deepcopy(model_cuda) model_cuda2 = deepcopy(model_cuda)
replicate(model, device_id=torch.device("cpu")) replicate(model, device_id=torch.device("cpu"))
# DDP instance is attached in first pre forward # DDP instance is attached in first pre forward
@ -233,13 +241,15 @@ class ReplicateTest(MultiProcessTestCase):
# Should be None for CPU training # Should be None for CPU training
self.assertEqual(None, replicate_ddp_weakref.device_ids) self.assertEqual(None, replicate_ddp_weakref.device_ids)
replicate(model_cuda, device_id=torch.device(torch.cuda.current_device())) replicate(
model_cuda, device_id=torch.device(torch.accelerator.current_device_index())
)
# DDP instance is attached in first pre forward # DDP instance is attached in first pre forward
model_cuda(torch.randn(2, 2)) model_cuda(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref() replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
self.assertEqual([0], replicate_ddp_weakref.device_ids) self.assertEqual([0], replicate_ddp_weakref.device_ids)
# Pass in int as device_id # Pass in int as device_id
replicate(model_cuda2, device_id=int(torch.cuda.current_device())) replicate(model_cuda2, device_id=int(torch.accelerator.current_device_index()))
# DDP instance is attached in first pre forward # DDP instance is attached in first pre forward
model_cuda2(torch.randn(2, 2)) model_cuda2(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref() replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref()
@ -256,6 +266,7 @@ class ReplicateTest(MultiProcessTestCase):
class ReplicateFullyShardInit(ReplicateTest): class ReplicateFullyShardInit(ReplicateTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_fully_shard_init(self): def test_replicate_fully_shard_init(self):
class ToyModel(nn.Module): class ToyModel(nn.Module):
def __init__(self, dim: int): def __init__(self, dim: int):
@ -273,14 +284,14 @@ class ReplicateFullyShardInit(ReplicateTest):
return y return y
self._init_pg() self._init_pg()
torch.cuda.set_device(self.rank) torch.accelerator.set_device_index(self.rank)
dim = 3 dim = 3
bz = 2 bz = 2
model = ToyModel(dim).cuda() model = ToyModel(dim).to(device_type)
for linear in model.linears: for linear in model.linears:
fully_shard(linear) fully_shard(linear)
fully_shard(model.linears) fully_shard(model.linears)
replicate(model, device_id=torch.cuda.current_device()) replicate(model, device_id=torch.accelerator.current_device_index())
for linear in model.linears: for linear in model.linears:
self.assertTrue(isinstance(linear.weight, DTensor)) self.assertTrue(isinstance(linear.weight, DTensor))
inp = torch.rand(bz, dim) inp = torch.rand(bz, dim)

View File

@ -98,6 +98,8 @@ class ReplicateTest(MultiProcessInductorTestCase):
self.create_pg(device) self.create_pg(device)
torch._dynamo.config.optimize_ddp = "python_reducer" torch._dynamo.config.optimize_ddp = "python_reducer"
torch.manual_seed(123) torch.manual_seed(123)
if device_type == "xpu":
torch.use_deterministic_algorithms(True, warn_only=True)
model = Net(checkpoint=checkpoint).to(device) model = Net(checkpoint=checkpoint).to(device)
input = torch.randn([1, DIM], device=device) input = torch.randn([1, DIM], device=device)

View File

@ -388,6 +388,7 @@ class DTensorTestBase(MultiProcessTestCase):
"hccl", "hccl",
"xccl", "xccl",
"fake", "fake",
"cpu:gloo,xpu:xccl",
]: ]:
raise RuntimeError(f"Backend {backend} not supported!") raise RuntimeError(f"Backend {backend} not supported!")