Revert "Tests Generelization for multiple accelerator devices (#139184)"

This reverts commit b576a8c318.

Reverted https://github.com/pytorch/pytorch/pull/139184 on behalf of https://github.com/clee2000 due to Failing internally when trying to pickle distributed test files D67098795 ([comment](https://github.com/pytorch/pytorch/pull/139184#issuecomment-2539610187))
This commit is contained in:
PyTorch MergeBot 2024-12-12 17:48:30 +00:00
parent 2f0fe82f6d
commit c85323c5e8
24 changed files with 381 additions and 416 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
from copy import deepcopy
from functools import partial
@ -15,7 +16,6 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
OffloadWrapper,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils.checkpoint import checkpoint
@ -23,8 +23,6 @@ from torch.utils.checkpoint import checkpoint
_SAVED_PREFIX = "_saved_"
GRAD_FN_NEXT_FUNCTIONS = "next_functions"
device_type = torch.device(get_devtype())
class CheckpointWrapperTest(TestCase):
def test_load_activation_checkpointed_module(self):
@ -132,6 +130,7 @@ class CheckpointWrapperTest(TestCase):
m(torch.randn(2, 1)).sum().backward()
self.assertEqual(2, count)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
def test_checkpoint_wrapper_parity(self):
"""
Tests that using checkpoint_wrapper or the functional
@ -156,11 +155,9 @@ class CheckpointWrapperTest(TestCase):
self.use_reentrant = use_reentrant
wrp = partial(
checkpoint_wrapper,
checkpoint_impl=(
CheckpointImpl.REENTRANT
if use_reentrant
else CheckpointImpl.NO_REENTRANT
),
checkpoint_impl=CheckpointImpl.REENTRANT
if use_reentrant
else CheckpointImpl.NO_REENTRANT,
)
for i in range(self.n):
l = nn.Sequential(
@ -187,12 +184,12 @@ class CheckpointWrapperTest(TestCase):
use_checkpointing,
use_wrapper=use_wrapper,
use_reentrant=use_reentrant,
).to(device_type.type)
x = torch.randn(10000, 256, requires_grad=True).to(device_type.type)
torch.get_device_module(device_type.type).reset_peak_memory_stats()
).cuda()
x = torch.randn(10000, 256, requires_grad=True).cuda()
torch.cuda.reset_peak_memory_stats()
loss = a(x).sum()
loss.backward()
return torch.get_device_module(device_type.type).max_memory_allocated()
return torch.cuda.max_memory_allocated()
functional_no_reentrant = test(
use_checkpointing=True, use_wrapper=False, use_reentrant=False
@ -336,12 +333,13 @@ class CheckpointWrapperTest(TestCase):
for fqn, _ in lin.named_parameters():
self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
def test_checkpoint_wrapper_cpu_offload(self):
model = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
).to(device_type.type)
).cuda()
# Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for
# testing, otherwise the tensor access during the DFS will cause orig
@ -360,7 +358,7 @@ class CheckpointWrapperTest(TestCase):
model = offload_wrapper(model)
inp = torch.randn(3, 10, device=device_type.type)
inp = torch.randn(3, 10, device="cuda")
loss = model(inp).sum()
# All autograd saved tensors should be offloaded to CPU.

View File

@ -8,10 +8,10 @@ from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter, loa
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, SkipModel
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
@ -85,7 +85,7 @@ class TestDistributedCheckpoint(FSDPTest):
# TODO: add resharding test case.
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices)
instantiate_parametrized_tests(TestDistributedCheckpoint)
if __name__ == "__main__":
run_tests()

View File

@ -6,13 +6,11 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
FSDPInitMode,
FSDPTest,
get_devtype,
NestedWrappedModule,
TransformerWithSharedParams,
)
@ -30,8 +28,6 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
device_type = torch.device(get_devtype())
class TestApply(FSDPTest):
@property
@ -71,12 +67,10 @@ class TestApply(FSDPTest):
def test_nested_module_apply(self):
"""Tests that ``apply()`` modifies parameter values in-place on a
non-FSDP-root nested FSDP-wrapped model."""
fsdp_kwargs = {"device_id": device_type.type}
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs=fsdp_kwargs,
)
self._check_apply(nested_wrapped_module)
@ -84,12 +78,10 @@ class TestApply(FSDPTest):
def test_transformer_module_apply(self):
"""Tests that ``apply()`` modifies parameter values in-place on an
FSDP-wrapped transformer model with shared parameters."""
fsdp_kwargs = {"device_id": device_type.type}
transformer = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs=fsdp_kwargs,
)
self._check_apply(transformer)
@ -97,19 +89,15 @@ class TestApply(FSDPTest):
def test_apply_in_summon_raises_error(self):
"""Tests that calling ``apply()`` on an FSDP instance inside the
``summon_full_params()`` context raises an error."""
fsdp_kwargs = {"device_id": device_type.type}
transformer = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs=fsdp_kwargs,
)
with transformer.summon_full_params(transformer):
with self.assertRaisesRegex(ValueError, "expected to be in states"):
transformer.apply(self._init_linear_weights)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestApply, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -16,12 +16,10 @@ from torch.distributed.fsdp._runtime_utils import (
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
NUM_ITERS = 2
DECODER_PARAM_FQNS = [
"decoder.layers.{index}.self_attn.in_proj_weight",
@ -83,13 +81,14 @@ class TestBackwardPrefetch(FSDPTest):
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
rank = self.rank
orig_get_handle_to_prefetch = _get_handle_to_prefetch
torch.manual_seed(0)
policy = ModuleWrapPolicy(
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
)
model = FSDP(
nn.Transformer(d_model=1024, nhead=8, device=device_type),
device_id=device_type.type,
nn.Transformer(d_model=1024, nhead=8, device="cuda"),
device_id=torch.cuda.current_device(),
auto_wrap_policy=policy,
use_orig_params=True,
backward_prefetch=backward_prefetch,
@ -98,8 +97,8 @@ class TestBackwardPrefetch(FSDPTest):
# prepare input
torch.manual_seed(rank + 1)
src = torch.randn((10, 1, 1024), device=device_type)
tgt = torch.randn((20, 1, 1024), device=device_type)
src = torch.randn((10, 1, 1024), device="cuda")
tgt = torch.randn((20, 1, 1024), device="cuda")
# monkey patch
all_handle_fqns: List[List[str]] = []

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import sys
from copy import deepcopy
@ -16,9 +17,8 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel as FSDP,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest, get_devtype
from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -28,17 +28,18 @@ from torch.testing._internal.common_utils import (
from torch.utils.checkpoint import checkpoint
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
_save_on_cpu_called = False
@ -82,18 +83,22 @@ class TestFSDPCheckpoint(FSDPTest):
**fsdp_kwargs,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
super().__init__()
l1 = nn.Linear(3, 3).to(device_type.type)
l2 = nn.Linear(3, 3).to(device_type.type)
l3 = nn.Linear(3, 3).to(device_type.type)
l1 = nn.Linear(3, 3).cuda()
l2 = nn.Linear(3, 3).cuda()
l3 = nn.Linear(3, 3).cuda()
if checkpoint_layer:
if offload_activations:
ckpt_wrapper = offload_wrapper
else:
ckpt_wrapper = checkpoint_wrapper
l1 = ckpt_wrapper(l1)
l2 = ckpt_wrapper(l2)
l3 = ckpt_wrapper(l3)
fsdp_wrapper = partial(
_maybe_wrap_fsdp, *fsdp_args, wrap_fsdp=wrap_fsdp, **fsdp_kwargs
)
@ -110,9 +115,11 @@ class TestFSDPCheckpoint(FSDPTest):
assert losses
assert outputs
assert models
for l, o in zip(losses[1:], outputs[1:]):
self.assertEqual(losses[0], l)
self.assertEqual(outputs[0], o)
# Verify grads
ref_model = models[0]
ref_grads = [p.grad for p in ref_model.parameters()]
@ -139,6 +146,7 @@ class TestFSDPCheckpoint(FSDPTest):
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
ckpt_sequential_wrapped_fsdp = wrapper_to_use(
TestFSDPCheckpoint.SequentialModule(
@ -153,13 +161,16 @@ class TestFSDPCheckpoint(FSDPTest):
wrap_fsdp=True,
**fsdp_kwargs,
)
baseline = TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True,
**fsdp_kwargs,
)
# note that reentrant-based checkpointing requires inputs to have grad
# flag set.
inp = torch.randn(10, 3, device=device_type.type, requires_grad=True)
inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)
global _save_on_cpu_called
models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
with patch_save_on_cpu(get_patched_save_on_cpu()):
@ -178,7 +189,9 @@ class TestFSDPCheckpoint(FSDPTest):
loss.backward()
losses.append(loss)
outputs.append(out)
self._verify_parity(losses, outputs, models)
dist.barrier()
@skip_if_lt_x_gpu(2)
@ -197,7 +210,7 @@ class TestFSDPCheckpoint(FSDPTest):
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
global _save_on_cpu_called
with patch_save_on_cpu(get_patched_save_on_cpu()):
seq = TestFSDPCheckpoint.SequentialModule().to(device_type.type)
seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
# Runs FSDP with no checkpointing
fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs)
# Runs checkpoint-wrapped FSDP
@ -205,6 +218,7 @@ class TestFSDPCheckpoint(FSDPTest):
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper
checkpointed_fsdp = wrapper_to_use(
FSDP(deepcopy(seq), **fsdp_kwargs),
)
@ -217,7 +231,11 @@ class TestFSDPCheckpoint(FSDPTest):
fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs)
# note that reentrant-based checkpointing requires inputs to have grad
# flag set.
inp = torch.randn(10, 3, device=device_type.type, requires_grad=True)
inp = torch.randn(
10, 3, device=torch.cuda.current_device(), requires_grad=True
)
models = [
fsdp_only_seq,
checkpointed_fsdp,
@ -247,6 +265,7 @@ class TestFSDPCheckpoint(FSDPTest):
# _save_on_cpu should not be called yet
self.assertFalse(_save_on_cpu_called)
out = m(inp)
if check_offload:
self.assertTrue(_save_on_cpu_called)
loss = out.sum()
@ -254,7 +273,9 @@ class TestFSDPCheckpoint(FSDPTest):
losses.append(loss)
outputs.append(out)
_save_on_cpu_called = False
self._verify_parity(losses, outputs, models)
dist.barrier()
@ -306,27 +327,35 @@ class TestFSDPCheckpointSubmodule(FSDPTest):
# TODO: grad value checks occasionally fails when use_reentrant = True
@skip_if_lt_x_gpu(2)
@parametrize("use_reentrant", [False])
def test_checkpoint_submodule(self, device, use_reentrant: bool):
model = TestModel(use_reentrant=use_reentrant).to(device_type.type)
def test_checkpoint_submodule(self, use_reentrant: bool):
model = TestModel(use_reentrant=use_reentrant).cuda()
model_ac = deepcopy(model)
for _, m in model_ac.named_modules():
if isinstance(m, CheckpointModule):
m.checkpoint = True
self.assertTrue(model_ac.checkpoint1.s1.checkpoint)
self.assertTrue(model_ac.checkpoint2.s2.checkpoint)
fsdp_kwargs = {
"device_id": device_type.type,
"device_id": torch.cuda.current_device(),
"sharding_strategy": ShardingStrategy.NO_SHARD,
}
# Wrap no checkpointing model submodules with FSDP
model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs)
model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs)
# Wrap checkpointing model submodules with FSDP
model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs)
model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs)
x = torch.randn(2, 100, device=self.device_type)
x = torch.randn(2, 100, device="cuda")
model(x).sum().backward()
model_ac(x).sum().backward()
for (n1, p1), (n2, p2) in zip(
model.named_parameters(), model_ac.named_parameters()
):
@ -334,7 +363,8 @@ class TestFSDPCheckpointSubmodule(FSDPTest):
self.assertTrue(p1.grad.allclose(p2.grad))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPCheckpointSubmodule, globals(), only_for=devices)
instantiate_parametrized_tests(TestFSDPCheckpointSubmodule)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import itertools
import sys
from typing import Union
@ -15,24 +16,25 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
FSDPInitMode,
FSDPTest,
get_devtype,
NestedWrappedModule,
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -45,7 +47,7 @@ class TestClipGradNorm(FSDPTest):
"""Tests :meth:`FullyShardedDataParallel.clip_grad_norm_`."""
@skip_if_lt_x_gpu(2)
def test_non_root(self, device):
def test_non_root(self):
"""
Tests that calling ``clip_grad_norm_()`` on a non-root FSDP instance
raises an error.
@ -60,24 +62,22 @@ class TestClipGradNorm(FSDPTest):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.lin1(x))
model = Model().to(device_type.type)
model = Model().cuda()
model.lin2 = FSDP(model.lin2)
fsdp_model = FSDP(model)
# fsdp_model(torch.randn((2, 5), device=torch.device(self.device_type))).sum().backward()
fsdp_model(torch.randn((2, 5), device=device_type)).sum().backward()
fsdp_model(torch.randn((2, 5), device=torch.device("cuda"))).sum().backward()
error_regex = "should only be called on the root FSDP instance"
with self.assertRaisesRegex(RuntimeError, error_regex):
fsdp_model.lin2.clip_grad_norm_(max_norm=2)
@skip_if_lt_x_gpu(2)
def test_ddp_parity(self, device):
def test_ddp_parity(self):
"""
Tests FSDP with ``FullyShardedDataParallel.clip_grad_norm_()` against
DDP with ``torch.nn.utils.clip_grad_norm_()` when using full precision.
"""
self.run_subtests(
{
"device": [device],
"max_norm": [1, 2.5],
"norm_type": [1, 2, float("inf")],
"sharding_strategy": [
@ -93,7 +93,6 @@ class TestClipGradNorm(FSDPTest):
def _test_ddp_parity(
self,
device,
max_norm: Union[float, int],
norm_type: Union[float, int],
sharding_strategy: Union[ShardingStrategy, str],
@ -106,11 +105,10 @@ class TestClipGradNorm(FSDPTest):
DEVICEInitMode.DEVICE_BEFORE,
deterministic=True,
)
ddp_model = DDP(local_model, device_ids=[device_type])
ddp_model = DDP(local_model, device_ids=[self.rank])
fsdp_kwargs = {
"cpu_offload": CPUOffload(offload_params=offload_params),
"use_orig_params": use_orig_params,
"device_id": device_type.type,
}
if sharding_strategy == "mixed_strategy":
fsdp_model = TransformerWithSharedParams.init(
@ -158,7 +156,7 @@ class TestClipGradNorm(FSDPTest):
LR = 1e-2
ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
device = torch.device(self.device_type)
device = torch.device("cuda")
LARGE_FACTOR = 100
inp = ddp_model.module.get_input(device)
for model in (ddp_model, fsdp_model):
@ -168,6 +166,7 @@ class TestClipGradNorm(FSDPTest):
else:
loss = model.get_loss(inp, out)
loss.backward()
# Multiply gradients by a large factor to ensure that gradients will
# actually be clipped
for param in itertools.chain(ddp_model.parameters(), fsdp_model.parameters()):
@ -182,6 +181,7 @@ class TestClipGradNorm(FSDPTest):
param.grad.detach().clone() if param.grad is not None else None
for param in fsdp_model.parameters()
]
ddp_total_norm = torch.nn.utils.clip_grad_norm_(
ddp_model.parameters(),
max_norm=max_norm,
@ -191,6 +191,7 @@ class TestClipGradNorm(FSDPTest):
max_norm=max_norm, norm_type=norm_type
)
self.assertEqual(ddp_total_norm, fsdp_total_norm)
# Check that the gradients were modified by `clip_grad_norm_()`
for param, orig_grad in zip(ddp_model.parameters(), orig_ddp_grads):
assert not torch.equal(param.grad, orig_grad)
@ -199,6 +200,7 @@ class TestClipGradNorm(FSDPTest):
self.assertEqual(param.grad, orig_grad) # `None`
else:
assert not torch.equal(param.grad, orig_grad)
# Run an optimizer step to ensure gradients matched after clipping
ddp_optim.step()
fsdp_optim.step()
@ -209,11 +211,13 @@ class TestClipGradNorm(FSDPTest):
):
self.assertEqual(n1, n2)
self.assertEqual(p1, p2)
if offload_params:
# TODO: Gradient computation on CPU and GPU differ slightly causing
# drift unrelated to `clip_grad_norm_()`.
# https://github.com/pytorch/pytorch/issues/89133
return
# Run a few more iterations
# TODO: We cannot run too many iterations, or else there is drift:
# https://github.com/pytorch/pytorch/issues/89136
@ -238,11 +242,10 @@ class TestClipGradNorm(FSDPTest):
fsdp_optim.step()
@skip_if_lt_x_gpu(2)
def test_low_precision_grads(self, device):
def test_low_precision_grads(self):
"""Tests ``clip_grad_norm_()`` when using low precision gradients."""
self.run_subtests(
{
"device": [device],
"max_norm": [1, 2.5],
"norm_type": [1, 2, float("inf")],
"sharding_strategy": [
@ -256,7 +259,6 @@ class TestClipGradNorm(FSDPTest):
def _test_low_precision_grads(
self,
device,
max_norm: Union[float, int],
norm_type: Union[float, int],
sharding_strategy: ShardingStrategy,
@ -270,7 +272,6 @@ class TestClipGradNorm(FSDPTest):
reduce_dtype=torch.float16,
keep_low_precision_grads=True,
),
"device_id": device_type.type,
}
fsdp_model = FSDP(
NestedWrappedModule.init(
@ -282,7 +283,7 @@ class TestClipGradNorm(FSDPTest):
),
**fsdp_kwargs,
)
inp = fsdp_model.module.get_input(torch.device(self.device_type))
inp = fsdp_model.module.get_input(torch.device("cuda"))
out = fsdp_model(*inp)
out.sum().backward()
for param in fsdp_model.parameters():
@ -301,17 +302,17 @@ class TestClipGradNorm(FSDPTest):
)
@skip_if_lt_x_gpu(2)
def test_no_gradients(self, device):
def test_no_gradients(self):
"""
Tests that calling ``clip_grad_norm_()`` when the FDSP module has no
gradients simply returns a scalar zero tensor in FP32 without erroring.
"""
self.run_subtests(
{"device": [device], "use_orig_params": [False, True]},
{"use_orig_params": [False, True]},
self._test_no_gradients,
)
def _test_no_gradients(self, device, use_orig_params: bool):
def _test_no_gradients(self, use_orig_params: bool):
lin_module = nn.Linear(24, 24)
mixed_precision_config = MixedPrecision(
param_dtype=torch.float16,
@ -322,10 +323,10 @@ class TestClipGradNorm(FSDPTest):
lin_module,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
mixed_precision=mixed_precision_config,
device_id=device_type.type,
device_id=self.rank,
use_orig_params=use_orig_params,
)
inp = torch.randn(32, 24, device=self.device_type)
inp = torch.randn(32, 24, device="cuda")
fsdp_module(inp)
with self.assertWarnsRegex(
expected_warning=UserWarning,
@ -335,10 +336,10 @@ class TestClipGradNorm(FSDPTest):
):
total_norm = fsdp_module.clip_grad_norm_(1)
self.assertEqual(total_norm.dtype, torch.float32)
self.assertEqual(total_norm, torch.tensor(0.0, device=self.device_type))
self.assertEqual(total_norm, torch.tensor(0.0, device="cuda"))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestClipGradNorm, globals(), only_for=devices)
instantiate_parametrized_tests(TestClipGradNorm)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import sys
from contextlib import nullcontext
from enum import auto, Enum
@ -9,34 +10,31 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed as dist
from torch._utils import _get_device_module
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
FSDPInitMode,
FSDPTest,
get_devtype,
MLP,
NestedWrappedModule,
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -56,14 +54,11 @@ class TestCommunication(FSDPTest):
def _init_model(
self,
device,
nested_model: bool,
sharding_strategy: ShardingStrategy,
device: torch.device,
):
fsdp_kwargs = {
"sharding_strategy": sharding_strategy,
"device_id": device_type.type,
}
fsdp_kwargs = {"sharding_strategy": sharding_strategy}
if nested_model:
model = NestedWrappedModule.init(
self.process_group,
@ -75,7 +70,7 @@ class TestCommunication(FSDPTest):
model,
self.process_group,
**fsdp_kwargs,
)
).to(device)
else:
fsdp_model: FSDP = TransformerWithSharedParams.init(
self.process_group,
@ -208,7 +203,6 @@ class TestCommunication(FSDPTest):
@parametrize("sharding_strategy", [ShardingStrategy.SHARD_GRAD_OP, None])
def test_communication(
self,
device,
nested_model: bool,
use_no_sync: bool,
sharding_strategy: Optional[ShardingStrategy],
@ -216,6 +210,7 @@ class TestCommunication(FSDPTest):
"""
Tests FSDP's communication cost in terms of calls to collective
communication primitives (i.e. all-gather and reduce-scatter).
Arguments:
nested_model (bool): If ``True``, uses ``NestedWrappedModule``,
which has nested FSDP instances; if ``False``, uses the default
@ -230,13 +225,16 @@ class TestCommunication(FSDPTest):
# Enable execution order checking
dist.set_debug_level(dist.DebugLevel.DETAIL)
# Initialize the model and inputs
fsdp_model = self._init_model(device_type, nested_model, sharding_strategy)
batch = fsdp_model.module.get_input(device_type)
device = torch.device("cuda")
fsdp_model = self._init_model(nested_model, sharding_strategy, device)
batch = fsdp_model.module.get_input(device)
# Count the number of FSDP instances that manage parameters since the
# number of collectives are a function of this number
num_fsdp = sum(
(isinstance(m, FSDP) and len(m.params) > 0) for m in fsdp_model.modules()
)
# If `use_no_sync=True`, we run `num_iters` iterations inside
# `no_sync()` followed by `num_iters` iterations outside `no_sync()`,
# and if `use_no_sync=False`, we only run `num_iters` iterations
@ -294,11 +292,11 @@ class TestCommunication(FSDPTest):
class TestExplicitUnshard(FSDPTest):
@property
def world_size(self) -> int:
return min(_get_device_module(self.device_type).device_count(), 2)
return min(torch.cuda.device_count(), 2)
@skip_if_lt_x_gpu(2)
@parametrize("use_orig_params", [False, True])
def test_unshard_async(self, device, use_orig_params: bool):
def test_unshard_async(self, use_orig_params: bool):
class ReduceModule(nn.Module):
def __init__(self, dim: int, group: dist.ProcessGroup):
super().__init__()
@ -352,25 +350,28 @@ class TestExplicitUnshard(FSDPTest):
group = self.process_group
batch_size, dim = 2, 8
torch.manual_seed(42)
ref_model = DDP(ReduceModel(dim, group).to(device_type), device_ids=[self.rank])
ref_model = DDP(ReduceModel(dim, group).cuda(), device_ids=[self.rank])
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
torch.manual_seed(42)
model = ReduceModel(dim, group)
model.mlps = FSDP(
model.mlps,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
auto_wrap_policy=ModuleWrapPolicy((MLP,)),
device_id=device_type.type,
device_id=self.rank,
use_orig_params=use_orig_params,
)
model.mlps.check_is_root()
mlp_params = set(model.mlps.parameters())
mlp_param_names = {n for n, p in model.named_parameters() if p in mlp_params}
DDP._set_params_and_buffers_to_ignore_for_model(model, mlp_param_names)
model = DDP(model.to(device_type), device_ids=[self.rank])
model = DDP(model.cuda(), device_ids=[self.rank])
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((batch_size, dim), device=device_type)
inp = torch.randn((batch_size, dim), device="cuda")
for _ in range(10):
losses: List[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
@ -382,8 +383,8 @@ class TestExplicitUnshard(FSDPTest):
model.module.mlps._wait_unshard_streams_on_current_stream()
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestCommunication, globals(), only_for=devices)
instantiate_device_type_tests(TestExplicitUnshard, globals(), only_for=devices)
instantiate_parametrized_tests(TestCommunication)
instantiate_parametrized_tests(TestExplicitUnshard)
if __name__ == "__main__":
run_tests()

View File

@ -1,9 +1,9 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import functools
import itertools
import sys
import unittest
from typing import Any, Callable, Dict, List, Optional
from unittest import mock
@ -19,7 +19,6 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _p_assert
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
AlwaysWrapNestedWrappedModule,
@ -27,7 +26,6 @@ from torch.testing._internal.common_fsdp import (
DummyDDP,
FSDPInitMode,
FSDPTest,
get_devtype,
MixtureOfExperts,
NestedWrappedModule,
NestedWrappedModuleWithDelay,
@ -35,9 +33,9 @@ from torch.testing._internal.common_fsdp import (
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
@ -45,6 +43,7 @@ from torch.testing._internal.common_utils import (
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -52,7 +51,6 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
device_type = torch.device(get_devtype())
params = "cpu_offload,sharding_strategy"
cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
sharding_strategy_config = [
@ -67,6 +65,7 @@ test_name_mapping = {
str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op",
str(ShardingStrategy.NO_SHARD): "no_shard",
}
subtest_name = functools.partial(subtest_name, test_name_mapping)
@ -87,6 +86,7 @@ class TestParityWithDDP(FSDPTest):
# on CPU but NCCL only supports GPU.
if cpu_offload.offload_params:
modes.append(DEVICEInitMode.DEVICE_NEVER)
return modes
def _get_subtest_config(self, cpu_offload: CPUOffload) -> Dict[str, List[Any]]:
@ -229,7 +229,6 @@ class TestParityWithDDP(FSDPTest):
cpu_offload: CPUOffload,
sharding_strategy: Optional[ShardingStrategy],
):
fsdp_kwargs = {"device_id": device_type.type}
self.run_subtests(
self._get_subtest_config(cpu_offload),
self._test_fsdp_parity,
@ -238,12 +237,8 @@ class TestParityWithDDP(FSDPTest):
ref_init_fn=self._dummy_ddp_fn,
cpu_offload=cpu_offload,
sharding_strategy=sharding_strategy,
**fsdp_kwargs,
)
@unittest.skipIf(
TEST_HPU, "HPU doesn't has HW sleep API support (like CUDA), skipping"
)
@skip_if_lt_x_gpu(2)
@parametrize(params, configs, subtest_name)
def test_mixture_of_experts_with_delay_before_free(
@ -251,7 +246,6 @@ class TestParityWithDDP(FSDPTest):
cpu_offload: CPUOffload,
sharding_strategy: Optional[ShardingStrategy],
):
fsdp_kwargs = {"device_id": device_type.type}
self.run_subtests(
self._get_subtest_config(cpu_offload),
self._test_fsdp_parity,
@ -261,7 +255,6 @@ class TestParityWithDDP(FSDPTest):
cpu_offload=cpu_offload,
sharding_strategy=sharding_strategy,
init_kwargs={"delay_before_free_ms": 250},
**fsdp_kwargs,
)
@ -274,7 +267,7 @@ class TestParamInit(FSDPTest):
initialization persist.
"""
# Establish reference behavior
fsdp_kwargs = {"device_id": device_type}
fsdp_kwargs = {}
if mixed_precision:
fsdp_kwargs["mixed_precision"] = MixedPrecision()
fsdp_model = TransformerWithSharedParams.init(
@ -284,7 +277,7 @@ class TestParamInit(FSDPTest):
fsdp_kwargs,
deterministic=True,
)
input = fsdp_model.module.get_input(device_type)
input = fsdp_model.module.get_input(torch.device("cuda"))
ref_output = fsdp_model(*input)
# Initialize the same model but change its first parameter value
# in-place after FSDP initialization
@ -311,12 +304,10 @@ class TestHooks(FSDPTest):
def test_pre_backward_hook_registration(self, cuda_first: bool):
"""Tests that FSDP pre-backward hooks are registered on forward pass
outputs."""
fsdp_kwargs = {"device_id": device_type.type}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_BEFORE if cuda_first else DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs,
)
self._test_pre_backward_hook_registration(fsdp_model)
@ -324,12 +315,10 @@ class TestHooks(FSDPTest):
def test_pre_backward_hook_registration_after_state_dict(self):
"""Tests that FSDP pre-backward hooks are registered on forward pass
outputs after saving and loading the model from a checkpoint."""
fsdp_kwargs = {"device_id": device_type.type}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs,
)
self._train_for_several_steps(fsdp_model, num_steps=2, autocast=False)
state_dict = fsdp_model.state_dict()
@ -340,11 +329,11 @@ class TestHooks(FSDPTest):
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim.zero_grad()
# Inputs always cuda, as computation happens on CUDA device only
input = model.module.get_input(device_type)
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
# this is pre-bwd hook
self.assertEqual(len(output._backward_hooks), 1)
loss = model.module.get_loss(input, output).to(device_type.type)
loss = model.module.get_loss(input, output).cuda()
loss.backward()
# It doesn't get removed
self.assertEqual(len(output._backward_hooks), 1)
@ -357,7 +346,7 @@ class TestHooks(FSDPTest):
def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool):
"""Tests that ``_register_{pre|post}_backward_hooks()`` are called
during the FSDP forward."""
fsdp_kwargs = {"device_id": device_type.type}
fsdp_kwargs = {}
if mixed_precision:
fsdp_kwargs["mixed_precision"] = MixedPrecision()
fsdp_model = TransformerWithSharedParams.init(
@ -366,7 +355,8 @@ class TestHooks(FSDPTest):
DEVICEInitMode.DEVICE_BEFORE if cuda_first else DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs,
)
input = fsdp_model.module.get_input(device_type)
input = fsdp_model.module.get_input(torch.device("cuda"))
# Since `_register_pre_backward_hooks()` modifies the forward output,
# we cannot directly mock it. We implement our own counter instead.
orig_register_pre_backward_hooks = (
@ -400,7 +390,7 @@ class TestNoGrad(FSDPTest):
parameters, after training for one iteration, running a forward pass in
``eval()`` mode gives the same output as running a forward pass in
``torch.no_grad()``."""
fsdp_kwargs = {"device_id": device_type.type}
fsdp_kwargs = {}
if mixed_precision:
fsdp_kwargs["mixed_precision"] = MixedPrecision(
param_dtype=torch.float16,
@ -421,7 +411,7 @@ class TestNoGrad(FSDPTest):
autocast=False,
mixed_precision=fsdp_kwargs["mixed_precision"],
)
input = fsdp_model.module.get_input(device_type)
input = fsdp_model.module.get_input(torch.device("cuda"))
# Run a forward in eval mode
fsdp_model.eval()
ref_output = fsdp_model(*input)
@ -445,7 +435,7 @@ class TestAutograd(FSDPTest):
{
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.SHARD_GRAD_OP
# Skip testing `NO_SHARD` since it doubly uses
# `_use_unsharded_views()` for sharded views. Testing
# `FULL_SHARD` and `SHARD_GRAD_OP` provides good confidence
@ -485,17 +475,17 @@ class TestAutograd(FSDPTest):
"forward_prefetch": forward_prefetch,
"backward_prefetch": backward_prefetch,
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
"device_id": device_type,
}
device = torch.device("cuda")
# Define a model with enough FSDP instances to exercise prefetching
NUM_LINEARS = 5
model = nn.Sequential(
*[nn.Linear(3, 3, device=device_type) for _ in range(NUM_LINEARS)]
*[nn.Linear(3, 3, device=device) for _ in range(NUM_LINEARS)]
)
fsdp_model = FSDP(model, **fsdp_kwargs)
self.assertEqual(len(list(FSDP.fsdp_modules(fsdp_model))), NUM_LINEARS + 1)
for _ in range(3):
inp = torch.randn((2, 3), device=device_type)
inp = torch.randn((2, 3), device=device)
with self._patch_use_unsharded_views(
_use_unsharded_views_assert_as_tensors
):
@ -512,11 +502,10 @@ class TestAutograd(FSDPTest):
FlatParamHandle._use_unsharded_views = orig_use_unsharded_views
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestHooks, globals(), only_for=devices)
instantiate_device_type_tests(TestParityWithDDP, globals(), only_for=devices)
instantiate_device_type_tests(TestNoGrad, globals(), only_for=devices)
instantiate_device_type_tests(TestParamInit, globals(), only_for=devices)
instantiate_device_type_tests(TestAutograd, globals(), only_for=devices)
instantiate_parametrized_tests(TestHooks)
instantiate_parametrized_tests(TestParityWithDDP)
instantiate_parametrized_tests(TestNoGrad)
instantiate_parametrized_tests(TestParamInit)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import io
from copy import deepcopy
@ -13,9 +14,11 @@ from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
StateDictType,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
@ -23,13 +26,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
# Simple and boring model to test interface and some corner cases that do not
# require complicated wrapping strategy.
class TestDummyModel(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
@ -41,11 +41,11 @@ class TestDummyModel(torch.nn.Module):
return self.net4(self.net3(self.net2(self.net1(x))))
def get_input(self):
return torch.rand(8, 8, device=device_type.type)
return torch.rand(8, 8, device="cuda")
class TestDummyModelUneven(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(5, 10), nn.ReLU())
@ -57,7 +57,7 @@ class TestDummyModelUneven(torch.nn.Module):
return self.net4(self.net3(self.net2(self.net1(x))))
def get_input(self):
return torch.rand(5, 5, device=device_type.type)
return torch.rand(5, 5, device="cuda")
class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@ -65,29 +65,34 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
dummy_model = (
TestDummyModel() if is_even_sharded_model else TestDummyModelUneven()
)
model = FSDP(dummy_model.to(device_type), device_mesh=device_mesh)
model = FSDP(dummy_model.cuda(), device_mesh=device_mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(model.get_input()).sum().backward()
optim.step()
return model, optim
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("is_even_sharded_model", [True, False])
def test_fsdp_init_with_device_mesh(self, is_even_sharded_model):
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
)
state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optim)
for v in state_dict.values():
self.assertEqual(type(v), DTensor)
self.assertEqual(len(v.placements), 1)
self.assertEqual(v.placements[0], (Shard(dim=0)))
self.assertEqual(v.device_mesh, device_mesh)
for state in optim_state_dict["state"].values():
for k, v in state.items():
if k != "step":
@ -95,6 +100,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
self.assertEqual(len(v.placements), 1)
self.assertEqual(v.placements[0], (Shard(dim=0)))
self.assertEqual(v.device_mesh, device_mesh)
state_dict_type = FSDP.get_state_dict_type(model)
# If device_mesh is used when initializing FSDP, the field _use_dtensor will
# automatically be set to True if StateDictType is set to SHARDED_STATE_DICT.
@ -108,8 +114,9 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
def test_dtensor_sharded_tensor_state_dict_identical(
self, offload_to_cpu, is_even_sharded_model
):
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
@ -120,6 +127,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
)
dtensor_sd = model.state_dict()
dtensor_osd = FSDP.optim_state_dict(model, optim)
ref_model, ref_optim = self._create_model(is_even_sharded_model)
FSDP.set_state_dict_type(
ref_model,
@ -131,6 +139,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
)
sharded_tensor_sd = ref_model.state_dict()
sharded_tensor_osd = FSDP.optim_state_dict(ref_model, ref_optim)
# Check dtensor and sharded_tensor model state dict values are identical
for dtensor_sd_item, sharded_tensor_sd_item in zip(
dtensor_sd.items(), sharded_tensor_sd.items()
@ -138,6 +147,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
k1, v1 = dtensor_sd_item
k2, v2 = sharded_tensor_sd_item
self.assertEqual(k1, k2)
# if the ShardedTensor is an empty shard,
# then the local tensor of DTensor should be local_tensor=tensor([])
if len(v2.local_shards()) == 0:
@ -149,6 +159,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
self.assertEqual(v1.to_local(), v2.local_tensor())
# check whether device are the same
self.assertEqual(v1.to_local().device, v2.local_tensor().device)
# Check dtensor and sharde_tensor optim state dict values are identical
for dtensor_osd_state, sharded_tensor_osd_state in zip(
dtensor_osd["state"].items(), sharded_tensor_osd["state"].items()
@ -162,6 +173,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
k1, v1 = dtensor_hyper_param
k2, v2 = sharded_tensor_hyper_param
self.assertEqual(k1, k2)
if k1 != "step":
# if the ShardedTensor is an empty shard,
# then the local tensor of DTensor should be local_tensor=tensor([])
@ -184,8 +196,9 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
def test_dtensor_sharded_optim_load_state_dict(
self, offload_to_cpu, is_even_sharded_model
):
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
@ -193,13 +206,16 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
offload_to_cpu=offload_to_cpu
),
)
checkpoint = io.BytesIO()
torch.save(FSDP.optim_state_dict(model, optim), checkpoint)
# Deepcopy to save current optim_state_dict to compare with the optim_state_dict loaded back below.
ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))
# Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.
model(model.get_input()).sum().backward()
optim.step()
# Load ref_optim_state_dict back.
checkpoint.seek(0)
load_ref_optim_state_dict = torch.load(checkpoint)
@ -207,6 +223,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
FSDP.optim_state_dict_to_load(model, optim, load_ref_optim_state_dict)
)
new_optim_state_dict = FSDP.optim_state_dict(model, optim)
# Check whether new_optim_state_dict is the same as ref_optim_state_dict.
for new_optim_state_dict_item, ref_optim_state_dict_item in zip(
new_optim_state_dict["state"].items(),
@ -220,10 +237,12 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
):
k1, v1 = new_optim_hyper_param
k2, v2 = ref_optim_hyper_param
# check whether keys are the same
self.assertEqual(k1, k2)
# check whether values are the same
self.assertEqual(v1, v2)
if k1 != "step":
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), DTensor)
@ -235,29 +254,35 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
def test_dtensor_sharded_model_load_state_dict(
self, offload_to_cpu, is_even_sharded_model
):
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),
)
checkpoint = io.BytesIO()
torch.save(model.state_dict(), checkpoint)
# Deepcopy to save current state_dict to compare with the state_dict loaded back below.
ref_state_dict = deepcopy(model.state_dict())
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
model(model.get_input()).sum().backward()
optim.step()
# Load ref_state_dict back.
checkpoint.seek(0)
load_ref_state_dict = torch.load(checkpoint)
model.load_state_dict(load_ref_state_dict)
new_state_dict = model.state_dict()
# Check whether new_state_dict is the same as ref_state_dict.
for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()):
# check whether fqn are the same
self.assertEqual(k1, k2)
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), DTensor)
# check whether DTensor are the same
@ -266,18 +291,20 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
def test_raises_warning_or_errors(self):
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(
is_even_sharded_model=True, device_mesh=device_mesh
)
# initialize optim
model(model.get_input()).sum().backward()
optim.step()
with self.assertRaisesRegex(
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
):
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
with self.assertRaisesRegex(
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
):
@ -285,9 +312,6 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
optim_state_dict = FSDP.optim_state_dict(model, optim)
devices = ("cuda", "hpu")
instantiate_device_type_tests(
TestFSDPWithDeviceMeshAndDTensor, globals(), only_for=devices
)
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
if __name__ == "__main__":
run_tests()

View File

@ -8,18 +8,16 @@ import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
@ -65,7 +63,7 @@ class Model(torch.nn.Module):
)
return z
def get_input(self, device):
def get_input(self, device: torch.device):
return (torch.randn((8, 5)).to(device),)
def get_loss(self, input, output):
@ -88,19 +86,22 @@ class Model(torch.nn.Module):
self.use_alt_path = not self.use_alt_path
@staticmethod
def wrap(sharding_strategy: ShardingStrategy, device):
def wrap(sharding_strategy: ShardingStrategy, device: torch.device):
model = Model()
model.layer1 = FSDP(
model.layer1, sharding_strategy=sharding_strategy, device_id=device
)
model.layer2 = FSDP(
model.layer2, sharding_strategy=sharding_strategy, device_id=device
)
fsdp_model = FSDP(model, sharding_strategy=sharding_strategy, device_id=device)
model.layer1 = FSDP(model.layer1, sharding_strategy=sharding_strategy)
model.layer2 = FSDP(model.layer2, sharding_strategy=sharding_strategy)
fsdp_model = FSDP(model, sharding_strategy=sharding_strategy)
return fsdp_model.to(device)
class TestFSDPExecOrder(FSDPTest):
def setUp(self):
super().setUp()
@property
def device(self):
return torch.device("cuda")
@skip_if_lt_x_gpu(2)
@parametrize(
"sharding_strategy",
@ -108,7 +109,6 @@ class TestFSDPExecOrder(FSDPTest):
)
def test_invalid_first_iter_order(
self,
device,
sharding_strategy: ShardingStrategy,
):
"""Tests that FSDP errors if the all-gather order differs across ranks
@ -116,10 +116,10 @@ class TestFSDPExecOrder(FSDPTest):
# Rank 0 runs the forward pass in one order and all other ranks run in
# different order
dist.set_debug_level(dist.DebugLevel.DETAIL)
fsdp_model = Model.wrap(sharding_strategy, device_type)
fsdp_model = Model.wrap(sharding_strategy, self.device)
if self.rank != 0:
fsdp_model.flip_path()
inp = fsdp_model.module.get_input(device_type)
inp = fsdp_model.module.get_input(self.device)
# Match the error message with the following prefix
error_regex = "^(Forward order differs across ranks)"
with self.assertRaisesRegex(RuntimeError, error_regex):
@ -133,7 +133,6 @@ class TestFSDPExecOrder(FSDPTest):
@parametrize("iters_before_path_change", [1, 3])
def test_invalid_later_iter_order(
self,
device,
sharding_strategy: ShardingStrategy,
iters_before_path_change: int,
):
@ -142,11 +141,11 @@ class TestFSDPExecOrder(FSDPTest):
dist.set_debug_level(dist.DebugLevel.DETAIL)
# On the first iteration, all ranks run the same order, and on the next
# iteration, all but rank 0 run in a different order
fsdp_model = Model.wrap(sharding_strategy, device_type)
fsdp_model = Model.wrap(sharding_strategy, self.device)
for _ in range(iters_before_path_change):
inp = fsdp_model.module.get_input(device_type)
inp = fsdp_model.module.get_input(self.device)
output = fsdp_model(*inp)
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
# Match the warning message with the following prefix
regex = (
@ -164,16 +163,16 @@ class TestFSDPExecOrder(FSDPTest):
)
if self.rank != 0:
fsdp_model.flip_path()
inp = fsdp_model.module.get_input(device_type)
inp = fsdp_model.module.get_input(self.device)
# Expect a warning for the forward pass all-gather
with context: # warning for forward pass all-gather
output = fsdp_model(*inp)
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
# Run an additional iteration to check that there are no more warnings
inp = fsdp_model.module.get_input(device_type)
inp = fsdp_model.module.get_input(self.device)
output = fsdp_model(*inp)
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
@skip_if_lt_x_gpu(2)
@ -181,24 +180,24 @@ class TestFSDPExecOrder(FSDPTest):
"sharding_strategy",
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
)
def test_train_eval(self, device, sharding_strategy: ShardingStrategy):
def test_train_eval(self, sharding_strategy: ShardingStrategy):
dist.set_debug_level(dist.DebugLevel.DETAIL)
fsdp_model = Model.wrap(sharding_strategy, device_type)
fsdp_model = Model.wrap(sharding_strategy, self.device)
NUM_ITERS = 3
NUM_EPOCHS = 2
with warnings.catch_warnings(record=True) as w: # records warnings to `w`
for _ in range(NUM_EPOCHS):
fsdp_model.train()
for _ in range(NUM_ITERS):
inp = fsdp_model.module.get_input(device_type)
inp = fsdp_model.module.get_input(self.device)
output = fsdp_model(*inp)
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
fsdp_model.module.run_backward(loss)
fsdp_model.eval()
for _ in range(NUM_ITERS):
inp = fsdp_model.module.get_input(device_type)
inp = fsdp_model.module.get_input(self.device)
output = fsdp_model(*inp)
fsdp_model.module.get_loss(inp, output).to(device_type)
fsdp_model.module.get_loss(inp, output).to(self.device)
# Check that the order validation warning was not issued (errors do not
# need to be checked since they will be directly reported)
warning_prefix = "Forward order differs"
@ -211,7 +210,7 @@ class TestFSDPExecOrder(FSDPTest):
# an `AssertionError` will be raised above for both sharding strategies
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPExecOrder, globals(), only_for=devices)
instantiate_parametrized_tests(TestFSDPExecOrder)
if __name__ == "__main__":
run_tests()

