mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[2/N] Port 5 _composable distributed test to Intel GPU (#159241)
For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This is the second PR for _composable cases, the first is https://github.com/pytorch/pytorch/pull/159118. We could enable Intel GPU with following methods and try the best to keep the original code styles: - Use "torch.accelerator.current_accelerator()" to determine the accelerator backend - Enabled XPU for some test path - Skip some test cases which Intel GPU does not support - Added "cpu:gloo,xpu:xccl" for distributed backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/159241 Approved by: https://github.com/guangyey, https://github.com/d4l3k
This commit is contained in:
parent
06bb32d55e
commit
814ba34fa6
|
|
@ -10,10 +10,13 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import checkpoint
|
from torch.distributed._composable import checkpoint
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
|
||||||
from torch.utils.checkpoint import CheckpointError
|
from torch.utils.checkpoint import CheckpointError
|
||||||
|
|
||||||
|
|
||||||
|
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||||
|
|
||||||
|
|
||||||
class MemoryDelta(ContextDecorator):
|
class MemoryDelta(ContextDecorator):
|
||||||
def __init__(self, device: torch.device):
|
def __init__(self, device: torch.device):
|
||||||
self.device: torch.device = device
|
self.device: torch.device = device
|
||||||
|
|
@ -22,16 +25,16 @@ class MemoryDelta(ContextDecorator):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.active_memory_enter = (
|
self.active_memory_enter = (
|
||||||
torch.cuda.memory_stats()["active_bytes.all.current"]
|
torch.accelerator.memory_stats()["active_bytes.all.current"]
|
||||||
if self.device.type == "cuda"
|
if self.device.type == "cuda" or self.device.type == "xpu"
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
def __exit__(self, *exc):
|
||||||
self.active_memory_exit = (
|
self.active_memory_exit = (
|
||||||
torch.cuda.memory_stats()["active_bytes.all.current"]
|
torch.accelerator.memory_stats()["active_bytes.all.current"]
|
||||||
if self.device.type == "cuda"
|
if self.device.type == "cuda" or self.device.type == "xpu"
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -126,7 +129,7 @@ class TestCheckpoint(TestCase):
|
||||||
loss2 = net2(x2).sum()
|
loss2 = net2(x2).sum()
|
||||||
loss2.backward()
|
loss2.backward()
|
||||||
|
|
||||||
if x.is_cuda:
|
if x.is_cuda or x.is_xpu:
|
||||||
self.assertTrue(mem2.delta() < mem1.delta())
|
self.assertTrue(mem2.delta() < mem1.delta())
|
||||||
|
|
||||||
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
||||||
|
|
@ -137,10 +140,10 @@ class TestCheckpoint(TestCase):
|
||||||
net = ToyModel()
|
net = ToyModel()
|
||||||
self._test_tensor_only(net, x)
|
self._test_tensor_only(net, x)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
|
||||||
def test_tensor_only_gpu(self):
|
def test_tensor_only_gpu(self):
|
||||||
x = torch.randn(20, 100, device="cuda:0")
|
x = torch.randn(20, 100, device=f"{device_type}:0")
|
||||||
net = ToyModel().to("cuda:0")
|
net = ToyModel().to(f"{device_type}:0")
|
||||||
self._test_tensor_only(net, x)
|
self._test_tensor_only(net, x)
|
||||||
|
|
||||||
def test_random_cpu(self):
|
def test_random_cpu(self):
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,8 @@ from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
|
TEST_XPU,
|
||||||
|
xfailIf,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
|
@ -58,6 +60,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(nn.Module):
|
class SimpleModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -73,7 +78,7 @@ class SimpleModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_input(self):
|
def get_input(self):
|
||||||
return torch.rand(4, 5, device="cuda")
|
return torch.rand(4, 5, device=device_type)
|
||||||
|
|
||||||
|
|
||||||
class SimpleModelUneven(nn.Module):
|
class SimpleModelUneven(nn.Module):
|
||||||
|
|
@ -94,7 +99,7 @@ class SimpleModelUneven(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_input(self):
|
def get_input(self):
|
||||||
return torch.rand(4, 5, device="cuda")
|
return torch.rand(4, 5, device=device_type)
|
||||||
|
|
||||||
|
|
||||||
class TestFullyShard2DTraining(FSDPTest):
|
class TestFullyShard2DTraining(FSDPTest):
|
||||||
|
|
@ -105,13 +110,15 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
return min(4, torch.cuda.device_count())
|
return min(4, torch.accelerator.device_count())
|
||||||
|
|
||||||
def init_global_mesh(self) -> DeviceMesh:
|
def init_global_mesh(self) -> DeviceMesh:
|
||||||
# Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
|
# Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
|
||||||
dp_size = 2 if self.world_size > 2 else 1
|
dp_size = 2 if self.world_size > 2 else 1
|
||||||
return init_device_mesh(
|
return init_device_mesh(
|
||||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
device_type,
|
||||||
|
(dp_size, self.world_size // dp_size),
|
||||||
|
mesh_dim_names=("dp", "tp"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
|
@ -138,7 +145,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
model = MLPStack(mlp_dim)
|
model = MLPStack(mlp_dim)
|
||||||
ref_model = copy.deepcopy(model).cuda()
|
ref_model = copy.deepcopy(model).to(device_type)
|
||||||
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
|
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
|
||||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||||
model.parallelize(
|
model.parallelize(
|
||||||
|
|
@ -150,9 +157,8 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||||
|
|
||||||
torch.manual_seed(42 + dp_pg.rank() + 1)
|
torch.manual_seed(42 + dp_pg.rank() + 1)
|
||||||
device = torch.device("cuda")
|
|
||||||
for iter_idx in range(10):
|
for iter_idx in range(10):
|
||||||
inp = torch.randn((8, mlp_dim), device=device)
|
inp = torch.randn((8, mlp_dim), device=device_type)
|
||||||
losses: list[torch.Tensor] = []
|
losses: list[torch.Tensor] = []
|
||||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||||
|
|
@ -162,6 +168,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
self.assertEqual(losses[0], losses[1])
|
self.assertEqual(losses[0], losses[1])
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881
|
||||||
def test_train_parity_2d_transformer(self):
|
def test_train_parity_2d_transformer(self):
|
||||||
self.run_subtests(
|
self.run_subtests(
|
||||||
{"use_shard_placement_fn": [False, True]},
|
{"use_shard_placement_fn": [False, True]},
|
||||||
|
|
@ -172,12 +179,12 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
model_args = ModelArgs(n_layers=3, dropout_p=0.0)
|
model_args = ModelArgs(n_layers=3, dropout_p=0.0)
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
ref_model = copy.deepcopy(model).cuda()
|
ref_model = copy.deepcopy(model).to(device_type)
|
||||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||||
|
|
||||||
dp_size, tp_size = self.world_size // 2, 2
|
dp_size, tp_size = self.world_size // 2, 2
|
||||||
global_mesh = init_device_mesh(
|
global_mesh = init_device_mesh(
|
||||||
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
|
device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
|
||||||
)
|
)
|
||||||
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)
|
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)
|
||||||
|
|
||||||
|
|
@ -205,7 +212,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
self.assertEqual(full_param, ref_param)
|
self.assertEqual(full_param, ref_param)
|
||||||
|
|
||||||
torch.manual_seed(42 + global_mesh.get_local_rank("dp"))
|
torch.manual_seed(42 + global_mesh.get_local_rank("dp"))
|
||||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type)
|
||||||
for iter_idx in range(5):
|
for iter_idx in range(5):
|
||||||
ref_loss = ref_model(inp).sum()
|
ref_loss = ref_model(inp).sum()
|
||||||
loss = model(inp).sum()
|
loss = model(inp).sum()
|
||||||
|
|
@ -242,15 +249,16 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
self.assertEqual(full_param, ref_param)
|
self.assertEqual(full_param, ref_param)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@xfailIf(TEST_XPU) # https://github.com/pytorch/pytorch/issues/156782
|
||||||
def test_tp_with_fsdp_offloading(self):
|
def test_tp_with_fsdp_offloading(self):
|
||||||
global_mesh = init_device_mesh(
|
global_mesh = init_device_mesh(
|
||||||
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
|
device_type, (1, self.world_size), mesh_dim_names=("dp", "tp")
|
||||||
)
|
)
|
||||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
mlp_dim = 16
|
mlp_dim = 16
|
||||||
model = MLPStack(mlp_dim)
|
model = MLPStack(mlp_dim)
|
||||||
ref_model = copy.deepcopy(model).cuda()
|
ref_model = copy.deepcopy(model).to(device_type)
|
||||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||||
# Parallelize with N-way TP and 1-way FSDP
|
# Parallelize with N-way TP and 1-way FSDP
|
||||||
model.parallelize(
|
model.parallelize(
|
||||||
|
|
@ -268,7 +276,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
|
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
|
||||||
# called, but they will just be no-ops without issuing any kernels.
|
# called, but they will just be no-ops without issuing any kernels.
|
||||||
# We prefer to keep the no-op check at the c10d level, not in FSDP.
|
# We prefer to keep the no-op check at the c10d level, not in FSDP.
|
||||||
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
|
inp = torch.randn((4, mlp_dim), device=device_type) # same on all ranks
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
ref_optim.zero_grad()
|
ref_optim.zero_grad()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
@ -297,6 +305,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
ref_optim.step()
|
ref_optim.step()
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_train_parity_2d_transformer_checkpoint_resume(self):
|
def test_train_parity_2d_transformer_checkpoint_resume(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -352,7 +361,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1)
|
torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1)
|
||||||
inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda")
|
inp = torch.randint(0, model_args.vocab_size, (3, 16), device=device_type)
|
||||||
loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
|
loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
|
||||||
loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp)
|
loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp)
|
||||||
|
|
||||||
|
|
@ -410,14 +419,14 @@ class TestFullyShard2DStateDict(DTensorTestBase):
|
||||||
@property
|
@property
|
||||||
def backend(self):
|
def backend(self):
|
||||||
# need to specify gloo backend for testing cpu offload
|
# need to specify gloo backend for testing cpu offload
|
||||||
return "cpu:gloo,cuda:nccl"
|
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
def test_fully_shard_tp_2d_set_full_state_dict(self):
|
def test_fully_shard_tp_2d_set_full_state_dict(self):
|
||||||
dummy_model = SimpleModel().cuda()
|
dummy_model = SimpleModel().to(device_type)
|
||||||
mesh_2d = init_device_mesh(
|
mesh_2d = init_device_mesh(
|
||||||
"cuda",
|
device_type,
|
||||||
(2, self.world_size // 2),
|
(2, self.world_size // 2),
|
||||||
mesh_dim_names=("dp", "tp"),
|
mesh_dim_names=("dp", "tp"),
|
||||||
)
|
)
|
||||||
|
|
@ -561,7 +570,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||||
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||||
)
|
)
|
||||||
model = FSDP(
|
model = FSDP(
|
||||||
SimpleModel().cuda(),
|
SimpleModel().to(device_type),
|
||||||
device_mesh=mesh_2d["dp"],
|
device_mesh=mesh_2d["dp"],
|
||||||
)
|
)
|
||||||
fsdp_state = _get_module_fsdp_state(model)
|
fsdp_state = _get_module_fsdp_state(model)
|
||||||
|
|
@ -573,7 +582,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||||
recompute_activation=False,
|
recompute_activation=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
model = SimpleModel().cuda(self.rank)
|
model = SimpleModel().to(f"{device_type}:{self.rank}")
|
||||||
model = FSDP(model, use_orig_params=use_orig_params)
|
model = FSDP(model, use_orig_params=use_orig_params)
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=0.01)
|
optim = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
|
@ -587,7 +596,9 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||||
"net1": ColwiseParallel(),
|
"net1": ColwiseParallel(),
|
||||||
"net2": RowwiseParallel(),
|
"net2": RowwiseParallel(),
|
||||||
}
|
}
|
||||||
model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan)
|
model_2d = parallelize_module(
|
||||||
|
SimpleModel().to(device_type), tp_mesh, parallelize_plan
|
||||||
|
)
|
||||||
model_2d = FSDP(
|
model_2d = FSDP(
|
||||||
model_2d,
|
model_2d,
|
||||||
device_mesh=dp_mesh,
|
device_mesh=dp_mesh,
|
||||||
|
|
@ -615,7 +626,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||||
# Ensure all input across TP ranks are same.
|
# Ensure all input across TP ranks are same.
|
||||||
# TODO: add a get_group_rank() to DeviceMesh.
|
# TODO: add a get_group_rank() to DeviceMesh.
|
||||||
torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0)))
|
torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0)))
|
||||||
input = torch.rand(4, 5).cuda(self.rank)
|
input = torch.rand(4, 5).to(f"{device_type}:{self.rank}")
|
||||||
output = model(input)
|
output = model(input)
|
||||||
output_2d = model_2d(input)
|
output_2d = model_2d(input)
|
||||||
self.assertEqual(output, output_2d)
|
self.assertEqual(output, output_2d)
|
||||||
|
|
@ -652,7 +663,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
@property
|
@property
|
||||||
def backend(self):
|
def backend(self):
|
||||||
# need to specify gloo backend for testing cpu offload
|
# need to specify gloo backend for testing cpu offload
|
||||||
return "cpu:gloo,cuda:nccl"
|
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
|
|
@ -669,7 +680,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
"net3": ColwiseParallel(),
|
"net3": ColwiseParallel(),
|
||||||
}
|
}
|
||||||
model_2d = parallelize_module(
|
model_2d = parallelize_module(
|
||||||
SimpleModel().cuda(),
|
SimpleModel().to(device_type),
|
||||||
mesh_2d["tp"],
|
mesh_2d["tp"],
|
||||||
parallelize_plan=parallelize_plan,
|
parallelize_plan=parallelize_plan,
|
||||||
)
|
)
|
||||||
|
|
@ -679,8 +690,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions)
|
isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions)
|
||||||
)
|
)
|
||||||
|
|
||||||
mesh_1d = init_device_mesh("cuda", (self.world_size,))
|
mesh_1d = init_device_mesh(device_type, (self.world_size,))
|
||||||
model_1d = FSDP(SimpleModel().cuda(), device_mesh=mesh_1d, use_orig_params=True)
|
model_1d = FSDP(
|
||||||
|
SimpleModel().to(device_type), device_mesh=mesh_1d, use_orig_params=True
|
||||||
|
)
|
||||||
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
|
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
|
||||||
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
|
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
|
||||||
|
|
||||||
|
|
@ -692,7 +705,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
|
|
||||||
# Create a model without wrapper
|
# Create a model without wrapper
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
no_wrap_model = simple_model().cuda(self.rank)
|
no_wrap_model = simple_model().to(f"{device_type}:{self.rank}")
|
||||||
no_wrap_state_dict = no_wrap_model.state_dict()
|
no_wrap_state_dict = no_wrap_model.state_dict()
|
||||||
|
|
||||||
# Create a model and sharded it with 2D FSDP + TP
|
# Create a model and sharded it with 2D FSDP + TP
|
||||||
|
|
@ -706,7 +719,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
"net1": ColwiseParallel(),
|
"net1": ColwiseParallel(),
|
||||||
"net2": RowwiseParallel(),
|
"net2": RowwiseParallel(),
|
||||||
}
|
}
|
||||||
model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
|
model_2d = parallelize_module(
|
||||||
|
simple_model().to(device_type), tp_mesh, parallelize_plan
|
||||||
|
)
|
||||||
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
|
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
|
||||||
|
|
||||||
FSDP.set_state_dict_type(
|
FSDP.set_state_dict_type(
|
||||||
|
|
@ -754,7 +769,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
"net1": ColwiseParallel(),
|
"net1": ColwiseParallel(),
|
||||||
"net2": RowwiseParallel(),
|
"net2": RowwiseParallel(),
|
||||||
}
|
}
|
||||||
model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
|
model_2d = parallelize_module(
|
||||||
|
simple_model().to(device_type), tp_mesh, parallelize_plan
|
||||||
|
)
|
||||||
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
|
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
|
||||||
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
|
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
|
@ -768,7 +785,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
ref_state_dict = deepcopy(model_2d.state_dict())
|
ref_state_dict = deepcopy(model_2d.state_dict())
|
||||||
|
|
||||||
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
|
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
|
||||||
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
|
model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
|
||||||
optim_2d.step()
|
optim_2d.step()
|
||||||
|
|
||||||
# Load ref_state_dict back.
|
# Load ref_state_dict back.
|
||||||
|
|
@ -799,9 +816,11 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
|
|
||||||
# Create a model without wrapper
|
# Create a model without wrapper
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
no_wrap_model = simple_model().cuda(self.rank)
|
no_wrap_model = simple_model().to(f"{device_type}:{self.rank}")
|
||||||
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
|
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
|
||||||
no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward()
|
no_wrap_model(
|
||||||
|
no_wrap_model.get_input().to(f"{device_type}:{self.rank}")
|
||||||
|
).sum().backward()
|
||||||
no_wrap_optim.step()
|
no_wrap_optim.step()
|
||||||
no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim)
|
no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim)
|
||||||
|
|
||||||
|
|
@ -815,7 +834,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
"net2": RowwiseParallel(),
|
"net2": RowwiseParallel(),
|
||||||
}
|
}
|
||||||
model_2d = parallelize_module(
|
model_2d = parallelize_module(
|
||||||
simple_model().cuda(), mesh_2d["tp"], parallelize_plan
|
simple_model().to(device_type), mesh_2d["tp"], parallelize_plan
|
||||||
)
|
)
|
||||||
model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
|
model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
|
||||||
FSDP.set_state_dict_type(
|
FSDP.set_state_dict_type(
|
||||||
|
|
@ -823,7 +842,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
StateDictType.SHARDED_STATE_DICT,
|
StateDictType.SHARDED_STATE_DICT,
|
||||||
)
|
)
|
||||||
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
|
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
|
||||||
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
|
model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
|
||||||
optim_2d.step()
|
optim_2d.step()
|
||||||
optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
|
optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
|
||||||
ref_optim_2d_osd = deepcopy(optim_2d_osd)
|
ref_optim_2d_osd = deepcopy(optim_2d_osd)
|
||||||
|
|
@ -842,7 +861,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
# compare with no_wrap state.
|
# compare with no_wrap state.
|
||||||
if isinstance(dist_state, DTensor):
|
if isinstance(dist_state, DTensor):
|
||||||
dist_state = (
|
dist_state = (
|
||||||
dist_state.cuda()
|
dist_state.to(device_type)
|
||||||
.redistribute(placements=(Replicate(), Replicate()))
|
.redistribute(placements=(Replicate(), Replicate()))
|
||||||
.to_local()
|
.to_local()
|
||||||
)
|
)
|
||||||
|
|
@ -850,7 +869,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
self.assertTrue(torch.allclose(state, dist_state))
|
self.assertTrue(torch.allclose(state, dist_state))
|
||||||
|
|
||||||
# Update the parameters 2d optim states will be different from ref_optim_state_dict.
|
# Update the parameters 2d optim states will be different from ref_optim_state_dict.
|
||||||
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
|
model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
|
||||||
optim_2d.step()
|
optim_2d.step()
|
||||||
|
|
||||||
set_optimizer_state_dict(
|
set_optimizer_state_dict(
|
||||||
|
|
@ -892,8 +911,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
5) dcp.load the state dict from storage
|
5) dcp.load the state dict from storage
|
||||||
6) load the state dict into the 2D model
|
6) load the state dict into the 2D model
|
||||||
"""
|
"""
|
||||||
dummy_model = SimpleModel().cuda()
|
dummy_model = SimpleModel().to(device_type)
|
||||||
mesh_1d = init_device_mesh("cuda", (self.world_size,))
|
mesh_1d = init_device_mesh(device_type, (self.world_size,))
|
||||||
model = FSDP(dummy_model, device_mesh=mesh_1d)
|
model = FSDP(dummy_model, device_mesh=mesh_1d)
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=0.01)
|
optim = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||||
model(model.get_input()).sum().backward()
|
model(model.get_input()).sum().backward()
|
||||||
|
|
@ -911,9 +930,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||||
dcp.save(state_dict, checkpoint_id=self.temp_dir)
|
dcp.save(state_dict, checkpoint_id=self.temp_dir)
|
||||||
|
|
||||||
# initialize 2d model
|
# initialize 2d model
|
||||||
dummy_model = SimpleModel().cuda()
|
dummy_model = SimpleModel().to(device_type)
|
||||||
mesh_2d = init_device_mesh(
|
mesh_2d = init_device_mesh(
|
||||||
"cuda",
|
device_type,
|
||||||
(2, self.world_size // 2),
|
(2, self.world_size // 2),
|
||||||
mesh_dim_names=("dp", "tp"),
|
mesh_dim_names=("dp", "tp"),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from torch.distributed.tensor.parallel import (
|
||||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
requires_nccl,
|
requires_accelerator_dist_backend,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
@ -38,6 +38,7 @@ from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
|
TEST_XPU,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||||
|
|
||||||
|
|
@ -46,6 +47,10 @@ if TYPE_CHECKING:
|
||||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||||
|
backend = torch.distributed.get_default_backend_for_device(device_type)
|
||||||
|
|
||||||
|
|
||||||
# MLP Layer
|
# MLP Layer
|
||||||
class MLPModule(torch.nn.Module):
|
class MLPModule(torch.nn.Module):
|
||||||
def __init__(self, d_hid: int):
|
def __init__(self, d_hid: int):
|
||||||
|
|
@ -79,7 +84,7 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def backend_str(cls) -> str:
|
def backend_str(cls) -> str:
|
||||||
# Testing with NCCL backend
|
# Testing with NCCL backend
|
||||||
return "nccl"
|
return backend
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
@ -100,9 +105,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.rank
|
return self.rank
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
|
@skip_but_pass_in_sandcastle_if(
|
||||||
|
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
|
||||||
|
)
|
||||||
def test_pp_and_dcp(self):
|
def test_pp_and_dcp(self):
|
||||||
"""
|
"""
|
||||||
Test that pipeline parallelism and distributed checkpointing can be used together and
|
Test that pipeline parallelism and distributed checkpointing can be used together and
|
||||||
|
|
@ -143,11 +150,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
device = torch.device("cuda", self.device)
|
device = torch.device(device_type, self.device)
|
||||||
torch.cuda.set_device(self.device)
|
torch.accelerator.set_device_index(self.device)
|
||||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend=backend,
|
||||||
store=store,
|
store=store,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
|
|
@ -192,9 +199,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
|
|
||||||
_dcp_test(self)
|
_dcp_test(self)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
@skip_if_lt_x_gpu(8)
|
@skip_if_lt_x_gpu(8)
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs")
|
@skip_but_pass_in_sandcastle_if(
|
||||||
|
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||||
|
)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"ScheduleClass",
|
"ScheduleClass",
|
||||||
[
|
[
|
||||||
|
|
@ -213,11 +222,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
|
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
|
||||||
_device_raii = torch.device("cuda", self.device)
|
_device_raii = torch.device(device_type, self.device)
|
||||||
torch.cuda.set_device(self.device)
|
torch.accelerator.set_device_index(self.device)
|
||||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend=backend,
|
||||||
store=store,
|
store=store,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
|
|
@ -228,7 +237,7 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||||
num_microbatches = 8
|
num_microbatches = 8
|
||||||
dp_size = self.world_size // (tp_size * pp_size)
|
dp_size = self.world_size // (tp_size * pp_size)
|
||||||
device_mesh = init_device_mesh(
|
device_mesh = init_device_mesh(
|
||||||
"cuda",
|
device_type,
|
||||||
mesh_shape=(dp_size, pp_size, tp_size),
|
mesh_shape=(dp_size, pp_size, tp_size),
|
||||||
mesh_dim_names=("dp", "pp", "tp"),
|
mesh_dim_names=("dp", "pp", "tp"),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -14,7 +15,11 @@ from torch.testing._internal.common_distributed import (
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests, TEST_XPU
|
||||||
|
|
||||||
|
|
||||||
|
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||||
|
device_module = torch.get_device_module(device_type)
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
|
@ -154,6 +159,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||||
self._compare_module(model, replicate_model)
|
self._compare_module(model, replicate_model)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
|
||||||
def test_replicate_move_args_kwargs_to_device(self):
|
def test_replicate_move_args_kwargs_to_device(self):
|
||||||
class MyNet(nn.Module):
|
class MyNet(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -166,24 +172,25 @@ class ReplicateTest(MultiProcessTestCase):
|
||||||
return self.a(inp)
|
return self.a(inp)
|
||||||
|
|
||||||
self._init_pg()
|
self._init_pg()
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
model = MyNet().cuda()
|
model = MyNet().to(device_type)
|
||||||
replicate(model, device_id=torch.cuda.current_device())
|
replicate(model, device_id=torch.accelerator.current_device_index())
|
||||||
# CPU input ensures replicate can move arg and kwargs to device.
|
# CPU input ensures replicate can move arg and kwargs to device.
|
||||||
a, b = torch.randn(2, 2), torch.randn(2, 2)
|
a, b = torch.randn(2, 2), torch.randn(2, 2)
|
||||||
model(a, kwarg=b).sum().backward()
|
model(a, kwarg=b).sum().backward()
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
|
||||||
def test_replicate_ignore_module(self):
|
def test_replicate_ignore_module(self):
|
||||||
self._init_pg()
|
self._init_pg()
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
# Seed ensures diff input and thus different local grads across ranks.
|
# Seed ensures diff input and thus different local grads across ranks.
|
||||||
torch.manual_seed(self.rank)
|
torch.manual_seed(self.rank)
|
||||||
torch.cuda.manual_seed(self.rank)
|
device_module.manual_seed(self.rank)
|
||||||
model = Net().cuda()
|
model = Net().to(device_type)
|
||||||
replicate(model, ignored_modules=[model.fc1])
|
replicate(model, ignored_modules=[model.fc1])
|
||||||
# CPU input ensures that replicate can move input to GPU as DDP does.
|
# CPU input ensures that replicate can move input to GPU as DDP does.
|
||||||
inp = torch.randn(5, 2, device="cuda") * (self.rank + 1)
|
inp = torch.randn(5, 2, device=device_type) * (self.rank + 1)
|
||||||
out = model(inp) * 10
|
out = model(inp) * 10
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
# FC1 grads should not be synchronized, FC2 and 3 should be.
|
# FC1 grads should not be synchronized, FC2 and 3 should be.
|
||||||
|
|
@ -221,10 +228,11 @@ class ReplicateTest(MultiProcessTestCase):
|
||||||
self._compare_module(model, replicate_model)
|
self._compare_module(model, replicate_model)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
|
||||||
def test_replicate_device_id(self):
|
def test_replicate_device_id(self):
|
||||||
self._init_pg()
|
self._init_pg()
|
||||||
model = Net()
|
model = Net()
|
||||||
model_cuda = deepcopy(model).cuda()
|
model_cuda = deepcopy(model).to(device_type)
|
||||||
model_cuda2 = deepcopy(model_cuda)
|
model_cuda2 = deepcopy(model_cuda)
|
||||||
replicate(model, device_id=torch.device("cpu"))
|
replicate(model, device_id=torch.device("cpu"))
|
||||||
# DDP instance is attached in first pre forward
|
# DDP instance is attached in first pre forward
|
||||||
|
|
@ -233,13 +241,15 @@ class ReplicateTest(MultiProcessTestCase):
|
||||||
# Should be None for CPU training
|
# Should be None for CPU training
|
||||||
self.assertEqual(None, replicate_ddp_weakref.device_ids)
|
self.assertEqual(None, replicate_ddp_weakref.device_ids)
|
||||||
|
|
||||||
replicate(model_cuda, device_id=torch.device(torch.cuda.current_device()))
|
replicate(
|
||||||
|
model_cuda, device_id=torch.device(torch.accelerator.current_device_index())
|
||||||
|
)
|
||||||
# DDP instance is attached in first pre forward
|
# DDP instance is attached in first pre forward
|
||||||
model_cuda(torch.randn(2, 2))
|
model_cuda(torch.randn(2, 2))
|
||||||
replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
|
replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
|
||||||
self.assertEqual([0], replicate_ddp_weakref.device_ids)
|
self.assertEqual([0], replicate_ddp_weakref.device_ids)
|
||||||
# Pass in int as device_id
|
# Pass in int as device_id
|
||||||
replicate(model_cuda2, device_id=int(torch.cuda.current_device()))
|
replicate(model_cuda2, device_id=int(torch.accelerator.current_device_index()))
|
||||||
# DDP instance is attached in first pre forward
|
# DDP instance is attached in first pre forward
|
||||||
model_cuda2(torch.randn(2, 2))
|
model_cuda2(torch.randn(2, 2))
|
||||||
replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref()
|
replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref()
|
||||||
|
|
@ -256,6 +266,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||||
|
|
||||||
class ReplicateFullyShardInit(ReplicateTest):
|
class ReplicateFullyShardInit(ReplicateTest):
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
|
||||||
def test_replicate_fully_shard_init(self):
|
def test_replicate_fully_shard_init(self):
|
||||||
class ToyModel(nn.Module):
|
class ToyModel(nn.Module):
|
||||||
def __init__(self, dim: int):
|
def __init__(self, dim: int):
|
||||||
|
|
@ -273,14 +284,14 @@ class ReplicateFullyShardInit(ReplicateTest):
|
||||||
return y
|
return y
|
||||||
|
|
||||||
self._init_pg()
|
self._init_pg()
|
||||||
torch.cuda.set_device(self.rank)
|
torch.accelerator.set_device_index(self.rank)
|
||||||
dim = 3
|
dim = 3
|
||||||
bz = 2
|
bz = 2
|
||||||
model = ToyModel(dim).cuda()
|
model = ToyModel(dim).to(device_type)
|
||||||
for linear in model.linears:
|
for linear in model.linears:
|
||||||
fully_shard(linear)
|
fully_shard(linear)
|
||||||
fully_shard(model.linears)
|
fully_shard(model.linears)
|
||||||
replicate(model, device_id=torch.cuda.current_device())
|
replicate(model, device_id=torch.accelerator.current_device_index())
|
||||||
for linear in model.linears:
|
for linear in model.linears:
|
||||||
self.assertTrue(isinstance(linear.weight, DTensor))
|
self.assertTrue(isinstance(linear.weight, DTensor))
|
||||||
inp = torch.rand(bz, dim)
|
inp = torch.rand(bz, dim)
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,8 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||||
self.create_pg(device)
|
self.create_pg(device)
|
||||||
torch._dynamo.config.optimize_ddp = "python_reducer"
|
torch._dynamo.config.optimize_ddp = "python_reducer"
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
if device_type == "xpu":
|
||||||
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
model = Net(checkpoint=checkpoint).to(device)
|
model = Net(checkpoint=checkpoint).to(device)
|
||||||
input = torch.randn([1, DIM], device=device)
|
input = torch.randn([1, DIM], device=device)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -388,6 +388,7 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||||
"hccl",
|
"hccl",
|
||||||
"xccl",
|
"xccl",
|
||||||
"fake",
|
"fake",
|
||||||
|
"cpu:gloo,xpu:xccl",
|
||||||
]:
|
]:
|
||||||
raise RuntimeError(f"Backend {backend} not supported!")
|
raise RuntimeError(f"Backend {backend} not supported!")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user