View File

@ -7,7 +7,6 @@ from unittest import mock
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import _get_device_module
from torch.distributed.fsdp import BackwardPrefetch, CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
@ -15,18 +14,11 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import (
run_tests,
TEST_CUDA,
TEST_WITH_DEV_DBG_ASAN,
)
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
@ -47,12 +39,12 @@ class LinearUnusedInput(nn.Linear):
class ModelUnusedInput(nn.Module):
def __init__(self, freeze: bool):
super().__init__()
self.layer0 = LinearUnusedInput(4, 4)
self.layer1_frozen = LinearUnusedInput(4, 4)
self.layer0 = LinearUnusedInput(4, 4, device="cuda")
self.layer1_frozen = LinearUnusedInput(4, 4, device="cuda")
if freeze:
for param in self.layer1_frozen.parameters():
param.requires_grad = False
self.layer2 = LinearUnusedInput(4, 4)
self.layer2 = LinearUnusedInput(4, 4, device="cuda")
def forward(self, frozen_input, learnable_input):
x = self.layer0(frozen_input, learnable_input)
@ -68,13 +60,13 @@ class TestFSDPFineTune(FSDPTest):
@property
def world_size(self) -> int:
return min(_get_device_module(self.device_type).device_count(), 2)
return min(torch.cuda.device_count(), 2)
def _init_seq_module(self, device) -> nn.Module:
def _init_seq_module(self) -> nn.Module:
torch.manual_seed(42)
modules = []
for _ in range(self.NUM_LINEARS):
modules += [nn.Linear(5, 5, device=device), nn.ReLU()]
modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()]
seq = nn.Sequential(*modules)
self._set_seq_module_requires_grad(seq, False)
return seq
@ -89,14 +81,13 @@ class TestFSDPFineTune(FSDPTest):
param.requires_grad = requires_grad
@skip_if_lt_x_gpu(2)
def test_backward_reshard_hooks(self, device):
def test_backward_reshard_hooks(self):
"""
Tests that the post-backward reshard happens even for flat parameters
that do not require gradients.
"""
self.run_subtests(
{
"device_id": [device],
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
@ -111,21 +102,18 @@ class TestFSDPFineTune(FSDPTest):
def _test_backward_reshard_hooks(
self,
device_id,
sharding_strategy: ShardingStrategy,
use_orig_params: bool,
inp_requires_grad: bool,
unfreeze_params: bool,
):
seq = self._init_seq_module(device_type)
seq = self._init_seq_module()
policy = ModuleWrapPolicy({nn.Linear})
fsdp_kwargs = {"device_id": device_type}
seq = FSDP(
seq,
auto_wrap_policy=policy,
sharding_strategy=sharding_strategy,
use_orig_params=use_orig_params,
**fsdp_kwargs,
)
orig_post_backward_reshard = (
torch.distributed.fsdp._runtime_utils._post_backward_reshard
@ -174,7 +162,7 @@ class TestFSDPFineTune(FSDPTest):
self._set_seq_module_requires_grad(seq, True)
inp = torch.randn(
(8, 5), device=device_type, requires_grad=inp_requires_grad
(8, 5), device="cuda", requires_grad=inp_requires_grad
)
if step_idx == nograd_step_idx:
with torch.no_grad():
@ -187,15 +175,15 @@ class TestFSDPFineTune(FSDPTest):
_assert_post_backward_reshard_count(step_idx, num_steps)
post_backward_reshard_count = 0
def _init_multi_traversal_module(self, device) -> nn.Module:
def _init_multi_traversal_module(self) -> nn.Module:
torch.manual_seed(42)
class TestModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer_0 = nn.Linear(5, 5, device=device)
self.layer_no_grad = nn.Linear(5, 5, device=device)
self.layer_with_grad = nn.Linear(5, 5, device=device)
self.layer_0 = nn.Linear(5, 5, device="cuda")
self.layer_no_grad = nn.Linear(5, 5, device="cuda")
self.layer_with_grad = nn.Linear(5, 5, device="cuda")
self.layer_no_grad.requires_grad_(False)
def forward(self, x):
@ -240,26 +228,22 @@ class TestFSDPFineTune(FSDPTest):
inp_requires_grad: bool,
forward_prefetch: bool,
):
seq = self._init_multi_traversal_module(device_type.type)
seq = self._init_multi_traversal_module()
policy = ModuleWrapPolicy({nn.Linear})
fsdp_kwargs = {"device_id": device_type}
fsdp_seq = FSDP(
copy.deepcopy(seq),
auto_wrap_policy=policy,
sharding_strategy=sharding_strategy,
use_orig_params=use_orig_params,
forward_prefetch=forward_prefetch,
**fsdp_kwargs,
)
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[self.rank])
fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
torch.manual_seed(self.rank + 1)
losses = []
for _ in range(6):
inp = torch.randn(
(8, 5), device=device_type, requires_grad=inp_requires_grad
)
inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad)
for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
loss = seq(inp).sum()
losses.append(loss)
@ -292,37 +276,32 @@ class TestFSDPFineTune(FSDPTest):
sharding_strategy: ShardingStrategy,
use_orig_params: bool,
):
seq = self._init_seq_module(device_type)
seq = self._init_seq_module()
policy = ModuleWrapPolicy({nn.Linear})
fsdp_kwargs = {"device_id": device_type}
fsdp_seq = FSDP(
copy.deepcopy(seq),
auto_wrap_policy=policy,
sharding_strategy=sharding_strategy,
use_orig_params=use_orig_params,
**fsdp_kwargs,
)
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[self.rank])
fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
torch.manual_seed(self.rank + 1)
losses = []
for _ in range(6):
inp = torch.randn((8, 5), device=device_type.type)
inp = torch.randn((8, 5), device="cuda")
for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
loss = seq(inp).sum()
losses.append(loss)
loss.backward()
optim.step()
optim.zero_grad()
if TEST_CUDA:
torch.testing.assert_close(losses[0], losses[1])
else:
torch.testing.assert_close(losses[0], losses[1], atol=1e-03, rtol=1e-03)
torch.testing.assert_close(losses[0], losses[1])
losses.clear()
@skip_if_lt_x_gpu(2)
def test_parity_with_non_frozen_fsdp(self, device):
def test_parity_with_non_frozen_fsdp(self):
"""
For frozen modules with unused input, reshard could happen without unshard
Verify numerical parity between `_post_backward_reshard_only_hook` and
@ -330,7 +309,6 @@ class TestFSDPFineTune(FSDPTest):
"""
self.run_subtests(
{
"device_id": [device],
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
@ -355,7 +333,6 @@ class TestFSDPFineTune(FSDPTest):
def _test_parity_with_non_frozen_fsdp(
self,
device_id,
sharding_strategy: ShardingStrategy,
use_orig_params: bool,
offload_params: bool,
@ -363,11 +340,10 @@ class TestFSDPFineTune(FSDPTest):
backward_prefetch: BackwardPrefetch,
):
torch.manual_seed(42)
model = ModelUnusedInput(freeze=True).to(device_type)
model = ModelUnusedInput(freeze=True)
torch.manual_seed(42)
ref_model = ModelUnusedInput(freeze=False).to(device_type)
ref_model = ModelUnusedInput(freeze=False)
fsdp_kwargs = {
"device_id": device_type,
"auto_wrap_policy": ModuleWrapPolicy({LinearUnusedInput}),
"sharding_strategy": sharding_strategy,
"use_orig_params": use_orig_params,
@ -389,10 +365,8 @@ class TestFSDPFineTune(FSDPTest):
torch.manual_seed(self.rank + 1)
losses = []
for idx in range(6):
frozen_input = torch.randn((4, 4), device=device_type, requires_grad=False)
learnable_input = torch.randn(
(4, 4), device=device_type, requires_grad=True
)
frozen_input = torch.randn((4, 4), device="cuda", requires_grad=False)
learnable_input = torch.randn((4, 4), device="cuda", requires_grad=True)
for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
loss = _model(frozen_input, frozen_input).sum()
losses.append(loss)
@ -407,7 +381,5 @@ class TestFSDPFineTune(FSDPTest):
self.assertEqual(param, ref_param)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPFineTune, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -1,8 +1,12 @@
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.fsdp._trace_utils import _ExecOrderTracer
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TestCase,
)
class Model(torch.nn.Module):
@ -113,7 +117,7 @@ class TestSymbolicTracing(TestCase):
self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestSymbolicTracing, globals(), only_for=devices)
instantiate_parametrized_tests(TestSymbolicTracing)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
@ -6,10 +7,10 @@ from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module
from torch.optim import SGD
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
@ -20,6 +21,7 @@ from torch.testing._internal.common_utils import (
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -35,11 +37,11 @@ class TestInput(FSDPTest):
@skip_if_lt_x_gpu(1)
@parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")])
def test_input_type(self, device, input_cls):
def test_input_type(self, input_cls):
"""Test FSDP with input being a list or a dict, only single GPU."""
class Model(Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.layer = Linear(4, 4)
@ -51,26 +53,25 @@ class TestInput(FSDPTest):
input = input["in"]
return self.layer(input)
fsdp_kwargs = {
"device_id": device,
}
model = FSDP(Model().to(device), **fsdp_kwargs)
model = FSDP(Model()).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(5):
in_data = torch.rand(64, 4).to(device)
in_data = torch.rand(64, 4).cuda()
in_data.requires_grad = True
if input_cls is list:
in_data = [in_data]
else:
self.assertTrue(input_cls is dict)
in_data = {"in": in_data}
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestInput, globals(), only_for=devices)
instantiate_parametrized_tests(TestInput)
if __name__ == "__main__":
run_tests()

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import sys
import unittest
import torch
import torch.nn as nn
@ -14,7 +13,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
from torch.utils.checkpoint import checkpoint
@ -160,7 +158,6 @@ class TestFSDPMemory(FSDPTest):
output = cmp(results, expected)
self.assertEqual(output, "")
@unittest.skipIf(TEST_HPU, "Memory will be differnt for CUDA and HPU, skipping")
@skip_if_lt_x_gpu(2)
@parametrize("ckpt", ["no_ckpt", "ckpt"])
def test_fsdp_memory(self, ckpt):
@ -232,5 +229,7 @@ class TestFSDPMemory(FSDPTest):
instantiate_parametrized_tests(TestFSDPMemory)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
@ -7,17 +8,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, get_full_params
from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -47,33 +46,37 @@ class TestMultiForward(FSDPTest):
def _dist_train(self, wrap_fsdp):
# keep everything deterministic for input data
torch.manual_seed(0)
model = Model(wrap_fsdp).to(device_type.type)
model = Model(wrap_fsdp).cuda()
if wrap_fsdp:
model = FSDP(model, device_id=device_type.type)
model = FSDP(model)
else:
model = DistributedDataParallel(model, device_ids=[device_type.type])
model = DistributedDataParallel(model, device_ids=[self.rank])
optim = SGD(model.parameters(), lr=0.1)
in_data = torch.rand(64, 4).to(device_type.type)
in_data = torch.rand(64, 4).cuda()
in_data.requires_grad = True
for _ in range(3):
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
if wrap_fsdp:
return get_full_params(model)
return list(model.parameters())
@skip_if_lt_x_gpu(2)
def test_multi_forward(self):
# DDP
ddp_state = self._dist_train(wrap_fsdp=False)
# FSDP
fsdp_state = self._dist_train(wrap_fsdp=True)
self.assertEqual(ddp_state, fsdp_state)
devices = ("cpu", "hpu")
instantiate_device_type_tests(TestMultiForward, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
@ -6,17 +7,15 @@ from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module, Sequential
from torch.optim import SGD
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -26,9 +25,9 @@ if TEST_WITH_DEV_DBG_ASAN:
class InnerModel(Module):
def __init__(self, device):
def __init__(self) -> None:
super().__init__()
self.layers = Sequential(FSDP(Linear(5, 5), device_id=device_type.type))
self.layers = Sequential(FSDP(Linear(5, 5)))
def forward(self, x):
return self.layers(x)
@ -36,32 +35,32 @@ class InnerModel(Module):
class TestMultipleWrapping(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_multiple_wrapping(self, device):
def test_multiple_wrapping(self):
"""
This test simulates wrapping the module after training to run inference.
This is required in cases where later in a session, the model is wrapped again in FSDP but
contains nested FSDP wrappers within the module.
"""
inner_model = InnerModel(device)
model = FSDP(inner_model).to(device_type.type)
inner_model = InnerModel()
model = FSDP(inner_model).cuda()
optim = SGD(model.parameters(), lr=0.1)
for i in range(3):
input = torch.rand((1, 5), dtype=torch.float).to(device_type.type)
input = torch.rand((1, 5), dtype=torch.float).cuda()
input.requires_grad = True
output = model(input)
output.sum().backward()
optim.step()
optim.zero_grad()
input = torch.rand((1, 5), dtype=torch.float).to(device_type.type)
input = torch.rand((1, 5), dtype=torch.float).cuda()
output = model(input)
# second time to rewrap the inner model
# rewrapped_model = FSDP(inner_model, device_id=device)
rewrapped_model = FSDP(inner_model).to(device_type.type)
rewrapped_model = FSDP(inner_model).cuda()
rewrapped_output = rewrapped_model(input)
self.assertEqual(output, rewrapped_output)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestMultipleWrapping, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -2,7 +2,6 @@
import sys
import time
import unittest
from statistics import mean
from unittest.mock import patch
@ -11,13 +10,11 @@ import torch.nn as nn
from torch import distributed as dist
from torch.cuda import Event
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
run_tests,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
@ -244,7 +241,6 @@ class TestForwardOverlapWorldSizeOne(FSDPTest):
both = e4["gpu_total"]
self.assertTrue(compute_only + all_gather_only > 1.1 * both)
@unittest.skipIf(TEST_HPU, "HPU doesn't has HW sleep API support, skipping")
@skip_if_lt_x_gpu(2)
def test_forward_overlap(self):
self._dist_train()
@ -256,9 +252,5 @@ class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
return 2
devices = ("cuda", "hpu")
instantiate_device_type_tests(
TestForwardOverlapWorldSizeOne, globals(), only_for=devices
)
if __name__ == "__main__":
run_tests()

View File

@ -10,20 +10,20 @@ from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
FSDPInitMode,
FSDPTest,
get_devtype,
NestedWrappedModule,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
@ -37,6 +37,11 @@ if TEST_WITH_DEV_DBG_ASAN:
class TestPureFP16(FSDPTest):
@property
def world_size(self):
# Test fails due to inaccuracies when using more than 4 GPUs
return min(4, super().world_size)
@skip_if_lt_x_gpu(2)
def test_pure_fp16_training(self):
"""Tests pure FP16 training, including when the parameter's dtype is
@ -97,13 +102,11 @@ class TestPureFP16(FSDPTest):
self.process_group,
FSDPInitMode.NO_FSDP,
DEVICEInitMode.DEVICE_NEVER,
{
"device_id": device_type,
},
{},
)
fsdp_kwargs = {
"use_orig_params": use_orig_params,
"device_id": device_type,
"device_id": torch.cuda.current_device(),
"mixed_precision": mixed_precision,
}
if to_half_before_fsdp_init:
@ -115,7 +118,7 @@ class TestPureFP16(FSDPTest):
self.assertEqual(param.dtype, torch.float16)
inp = tuple(
t.half() if torch.is_tensor(t) else t
for t in fsdp_model.module.get_input(self.device_type)
for t in fsdp_model.module.get_input(torch.device("cuda"))
)
out = fsdp_model(*inp)
out.sum().backward()
@ -151,7 +154,7 @@ class TestPureFP16(FSDPTest):
self.assertEqual(param.grad.dtype, torch.float16)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices)
instantiate_parametrized_tests(TestPureFP16)
if __name__ == "__main__":
run_tests()

View File

@ -1,9 +1,9 @@
# Owner(s): ["oncall: distributed"]
import sys
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_AS
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -56,7 +57,5 @@ class TestTraversal(FSDPTest):
)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestTraversal, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -7,9 +7,8 @@ from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear
from torch.optim import SGD
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
@ -24,39 +23,37 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
device_type = torch.device(get_devtype())
class TestUnevenParamShard(FSDPTest):
def _get_ref_results(self, device, model, input, my_lr):
def _get_ref_results(self, model, input, my_lr):
with torch.no_grad():
# Compute one iteration local output.
weight = model.weight.T.clone().to(device_type)
v = torch.Tensor(input[self.rank]).to(device_type)
weight = model.weight.T.clone().to(self.rank)
v = torch.Tensor(input[self.rank]).to(self.rank)
ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update.
v = torch.Tensor(input[: self.world_size]).to(device_type)
v = torch.Tensor(input[: self.world_size]).to(self.rank)
grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
ref_weight_out = weight - grad.T * my_lr
return ref_forward_output_my_rank, ref_weight_out
@skip_if_lt_x_gpu(2)
def test_one_iteration(self, device):
def test_one_iteration(self):
"""Test FSDP with uneven divide of parameter shards."""
model = Linear(3, 3, bias=False)
input = torch.rand(8, 3)
my_lr = 0.1
ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
device, model, input, my_lr
model, input, my_lr
)
model.to(device_type)
model.to(self.rank)
model = FSDP(model)
optim = SGD(model.parameters(), lr=my_lr)
self.assertTrue(len(input) >= self.world_size)
in_data = torch.Tensor(input[self.rank]).to(device_type)
in_data = torch.Tensor(input[self.rank]).to(self.rank)
out = model(in_data)
out.float().sum().backward()
optim.step()
@ -68,7 +65,5 @@ class TestUnevenParamShard(FSDPTest):
self.assertEqual(ref_weight_out, weight_out)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestUnevenParamShard, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -25,15 +25,12 @@ from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
FSDPInitMode,
FSDPTest,
get_devtype,
NestedWrappedModule,
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
@ -51,6 +48,10 @@ class TestUnshardParamsBase(FSDPTest):
This contains any methods common to both the sharded and non-sharded cases.
"""
@property
def device(self) -> torch.device:
return torch.device("cuda", self.rank)
def _test_unshard_params_writeback(
self,
writeback: bool,
@ -58,8 +59,8 @@ class TestUnshardParamsBase(FSDPTest):
**fsdp_kwargs: Dict[str, Any],
):
model = nn.Sequential(
nn.Linear(5, 5, bias=False, device=device_type.type),
nn.Linear(5, 3, bias=False, device=device_type.type),
nn.Linear(5, 5, bias=False, device=self.device),
nn.Linear(5, 3, bias=False, device=self.device),
)
model[0] = FSDP(model[0], **fsdp_kwargs)
model = FSDP(model, **fsdp_kwargs)
@ -125,7 +126,7 @@ class TestUnshardParamsBase(FSDPTest):
self.process_group,
FSDPInitMode.NO_FSDP,
DEVICEInitMode.DEVICE_BEFORE,
fsdp_kwargs={"device_id": device_type.type},
fsdp_kwargs={},
deterministic=True,
)
# Apply FSDP such that the root module does not have FSDP applied,
@ -237,7 +238,7 @@ class TestUnshardParams(TestUnshardParamsBase):
testing that writing to padding does not persist.
NOTE: This method depends on FSDP internals.
"""
model = FSDP(nn.Linear(1, 1, bias=False, device=device_type.type))
model = FSDP(nn.Linear(1, 1, bias=False, device=self.device))
flat_param = model._handle.flat_param
self.assertEqual(1, flat_param.numel())
# Write a known value to the *sharded* `FlatParameter`
@ -290,10 +291,8 @@ class TestUnshardParams(TestUnshardParamsBase):
}
model = FSDP(
nn.Sequential(
FSDP(
nn.Linear(5, 5, bias=False, device=device_type.type), **fsdp_kwargs
),
nn.Linear(5, 3, bias=False, device=device_type.type),
FSDP(nn.Linear(5, 5, bias=False, device=self.device), **fsdp_kwargs),
nn.Linear(5, 3, bias=False, device=self.device),
),
**fsdp_kwargs,
)
@ -310,7 +309,7 @@ class TestUnshardParams(TestUnshardParamsBase):
# Validate the expected behavior: the root does not reshard after
# forward; the non-root reshards after forward; and both reshard after
# backward
output = model(torch.zeros(5, device=device_type.type))
output = model(torch.zeros(5, device=self.device))
self.assertEqual(
expected_outer_flat_param_unsharded_numel,
_get_unsharded_storage_size(outer_flat_param),
@ -322,7 +321,7 @@ class TestUnshardParams(TestUnshardParamsBase):
# Check that with parameter unsharding in between forward and backward
# as well as after backward, the reshard behavior matches
output = model(torch.zeros(5, device=device_type.type))
output = model(torch.zeros(5, device=self.device))
with FSDP.summon_full_params(
model,
rank0_only=rank0_only,
@ -379,10 +378,8 @@ class TestUnshardParams(TestUnshardParamsBase):
}
model = FSDP(
nn.Sequential(
FSDP(
nn.Linear(5, 5, bias=False, device=device_type.type), **fsdp_kwargs
),
nn.Linear(5, 3, bias=False, device=device_type.type),
FSDP(nn.Linear(5, 5, bias=False, device=self.device), **fsdp_kwargs),
nn.Linear(5, 3, bias=False, device=self.device),
),
**fsdp_kwargs,
)
@ -560,7 +557,7 @@ class TestUnshardParams(TestUnshardParamsBase):
DEVICEInitMode.DEVICE_BEFORE,
deterministic=True,
)
ddp_model = DDP(model, device_ids=[device_type])
ddp_model = DDP(model, device_ids=[self.rank])
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
@ -569,7 +566,6 @@ class TestUnshardParams(TestUnshardParamsBase):
fsdp_kwargs={
"use_orig_params": use_orig_params,
"sharding_strategy": sharding_strategy,
"device_id": device_type.type,
},
)
with FSDP.summon_full_params(fsdp_model):
@ -577,7 +573,7 @@ class TestUnshardParams(TestUnshardParamsBase):
assert torch.all(torch.isclose(p1, p2))
# Check calling after backward
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
ddp_out = ddp_model(*inp)
fsdp_out = fsdp_model(*inp)
ddp_out.sum().backward()
@ -587,7 +583,7 @@ class TestUnshardParams(TestUnshardParamsBase):
_check_grads(ddp_model, fsdp_model, old_fsdp_grads)
# Check calling between forward and backward
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
ddp_out = ddp_model(*inp)
fsdp_out = fsdp_model(*inp)
old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
@ -620,7 +616,6 @@ class TestUnshardParams(TestUnshardParamsBase):
fsdp_kwargs={
"use_orig_params": True,
"sharding_strategy": sharding_strategy,
"device_id": device_type.type,
},
)
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
@ -635,7 +630,7 @@ class TestUnshardParams(TestUnshardParamsBase):
model = nn.Sequential(
nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
).to(device_type.type)
).cuda()
model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
with FSDP.summon_full_params(model[0]):
# Check that the summoned module does not have its flat parameter
@ -689,7 +684,7 @@ class TestUnshardParamsErrors(TestUnshardParamsBase):
with fsdp_module.summon_full_params(fsdp_module):
pass
model = FSDP(MyModule()).to(device_type.type)
model = FSDP(MyModule()).cuda(self.rank)
with self.assertRaisesRegex(
AssertionError, "Cannot manually unshard parameters during forward/backward"
):
@ -697,8 +692,8 @@ class TestUnshardParamsErrors(TestUnshardParamsBase):
@skip_if_lt_x_gpu(2)
def test_unshard_params_from_backward_raises(self):
model = FSDP(nn.Linear(2, 1, device=device_type.type))
output = model(torch.ones(2, device=device_type.type))
model = FSDP(nn.Linear(2, 1, device=self.device))
output = model(torch.ones(2, device=self.device))
def invalid_backward_hook(*args, **kwargs):
with FSDP.summon_full_params(model):

View File

@ -16,9 +16,11 @@ from torch.distributed.fsdp.api import (
ShardingStrategy,
StateDictType,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
@ -26,9 +28,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
# Simple and boring model to test interface and some corner cases that do not
# require complicated wrapping strategy.
class DenseModel(torch.nn.Module):
@ -43,17 +42,16 @@ class DenseModel(torch.nn.Module):
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
def get_input(self, device):
return torch.rand(4, 8, device=device)
def get_input(self):
return torch.rand(4, 8, device="cuda")
# TODO: Consolidate DeviceMesh based FSDP and HSDP test cases.
class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
def _create_model(self, device, device_mesh=None):
device_id = device_type
def _create_model(self, device_mesh=None):
if device_mesh:
model = FSDP(
DenseModel().to(device_id),
DenseModel().cuda(),
device_mesh=device_mesh,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
@ -62,23 +60,22 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
intra_node_pg = mesh_2d.get_group(mesh_dim=1)
inter_node_pg = mesh_2d.get_group(mesh_dim=0)
model = FSDP(
DenseModel().to(device_id),
DenseModel().cuda(),
process_group=(intra_node_pg, inter_node_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(model.get_input(device_id)).sum().backward()
model(model.get_input()).sum().backward()
optim.step()
return model, optim
@with_comms
@skip_if_lt_x_gpu(4)
def test_hsdp_init_with_device_mesh(self, device):
device_id = device_type
def test_hsdp_init_with_device_mesh(self):
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model, optim = self._create_model(device_id, mesh_2d)
model, optim = self._create_model(mesh_2d)
FSDP.set_state_dict_type(
model,
@ -110,11 +107,10 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_tensor_state_dict_identical(self, device, offload_to_cpu):
device_id = device_type
def test_dtensor_sharded_tensor_state_dict_identical(self, offload_to_cpu):
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model, optim = self._create_model(mesh_2d)
model, optim = self._create_model(device_id, mesh_2d)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
@ -126,7 +122,7 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
dtensor_sd = model.state_dict()
dtensor_osd = FSDP.optim_state_dict(model, optim)
ref_model, ref_optim = self._create_model(device_id)
ref_model, ref_optim = self._create_model()
FSDP.set_state_dict_type(
ref_model,
StateDictType.SHARDED_STATE_DICT,
@ -180,10 +176,9 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_optim_load_state_dict(self, device, offload_to_cpu):
device_id = device_type
def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu):
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model, optim = self._create_model(device_id, mesh_2d)
model, optim = self._create_model(mesh_2d)
FSDP.set_state_dict_type(
model,
@ -199,7 +194,7 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))
# Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.
model(model.get_input(device_id)).sum().backward()
model(model.get_input()).sum().backward()
optim.step()
# Load ref_optim_state_dict back.
@ -235,10 +230,9 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_model_load_state_dict(self, device, offload_to_cpu):
device_id = device_type
def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model, optim = self._create_model(device_id, mesh_2d)
model, optim = self._create_model(mesh_2d)
FSDP.set_state_dict_type(
model,
@ -252,7 +246,7 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
ref_state_dict = deepcopy(model.state_dict())
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
model(model.get_input(device_id)).sum().backward()
model(model.get_input()).sum().backward()
optim.step()
# Load ref_state_dict back.
@ -273,15 +267,13 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
def test_root_module_is_not_FSDP(self, device):
device_id = device_type
def test_root_module_is_not_FSDP(self):
class FakeMPModel(torch.nn.Module):
def __init__(self, device_mesh):
super().__init__()
torch.manual_seed(0)
self.dense = FSDP(
DenseModel().to(device_id),
DenseModel().cuda(),
use_orig_params=True,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
device_mesh=device_mesh,
@ -300,10 +292,10 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
return self.dense(sparse)
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model = FakeMPModel(device_mesh=mesh_2d).to(device_id)
model = FakeMPModel(device_mesh=mesh_2d).cuda()
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
batch = torch.rand(5, 8, device=self.device_type)
batch = torch.rand(5, 8, device=torch.device("cuda"))
model(batch).sum().backward()
optim.step()
osd = optim.state_dict()
@ -324,9 +316,6 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
self.assertIsInstance(state["exp_avg_sq"], torch.Tensor)
devices = ("cuda", "hpu")
instantiate_device_type_tests(
TestHSDPWithDeviceMeshAndDTensor, globals(), only_for=devices
)
instantiate_parametrized_tests(TestHSDPWithDeviceMeshAndDTensor)
if __name__ == "__main__":
run_tests()

View File

@ -2,6 +2,7 @@
import random
import sys
import unittest
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
@ -10,13 +11,11 @@ import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
@ -33,25 +32,22 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
list_device = "hpu" if TEST_HPU else "cuda"
class TestUtils(TestCase):
@parametrize(
"device_list",
[
["cpu"],
[list_device],
subtest(["cpu", list_device], name=f"cpu_{list_device}"),
],
"devices", [["cpu"], ["cuda"], subtest(["cpu", "cuda"], name="cpu_cuda")]
)
@skip_if_lt_x_gpu(1)
def test_apply_to_tensors(self, device_list):
def test_apply_to_tensors(self, devices):
if "cuda" in devices and (
not torch.cuda.is_available() or torch.cuda.device_count() < 1
):
raise unittest.SkipTest("Skipped due to lack of GPU")
expected = 0
def get_a_tensor():
"""Return a random tensor on random device."""
dev = random.choice(device_list)
dev = random.choice(devices)
shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
t = torch.rand(shape).to(dev)
nonlocal expected
@ -95,7 +91,6 @@ class TestUtils(TestCase):
for i, v in enumerate(data):
self.assertEqual(type(new_data[i]), type(v))
@skip_if_lt_x_gpu(1)
def test_replace_by_prefix(self):
state_dict = {
"layer.a": torch.tensor(1),
@ -112,7 +107,6 @@ class TestUtils(TestCase):
_replace_by_prefix(state_dict, "module.layer.", "layer.")
assert state_dict == original_state_dict
@skip_if_lt_x_gpu(1)
def test_packed_sequence(self):
"""Test to ensure RNN packed sequences are modified correctly."""
rnn = nn.RNN(5, 5)
@ -130,7 +124,7 @@ class TestUtils(TestCase):
self.assertEqual(torch.sum(x), 0)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestUtils, globals(), only_for=devices)
instantiate_parametrized_tests(TestUtils)
if __name__ == "__main__":
run_tests()

View File

@ -164,10 +164,6 @@ def _assert_module_states(
assert_fn(p1, p2)
def get_devtype():
return torch.device(DEVICE_TYPE)
def _zero_model(
model: nn.Module,
zero_buffers: bool = False,
@ -659,7 +655,7 @@ class ModuleWithDelay(FSDPTestModel):
loss = self.module.get_loss(input, output) # type: ignore[operator]
if self.delay_after_loss_ms > 0:
if TEST_HPU:
time.sleep(self.delay_after_loss_ms / 1000)
time.sleep(self.delay_before_reduction_ms / 1000)
elif TEST_CUDA:
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
@ -1374,12 +1370,7 @@ class FSDPTest(MultiProcessTestCase):
**init_kwargs,
)
if ref_init_fn is None:
if TEST_HPU:
ref_model = DDP(
model, device_ids=[DEVICE_TYPE], output_device=DEVICE_TYPE
)
else:
ref_model = DDP(model, device_ids=[rank], output_device=rank)
ref_model = DDP(model, device_ids=[rank], output_device=rank)
else:
ref_model = ref_init_fn(model)
if use_pure_fp16: