mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Tests Generelization for multiple accelerator devices (#139184)"
This reverts commit b576a8c318.
Reverted https://github.com/pytorch/pytorch/pull/139184 on behalf of https://github.com/clee2000 due to Failing internally when trying to pickle distributed test files D67098795 ([comment](https://github.com/pytorch/pytorch/pull/139184#issuecomment-2539610187))
This commit is contained in:
parent
2f0fe82f6d
commit
c85323c5e8
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
|
|
@ -15,7 +16,6 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|||
OffloadWrapper,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
|
@ -23,8 +23,6 @@ from torch.utils.checkpoint import checkpoint
|
|||
_SAVED_PREFIX = "_saved_"
|
||||
GRAD_FN_NEXT_FUNCTIONS = "next_functions"
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class CheckpointWrapperTest(TestCase):
|
||||
def test_load_activation_checkpointed_module(self):
|
||||
|
|
@ -132,6 +130,7 @@ class CheckpointWrapperTest(TestCase):
|
|||
m(torch.randn(2, 1)).sum().backward()
|
||||
self.assertEqual(2, count)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||
def test_checkpoint_wrapper_parity(self):
|
||||
"""
|
||||
Tests that using checkpoint_wrapper or the functional
|
||||
|
|
@ -156,11 +155,9 @@ class CheckpointWrapperTest(TestCase):
|
|||
self.use_reentrant = use_reentrant
|
||||
wrp = partial(
|
||||
checkpoint_wrapper,
|
||||
checkpoint_impl=(
|
||||
CheckpointImpl.REENTRANT
|
||||
if use_reentrant
|
||||
else CheckpointImpl.NO_REENTRANT
|
||||
),
|
||||
checkpoint_impl=CheckpointImpl.REENTRANT
|
||||
if use_reentrant
|
||||
else CheckpointImpl.NO_REENTRANT,
|
||||
)
|
||||
for i in range(self.n):
|
||||
l = nn.Sequential(
|
||||
|
|
@ -187,12 +184,12 @@ class CheckpointWrapperTest(TestCase):
|
|||
use_checkpointing,
|
||||
use_wrapper=use_wrapper,
|
||||
use_reentrant=use_reentrant,
|
||||
).to(device_type.type)
|
||||
x = torch.randn(10000, 256, requires_grad=True).to(device_type.type)
|
||||
torch.get_device_module(device_type.type).reset_peak_memory_stats()
|
||||
).cuda()
|
||||
x = torch.randn(10000, 256, requires_grad=True).cuda()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
loss = a(x).sum()
|
||||
loss.backward()
|
||||
return torch.get_device_module(device_type.type).max_memory_allocated()
|
||||
return torch.cuda.max_memory_allocated()
|
||||
|
||||
functional_no_reentrant = test(
|
||||
use_checkpointing=True, use_wrapper=False, use_reentrant=False
|
||||
|
|
@ -336,12 +333,13 @@ class CheckpointWrapperTest(TestCase):
|
|||
for fqn, _ in lin.named_parameters():
|
||||
self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||
def test_checkpoint_wrapper_cpu_offload(self):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 10),
|
||||
nn.Linear(10, 10),
|
||||
nn.Linear(10, 10),
|
||||
).to(device_type.type)
|
||||
).cuda()
|
||||
|
||||
# Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for
|
||||
# testing, otherwise the tensor access during the DFS will cause orig
|
||||
|
|
@ -360,7 +358,7 @@ class CheckpointWrapperTest(TestCase):
|
|||
|
||||
model = offload_wrapper(model)
|
||||
|
||||
inp = torch.randn(3, 10, device=device_type.type)
|
||||
inp = torch.randn(3, 10, device="cuda")
|
||||
loss = model(inp).sum()
|
||||
|
||||
# All autograd saved tensors should be offloaded to CPU.
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter, loa
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp.wrap import enable_wrap, wrap
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, SkipModel
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
|
|
@ -85,7 +85,7 @@ class TestDistributedCheckpoint(FSDPTest):
|
|||
# TODO: add resharding test case.
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestDistributedCheckpoint)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
DEVICEInitMode,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
NestedWrappedModule,
|
||||
TransformerWithSharedParams,
|
||||
)
|
||||
|
|
@ -30,8 +28,6 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestApply(FSDPTest):
|
||||
@property
|
||||
|
|
@ -71,12 +67,10 @@ class TestApply(FSDPTest):
|
|||
def test_nested_module_apply(self):
|
||||
"""Tests that ``apply()`` modifies parameter values in-place on a
|
||||
non-FSDP-root nested FSDP-wrapped model."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
nested_wrapped_module = NestedWrappedModule.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
self._check_apply(nested_wrapped_module)
|
||||
|
||||
|
|
@ -84,12 +78,10 @@ class TestApply(FSDPTest):
|
|||
def test_transformer_module_apply(self):
|
||||
"""Tests that ``apply()`` modifies parameter values in-place on an
|
||||
FSDP-wrapped transformer model with shared parameters."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
transformer = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
self._check_apply(transformer)
|
||||
|
||||
|
|
@ -97,19 +89,15 @@ class TestApply(FSDPTest):
|
|||
def test_apply_in_summon_raises_error(self):
|
||||
"""Tests that calling ``apply()`` on an FSDP instance inside the
|
||||
``summon_full_params()`` context raises an error."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
transformer = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
with transformer.summon_full_params(transformer):
|
||||
with self.assertRaisesRegex(ValueError, "expected to be in states"):
|
||||
transformer.apply(self._init_linear_weights)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestApply, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -16,12 +16,10 @@ from torch.distributed.fsdp._runtime_utils import (
|
|||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
NUM_ITERS = 2
|
||||
DECODER_PARAM_FQNS = [
|
||||
"decoder.layers.{index}.self_attn.in_proj_weight",
|
||||
|
|
@ -83,13 +81,14 @@ class TestBackwardPrefetch(FSDPTest):
|
|||
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
|
||||
rank = self.rank
|
||||
orig_get_handle_to_prefetch = _get_handle_to_prefetch
|
||||
|
||||
torch.manual_seed(0)
|
||||
policy = ModuleWrapPolicy(
|
||||
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
|
||||
)
|
||||
model = FSDP(
|
||||
nn.Transformer(d_model=1024, nhead=8, device=device_type),
|
||||
device_id=device_type.type,
|
||||
nn.Transformer(d_model=1024, nhead=8, device="cuda"),
|
||||
device_id=torch.cuda.current_device(),
|
||||
auto_wrap_policy=policy,
|
||||
use_orig_params=True,
|
||||
backward_prefetch=backward_prefetch,
|
||||
|
|
@ -98,8 +97,8 @@ class TestBackwardPrefetch(FSDPTest):
|
|||
|
||||
# prepare input
|
||||
torch.manual_seed(rank + 1)
|
||||
src = torch.randn((10, 1, 1024), device=device_type)
|
||||
tgt = torch.randn((20, 1, 1024), device=device_type)
|
||||
src = torch.randn((10, 1, 1024), device="cuda")
|
||||
tgt = torch.randn((20, 1, 1024), device="cuda")
|
||||
|
||||
# monkey patch
|
||||
all_handle_fqns: List[List[str]] = []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
|
@ -16,9 +17,8 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|||
CPUOffload,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
@ -28,17 +28,18 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_save_on_cpu_called = False
|
||||
|
||||
|
||||
|
|
@ -82,18 +83,22 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
**fsdp_kwargs,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
super().__init__()
|
||||
l1 = nn.Linear(3, 3).to(device_type.type)
|
||||
l2 = nn.Linear(3, 3).to(device_type.type)
|
||||
l3 = nn.Linear(3, 3).to(device_type.type)
|
||||
l1 = nn.Linear(3, 3).cuda()
|
||||
l2 = nn.Linear(3, 3).cuda()
|
||||
l3 = nn.Linear(3, 3).cuda()
|
||||
|
||||
if checkpoint_layer:
|
||||
if offload_activations:
|
||||
ckpt_wrapper = offload_wrapper
|
||||
else:
|
||||
ckpt_wrapper = checkpoint_wrapper
|
||||
|
||||
l1 = ckpt_wrapper(l1)
|
||||
l2 = ckpt_wrapper(l2)
|
||||
l3 = ckpt_wrapper(l3)
|
||||
|
||||
fsdp_wrapper = partial(
|
||||
_maybe_wrap_fsdp, *fsdp_args, wrap_fsdp=wrap_fsdp, **fsdp_kwargs
|
||||
)
|
||||
|
|
@ -110,9 +115,11 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
assert losses
|
||||
assert outputs
|
||||
assert models
|
||||
|
||||
for l, o in zip(losses[1:], outputs[1:]):
|
||||
self.assertEqual(losses[0], l)
|
||||
self.assertEqual(outputs[0], o)
|
||||
|
||||
# Verify grads
|
||||
ref_model = models[0]
|
||||
ref_grads = [p.grad for p in ref_model.parameters()]
|
||||
|
|
@ -139,6 +146,7 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
wrapper_to_use = offload_wrapper
|
||||
else:
|
||||
wrapper_to_use = checkpoint_wrapper
|
||||
|
||||
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
|
||||
ckpt_sequential_wrapped_fsdp = wrapper_to_use(
|
||||
TestFSDPCheckpoint.SequentialModule(
|
||||
|
|
@ -153,13 +161,16 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
wrap_fsdp=True,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
||||
baseline = TestFSDPCheckpoint.SequentialModule(
|
||||
wrap_fsdp=True,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
||||
# note that reentrant-based checkpointing requires inputs to have grad
|
||||
# flag set.
|
||||
inp = torch.randn(10, 3, device=device_type.type, requires_grad=True)
|
||||
inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)
|
||||
|
||||
global _save_on_cpu_called
|
||||
models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
|
||||
with patch_save_on_cpu(get_patched_save_on_cpu()):
|
||||
|
|
@ -178,7 +189,9 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
loss.backward()
|
||||
losses.append(loss)
|
||||
outputs.append(out)
|
||||
|
||||
self._verify_parity(losses, outputs, models)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -197,7 +210,7 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
|
||||
global _save_on_cpu_called
|
||||
with patch_save_on_cpu(get_patched_save_on_cpu()):
|
||||
seq = TestFSDPCheckpoint.SequentialModule().to(device_type.type)
|
||||
seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
|
||||
# Runs FSDP with no checkpointing
|
||||
fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs)
|
||||
# Runs checkpoint-wrapped FSDP
|
||||
|
|
@ -205,6 +218,7 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
wrapper_to_use = offload_wrapper
|
||||
else:
|
||||
wrapper_to_use = checkpoint_wrapper
|
||||
|
||||
checkpointed_fsdp = wrapper_to_use(
|
||||
FSDP(deepcopy(seq), **fsdp_kwargs),
|
||||
)
|
||||
|
|
@ -217,7 +231,11 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs)
|
||||
# note that reentrant-based checkpointing requires inputs to have grad
|
||||
# flag set.
|
||||
inp = torch.randn(10, 3, device=device_type.type, requires_grad=True)
|
||||
|
||||
inp = torch.randn(
|
||||
10, 3, device=torch.cuda.current_device(), requires_grad=True
|
||||
)
|
||||
|
||||
models = [
|
||||
fsdp_only_seq,
|
||||
checkpointed_fsdp,
|
||||
|
|
@ -247,6 +265,7 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
# _save_on_cpu should not be called yet
|
||||
self.assertFalse(_save_on_cpu_called)
|
||||
out = m(inp)
|
||||
|
||||
if check_offload:
|
||||
self.assertTrue(_save_on_cpu_called)
|
||||
loss = out.sum()
|
||||
|
|
@ -254,7 +273,9 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
losses.append(loss)
|
||||
outputs.append(out)
|
||||
_save_on_cpu_called = False
|
||||
|
||||
self._verify_parity(losses, outputs, models)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
|
@ -306,27 +327,35 @@ class TestFSDPCheckpointSubmodule(FSDPTest):
|
|||
# TODO: grad value checks occasionally fails when use_reentrant = True
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("use_reentrant", [False])
|
||||
def test_checkpoint_submodule(self, device, use_reentrant: bool):
|
||||
model = TestModel(use_reentrant=use_reentrant).to(device_type.type)
|
||||
def test_checkpoint_submodule(self, use_reentrant: bool):
|
||||
model = TestModel(use_reentrant=use_reentrant).cuda()
|
||||
model_ac = deepcopy(model)
|
||||
|
||||
for _, m in model_ac.named_modules():
|
||||
if isinstance(m, CheckpointModule):
|
||||
m.checkpoint = True
|
||||
|
||||
self.assertTrue(model_ac.checkpoint1.s1.checkpoint)
|
||||
self.assertTrue(model_ac.checkpoint2.s2.checkpoint)
|
||||
|
||||
fsdp_kwargs = {
|
||||
"device_id": device_type.type,
|
||||
"device_id": torch.cuda.current_device(),
|
||||
"sharding_strategy": ShardingStrategy.NO_SHARD,
|
||||
}
|
||||
|
||||
# Wrap no checkpointing model submodules with FSDP
|
||||
model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs)
|
||||
model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs)
|
||||
|
||||
# Wrap checkpointing model submodules with FSDP
|
||||
model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs)
|
||||
model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs)
|
||||
x = torch.randn(2, 100, device=self.device_type)
|
||||
|
||||
x = torch.randn(2, 100, device="cuda")
|
||||
|
||||
model(x).sum().backward()
|
||||
model_ac(x).sum().backward()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(
|
||||
model.named_parameters(), model_ac.named_parameters()
|
||||
):
|
||||
|
|
@ -334,7 +363,8 @@ class TestFSDPCheckpointSubmodule(FSDPTest):
|
|||
self.assertTrue(p1.grad.allclose(p2.grad))
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestFSDPCheckpointSubmodule, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestFSDPCheckpointSubmodule)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import itertools
|
||||
import sys
|
||||
from typing import Union
|
||||
|
|
@ -15,24 +16,25 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
DEVICEInitMode,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
NestedWrappedModule,
|
||||
TransformerWithSharedParams,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -45,7 +47,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
"""Tests :meth:`FullyShardedDataParallel.clip_grad_norm_`."""
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_non_root(self, device):
|
||||
def test_non_root(self):
|
||||
"""
|
||||
Tests that calling ``clip_grad_norm_()`` on a non-root FSDP instance
|
||||
raises an error.
|
||||
|
|
@ -60,24 +62,22 @@ class TestClipGradNorm(FSDPTest):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lin2(self.lin1(x))
|
||||
|
||||
model = Model().to(device_type.type)
|
||||
model = Model().cuda()
|
||||
model.lin2 = FSDP(model.lin2)
|
||||
fsdp_model = FSDP(model)
|
||||
# fsdp_model(torch.randn((2, 5), device=torch.device(self.device_type))).sum().backward()
|
||||
fsdp_model(torch.randn((2, 5), device=device_type)).sum().backward()
|
||||
fsdp_model(torch.randn((2, 5), device=torch.device("cuda"))).sum().backward()
|
||||
error_regex = "should only be called on the root FSDP instance"
|
||||
with self.assertRaisesRegex(RuntimeError, error_regex):
|
||||
fsdp_model.lin2.clip_grad_norm_(max_norm=2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_parity(self, device):
|
||||
def test_ddp_parity(self):
|
||||
"""
|
||||
Tests FSDP with ``FullyShardedDataParallel.clip_grad_norm_()` against
|
||||
DDP with ``torch.nn.utils.clip_grad_norm_()` when using full precision.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"device": [device],
|
||||
"max_norm": [1, 2.5],
|
||||
"norm_type": [1, 2, float("inf")],
|
||||
"sharding_strategy": [
|
||||
|
|
@ -93,7 +93,6 @@ class TestClipGradNorm(FSDPTest):
|
|||
|
||||
def _test_ddp_parity(
|
||||
self,
|
||||
device,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int],
|
||||
sharding_strategy: Union[ShardingStrategy, str],
|
||||
|
|
@ -106,11 +105,10 @@ class TestClipGradNorm(FSDPTest):
|
|||
DEVICEInitMode.DEVICE_BEFORE,
|
||||
deterministic=True,
|
||||
)
|
||||
ddp_model = DDP(local_model, device_ids=[device_type])
|
||||
ddp_model = DDP(local_model, device_ids=[self.rank])
|
||||
fsdp_kwargs = {
|
||||
"cpu_offload": CPUOffload(offload_params=offload_params),
|
||||
"use_orig_params": use_orig_params,
|
||||
"device_id": device_type.type,
|
||||
}
|
||||
if sharding_strategy == "mixed_strategy":
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
|
|
@ -158,7 +156,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
LR = 1e-2
|
||||
ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
|
||||
device = torch.device(self.device_type)
|
||||
device = torch.device("cuda")
|
||||
LARGE_FACTOR = 100
|
||||
inp = ddp_model.module.get_input(device)
|
||||
for model in (ddp_model, fsdp_model):
|
||||
|
|
@ -168,6 +166,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
else:
|
||||
loss = model.get_loss(inp, out)
|
||||
loss.backward()
|
||||
|
||||
# Multiply gradients by a large factor to ensure that gradients will
|
||||
# actually be clipped
|
||||
for param in itertools.chain(ddp_model.parameters(), fsdp_model.parameters()):
|
||||
|
|
@ -182,6 +181,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
param.grad.detach().clone() if param.grad is not None else None
|
||||
for param in fsdp_model.parameters()
|
||||
]
|
||||
|
||||
ddp_total_norm = torch.nn.utils.clip_grad_norm_(
|
||||
ddp_model.parameters(),
|
||||
max_norm=max_norm,
|
||||
|
|
@ -191,6 +191,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
max_norm=max_norm, norm_type=norm_type
|
||||
)
|
||||
self.assertEqual(ddp_total_norm, fsdp_total_norm)
|
||||
|
||||
# Check that the gradients were modified by `clip_grad_norm_()`
|
||||
for param, orig_grad in zip(ddp_model.parameters(), orig_ddp_grads):
|
||||
assert not torch.equal(param.grad, orig_grad)
|
||||
|
|
@ -199,6 +200,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
self.assertEqual(param.grad, orig_grad) # `None`
|
||||
else:
|
||||
assert not torch.equal(param.grad, orig_grad)
|
||||
|
||||
# Run an optimizer step to ensure gradients matched after clipping
|
||||
ddp_optim.step()
|
||||
fsdp_optim.step()
|
||||
|
|
@ -209,11 +211,13 @@ class TestClipGradNorm(FSDPTest):
|
|||
):
|
||||
self.assertEqual(n1, n2)
|
||||
self.assertEqual(p1, p2)
|
||||
|
||||
if offload_params:
|
||||
# TODO: Gradient computation on CPU and GPU differ slightly causing
|
||||
# drift unrelated to `clip_grad_norm_()`.
|
||||
# https://github.com/pytorch/pytorch/issues/89133
|
||||
return
|
||||
|
||||
# Run a few more iterations
|
||||
# TODO: We cannot run too many iterations, or else there is drift:
|
||||
# https://github.com/pytorch/pytorch/issues/89136
|
||||
|
|
@ -238,11 +242,10 @@ class TestClipGradNorm(FSDPTest):
|
|||
fsdp_optim.step()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_low_precision_grads(self, device):
|
||||
def test_low_precision_grads(self):
|
||||
"""Tests ``clip_grad_norm_()`` when using low precision gradients."""
|
||||
self.run_subtests(
|
||||
{
|
||||
"device": [device],
|
||||
"max_norm": [1, 2.5],
|
||||
"norm_type": [1, 2, float("inf")],
|
||||
"sharding_strategy": [
|
||||
|
|
@ -256,7 +259,6 @@ class TestClipGradNorm(FSDPTest):
|
|||
|
||||
def _test_low_precision_grads(
|
||||
self,
|
||||
device,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int],
|
||||
sharding_strategy: ShardingStrategy,
|
||||
|
|
@ -270,7 +272,6 @@ class TestClipGradNorm(FSDPTest):
|
|||
reduce_dtype=torch.float16,
|
||||
keep_low_precision_grads=True,
|
||||
),
|
||||
"device_id": device_type.type,
|
||||
}
|
||||
fsdp_model = FSDP(
|
||||
NestedWrappedModule.init(
|
||||
|
|
@ -282,7 +283,7 @@ class TestClipGradNorm(FSDPTest):
|
|||
),
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
inp = fsdp_model.module.get_input(torch.device(self.device_type))
|
||||
inp = fsdp_model.module.get_input(torch.device("cuda"))
|
||||
out = fsdp_model(*inp)
|
||||
out.sum().backward()
|
||||
for param in fsdp_model.parameters():
|
||||
|
|
@ -301,17 +302,17 @@ class TestClipGradNorm(FSDPTest):
|
|||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_no_gradients(self, device):
|
||||
def test_no_gradients(self):
|
||||
"""
|
||||
Tests that calling ``clip_grad_norm_()`` when the FDSP module has no
|
||||
gradients simply returns a scalar zero tensor in FP32 without erroring.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"device": [device], "use_orig_params": [False, True]},
|
||||
{"use_orig_params": [False, True]},
|
||||
self._test_no_gradients,
|
||||
)
|
||||
|
||||
def _test_no_gradients(self, device, use_orig_params: bool):
|
||||
def _test_no_gradients(self, use_orig_params: bool):
|
||||
lin_module = nn.Linear(24, 24)
|
||||
mixed_precision_config = MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
|
|
@ -322,10 +323,10 @@ class TestClipGradNorm(FSDPTest):
|
|||
lin_module,
|
||||
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
||||
mixed_precision=mixed_precision_config,
|
||||
device_id=device_type.type,
|
||||
device_id=self.rank,
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
inp = torch.randn(32, 24, device=self.device_type)
|
||||
inp = torch.randn(32, 24, device="cuda")
|
||||
fsdp_module(inp)
|
||||
with self.assertWarnsRegex(
|
||||
expected_warning=UserWarning,
|
||||
|
|
@ -335,10 +336,10 @@ class TestClipGradNorm(FSDPTest):
|
|||
):
|
||||
total_norm = fsdp_module.clip_grad_norm_(1)
|
||||
self.assertEqual(total_norm.dtype, torch.float32)
|
||||
self.assertEqual(total_norm, torch.tensor(0.0, device=self.device_type))
|
||||
self.assertEqual(total_norm, torch.tensor(0.0, device="cuda"))
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestClipGradNorm, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestClipGradNorm)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
from enum import auto, Enum
|
||||
|
|
@ -9,34 +10,31 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed as dist
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
DEVICEInitMode,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
MLP,
|
||||
NestedWrappedModule,
|
||||
TransformerWithSharedParams,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -56,14 +54,11 @@ class TestCommunication(FSDPTest):
|
|||
|
||||
def _init_model(
|
||||
self,
|
||||
device,
|
||||
nested_model: bool,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
device: torch.device,
|
||||
):
|
||||
fsdp_kwargs = {
|
||||
"sharding_strategy": sharding_strategy,
|
||||
"device_id": device_type.type,
|
||||
}
|
||||
fsdp_kwargs = {"sharding_strategy": sharding_strategy}
|
||||
if nested_model:
|
||||
model = NestedWrappedModule.init(
|
||||
self.process_group,
|
||||
|
|
@ -75,7 +70,7 @@ class TestCommunication(FSDPTest):
|
|||
model,
|
||||
self.process_group,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
).to(device)
|
||||
else:
|
||||
fsdp_model: FSDP = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
|
|
@ -208,7 +203,6 @@ class TestCommunication(FSDPTest):
|
|||
@parametrize("sharding_strategy", [ShardingStrategy.SHARD_GRAD_OP, None])
|
||||
def test_communication(
|
||||
self,
|
||||
device,
|
||||
nested_model: bool,
|
||||
use_no_sync: bool,
|
||||
sharding_strategy: Optional[ShardingStrategy],
|
||||
|
|
@ -216,6 +210,7 @@ class TestCommunication(FSDPTest):
|
|||
"""
|
||||
Tests FSDP's communication cost in terms of calls to collective
|
||||
communication primitives (i.e. all-gather and reduce-scatter).
|
||||
|
||||
Arguments:
|
||||
nested_model (bool): If ``True``, uses ``NestedWrappedModule``,
|
||||
which has nested FSDP instances; if ``False``, uses the default
|
||||
|
|
@ -230,13 +225,16 @@ class TestCommunication(FSDPTest):
|
|||
# Enable execution order checking
|
||||
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
||||
# Initialize the model and inputs
|
||||
fsdp_model = self._init_model(device_type, nested_model, sharding_strategy)
|
||||
batch = fsdp_model.module.get_input(device_type)
|
||||
device = torch.device("cuda")
|
||||
fsdp_model = self._init_model(nested_model, sharding_strategy, device)
|
||||
batch = fsdp_model.module.get_input(device)
|
||||
|
||||
# Count the number of FSDP instances that manage parameters since the
|
||||
# number of collectives are a function of this number
|
||||
num_fsdp = sum(
|
||||
(isinstance(m, FSDP) and len(m.params) > 0) for m in fsdp_model.modules()
|
||||
)
|
||||
|
||||
# If `use_no_sync=True`, we run `num_iters` iterations inside
|
||||
# `no_sync()` followed by `num_iters` iterations outside `no_sync()`,
|
||||
# and if `use_no_sync=False`, we only run `num_iters` iterations
|
||||
|
|
@ -294,11 +292,11 @@ class TestCommunication(FSDPTest):
|
|||
class TestExplicitUnshard(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(_get_device_module(self.device_type).device_count(), 2)
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("use_orig_params", [False, True])
|
||||
def test_unshard_async(self, device, use_orig_params: bool):
|
||||
def test_unshard_async(self, use_orig_params: bool):
|
||||
class ReduceModule(nn.Module):
|
||||
def __init__(self, dim: int, group: dist.ProcessGroup):
|
||||
super().__init__()
|
||||
|
|
@ -352,25 +350,28 @@ class TestExplicitUnshard(FSDPTest):
|
|||
group = self.process_group
|
||||
batch_size, dim = 2, 8
|
||||
torch.manual_seed(42)
|
||||
ref_model = DDP(ReduceModel(dim, group).to(device_type), device_ids=[self.rank])
|
||||
ref_model = DDP(ReduceModel(dim, group).cuda(), device_ids=[self.rank])
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = ReduceModel(dim, group)
|
||||
model.mlps = FSDP(
|
||||
model.mlps,
|
||||
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
|
||||
auto_wrap_policy=ModuleWrapPolicy((MLP,)),
|
||||
device_id=device_type.type,
|
||||
device_id=self.rank,
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
model.mlps.check_is_root()
|
||||
mlp_params = set(model.mlps.parameters())
|
||||
mlp_param_names = {n for n, p in model.named_parameters() if p in mlp_params}
|
||||
DDP._set_params_and_buffers_to_ignore_for_model(model, mlp_param_names)
|
||||
model = DDP(model.to(device_type), device_ids=[self.rank])
|
||||
model = DDP(model.cuda(), device_ids=[self.rank])
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((batch_size, dim), device=device_type)
|
||||
inp = torch.randn((batch_size, dim), device="cuda")
|
||||
|
||||
for _ in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
|
|
@ -382,8 +383,8 @@ class TestExplicitUnshard(FSDPTest):
|
|||
model.module.mlps._wait_unshard_streams_on_current_stream()
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestCommunication, globals(), only_for=devices)
|
||||
instantiate_device_type_tests(TestExplicitUnshard, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestCommunication)
|
||||
instantiate_parametrized_tests(TestExplicitUnshard)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from unittest import mock
|
||||
|
||||
|
|
@ -19,7 +19,6 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.utils import _p_assert
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
AlwaysWrapNestedWrappedModule,
|
||||
|
|
@ -27,7 +26,6 @@ from torch.testing._internal.common_fsdp import (
|
|||
DummyDDP,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
MixtureOfExperts,
|
||||
NestedWrappedModule,
|
||||
NestedWrappedModuleWithDelay,
|
||||
|
|
@ -35,9 +33,9 @@ from torch.testing._internal.common_fsdp import (
|
|||
TransformerWithSharedParams,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
|
@ -45,6 +43,7 @@ from torch.testing._internal.common_utils import (
|
|||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -52,7 +51,6 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
params = "cpu_offload,sharding_strategy"
|
||||
cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
|
||||
sharding_strategy_config = [
|
||||
|
|
@ -67,6 +65,7 @@ test_name_mapping = {
|
|||
str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op",
|
||||
str(ShardingStrategy.NO_SHARD): "no_shard",
|
||||
}
|
||||
|
||||
subtest_name = functools.partial(subtest_name, test_name_mapping)
|
||||
|
||||
|
||||
|
|
@ -87,6 +86,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
# on CPU but NCCL only supports GPU.
|
||||
if cpu_offload.offload_params:
|
||||
modes.append(DEVICEInitMode.DEVICE_NEVER)
|
||||
|
||||
return modes
|
||||
|
||||
def _get_subtest_config(self, cpu_offload: CPUOffload) -> Dict[str, List[Any]]:
|
||||
|
|
@ -229,7 +229,6 @@ class TestParityWithDDP(FSDPTest):
|
|||
cpu_offload: CPUOffload,
|
||||
sharding_strategy: Optional[ShardingStrategy],
|
||||
):
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
self.run_subtests(
|
||||
self._get_subtest_config(cpu_offload),
|
||||
self._test_fsdp_parity,
|
||||
|
|
@ -238,12 +237,8 @@ class TestParityWithDDP(FSDPTest):
|
|||
ref_init_fn=self._dummy_ddp_fn,
|
||||
cpu_offload=cpu_offload,
|
||||
sharding_strategy=sharding_strategy,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_HPU, "HPU doesn't has HW sleep API support (like CUDA), skipping"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize(params, configs, subtest_name)
|
||||
def test_mixture_of_experts_with_delay_before_free(
|
||||
|
|
@ -251,7 +246,6 @@ class TestParityWithDDP(FSDPTest):
|
|||
cpu_offload: CPUOffload,
|
||||
sharding_strategy: Optional[ShardingStrategy],
|
||||
):
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
self.run_subtests(
|
||||
self._get_subtest_config(cpu_offload),
|
||||
self._test_fsdp_parity,
|
||||
|
|
@ -261,7 +255,6 @@ class TestParityWithDDP(FSDPTest):
|
|||
cpu_offload=cpu_offload,
|
||||
sharding_strategy=sharding_strategy,
|
||||
init_kwargs={"delay_before_free_ms": 250},
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -274,7 +267,7 @@ class TestParamInit(FSDPTest):
|
|||
initialization persist.
|
||||
"""
|
||||
# Establish reference behavior
|
||||
fsdp_kwargs = {"device_id": device_type}
|
||||
fsdp_kwargs = {}
|
||||
if mixed_precision:
|
||||
fsdp_kwargs["mixed_precision"] = MixedPrecision()
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
|
|
@ -284,7 +277,7 @@ class TestParamInit(FSDPTest):
|
|||
fsdp_kwargs,
|
||||
deterministic=True,
|
||||
)
|
||||
input = fsdp_model.module.get_input(device_type)
|
||||
input = fsdp_model.module.get_input(torch.device("cuda"))
|
||||
ref_output = fsdp_model(*input)
|
||||
# Initialize the same model but change its first parameter value
|
||||
# in-place after FSDP initialization
|
||||
|
|
@ -311,12 +304,10 @@ class TestHooks(FSDPTest):
|
|||
def test_pre_backward_hook_registration(self, cuda_first: bool):
|
||||
"""Tests that FSDP pre-backward hooks are registered on forward pass
|
||||
outputs."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
DEVICEInitMode.DEVICE_BEFORE if cuda_first else DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs,
|
||||
)
|
||||
self._test_pre_backward_hook_registration(fsdp_model)
|
||||
|
||||
|
|
@ -324,12 +315,10 @@ class TestHooks(FSDPTest):
|
|||
def test_pre_backward_hook_registration_after_state_dict(self):
|
||||
"""Tests that FSDP pre-backward hooks are registered on forward pass
|
||||
outputs after saving and loading the model from a checkpoint."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs,
|
||||
)
|
||||
self._train_for_several_steps(fsdp_model, num_steps=2, autocast=False)
|
||||
state_dict = fsdp_model.state_dict()
|
||||
|
|
@ -340,11 +329,11 @@ class TestHooks(FSDPTest):
|
|||
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
|
||||
optim.zero_grad()
|
||||
# Inputs always cuda, as computation happens on CUDA device only
|
||||
input = model.module.get_input(device_type)
|
||||
input = model.module.get_input(torch.device("cuda"))
|
||||
output = model(*input)
|
||||
# this is pre-bwd hook
|
||||
self.assertEqual(len(output._backward_hooks), 1)
|
||||
loss = model.module.get_loss(input, output).to(device_type.type)
|
||||
loss = model.module.get_loss(input, output).cuda()
|
||||
loss.backward()
|
||||
# It doesn't get removed
|
||||
self.assertEqual(len(output._backward_hooks), 1)
|
||||
|
|
@ -357,7 +346,7 @@ class TestHooks(FSDPTest):
|
|||
def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool):
|
||||
"""Tests that ``_register_{pre|post}_backward_hooks()`` are called
|
||||
during the FSDP forward."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
fsdp_kwargs = {}
|
||||
if mixed_precision:
|
||||
fsdp_kwargs["mixed_precision"] = MixedPrecision()
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
|
|
@ -366,7 +355,8 @@ class TestHooks(FSDPTest):
|
|||
DEVICEInitMode.DEVICE_BEFORE if cuda_first else DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs,
|
||||
)
|
||||
input = fsdp_model.module.get_input(device_type)
|
||||
input = fsdp_model.module.get_input(torch.device("cuda"))
|
||||
|
||||
# Since `_register_pre_backward_hooks()` modifies the forward output,
|
||||
# we cannot directly mock it. We implement our own counter instead.
|
||||
orig_register_pre_backward_hooks = (
|
||||
|
|
@ -400,7 +390,7 @@ class TestNoGrad(FSDPTest):
|
|||
parameters, after training for one iteration, running a forward pass in
|
||||
``eval()`` mode gives the same output as running a forward pass in
|
||||
``torch.no_grad()``."""
|
||||
fsdp_kwargs = {"device_id": device_type.type}
|
||||
fsdp_kwargs = {}
|
||||
if mixed_precision:
|
||||
fsdp_kwargs["mixed_precision"] = MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
|
|
@ -421,7 +411,7 @@ class TestNoGrad(FSDPTest):
|
|||
autocast=False,
|
||||
mixed_precision=fsdp_kwargs["mixed_precision"],
|
||||
)
|
||||
input = fsdp_model.module.get_input(device_type)
|
||||
input = fsdp_model.module.get_input(torch.device("cuda"))
|
||||
# Run a forward in eval mode
|
||||
fsdp_model.eval()
|
||||
ref_output = fsdp_model(*input)
|
||||
|
|
@ -445,7 +435,7 @@ class TestAutograd(FSDPTest):
|
|||
{
|
||||
"sharding_strategy": [
|
||||
ShardingStrategy.FULL_SHARD,
|
||||
ShardingStrategy.SHARD_GRAD_OP,
|
||||
ShardingStrategy.SHARD_GRAD_OP
|
||||
# Skip testing `NO_SHARD` since it doubly uses
|
||||
# `_use_unsharded_views()` for sharded views. Testing
|
||||
# `FULL_SHARD` and `SHARD_GRAD_OP` provides good confidence
|
||||
|
|
@ -485,17 +475,17 @@ class TestAutograd(FSDPTest):
|
|||
"forward_prefetch": forward_prefetch,
|
||||
"backward_prefetch": backward_prefetch,
|
||||
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
|
||||
"device_id": device_type,
|
||||
}
|
||||
device = torch.device("cuda")
|
||||
# Define a model with enough FSDP instances to exercise prefetching
|
||||
NUM_LINEARS = 5
|
||||
model = nn.Sequential(
|
||||
*[nn.Linear(3, 3, device=device_type) for _ in range(NUM_LINEARS)]
|
||||
*[nn.Linear(3, 3, device=device) for _ in range(NUM_LINEARS)]
|
||||
)
|
||||
fsdp_model = FSDP(model, **fsdp_kwargs)
|
||||
self.assertEqual(len(list(FSDP.fsdp_modules(fsdp_model))), NUM_LINEARS + 1)
|
||||
for _ in range(3):
|
||||
inp = torch.randn((2, 3), device=device_type)
|
||||
inp = torch.randn((2, 3), device=device)
|
||||
with self._patch_use_unsharded_views(
|
||||
_use_unsharded_views_assert_as_tensors
|
||||
):
|
||||
|
|
@ -512,11 +502,10 @@ class TestAutograd(FSDPTest):
|
|||
FlatParamHandle._use_unsharded_views = orig_use_unsharded_views
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestHooks, globals(), only_for=devices)
|
||||
instantiate_device_type_tests(TestParityWithDDP, globals(), only_for=devices)
|
||||
instantiate_device_type_tests(TestNoGrad, globals(), only_for=devices)
|
||||
instantiate_device_type_tests(TestParamInit, globals(), only_for=devices)
|
||||
instantiate_device_type_tests(TestAutograd, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestHooks)
|
||||
instantiate_parametrized_tests(TestParityWithDDP)
|
||||
instantiate_parametrized_tests(TestNoGrad)
|
||||
instantiate_parametrized_tests(TestParamInit)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import io
|
||||
from copy import deepcopy
|
||||
|
||||
|
|
@ -13,9 +14,11 @@ from torch.distributed.fsdp.api import (
|
|||
ShardedStateDictConfig,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
|
|
@ -23,13 +26,10 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
# Simple and boring model to test interface and some corner cases that do not
|
||||
# require complicated wrapping strategy.
|
||||
class TestDummyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
|
||||
|
|
@ -41,11 +41,11 @@ class TestDummyModel(torch.nn.Module):
|
|||
return self.net4(self.net3(self.net2(self.net1(x))))
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(8, 8, device=device_type.type)
|
||||
return torch.rand(8, 8, device="cuda")
|
||||
|
||||
|
||||
class TestDummyModelUneven(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.net1 = nn.Sequential(nn.Linear(5, 10), nn.ReLU())
|
||||
|
|
@ -57,7 +57,7 @@ class TestDummyModelUneven(torch.nn.Module):
|
|||
return self.net4(self.net3(self.net2(self.net1(x))))
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(5, 5, device=device_type.type)
|
||||
return torch.rand(5, 5, device="cuda")
|
||||
|
||||
|
||||
class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
||||
|
|
@ -65,29 +65,34 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
dummy_model = (
|
||||
TestDummyModel() if is_even_sharded_model else TestDummyModelUneven()
|
||||
)
|
||||
model = FSDP(dummy_model.to(device_type), device_mesh=device_mesh)
|
||||
|
||||
model = FSDP(dummy_model.cuda(), device_mesh=device_mesh)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.1)
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
return model, optim
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_fsdp_init_with_device_mesh(self, is_even_sharded_model):
|
||||
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model, optim = self._create_model(is_even_sharded_model, device_mesh)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optim)
|
||||
|
||||
for v in state_dict.values():
|
||||
self.assertEqual(type(v), DTensor)
|
||||
self.assertEqual(len(v.placements), 1)
|
||||
self.assertEqual(v.placements[0], (Shard(dim=0)))
|
||||
self.assertEqual(v.device_mesh, device_mesh)
|
||||
|
||||
for state in optim_state_dict["state"].values():
|
||||
for k, v in state.items():
|
||||
if k != "step":
|
||||
|
|
@ -95,6 +100,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
self.assertEqual(len(v.placements), 1)
|
||||
self.assertEqual(v.placements[0], (Shard(dim=0)))
|
||||
self.assertEqual(v.device_mesh, device_mesh)
|
||||
|
||||
state_dict_type = FSDP.get_state_dict_type(model)
|
||||
# If device_mesh is used when initializing FSDP, the field _use_dtensor will
|
||||
# automatically be set to True if StateDictType is set to SHARDED_STATE_DICT.
|
||||
|
|
@ -108,8 +114,9 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
def test_dtensor_sharded_tensor_state_dict_identical(
|
||||
self, offload_to_cpu, is_even_sharded_model
|
||||
):
|
||||
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model, optim = self._create_model(is_even_sharded_model, device_mesh)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
|
|
@ -120,6 +127,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
)
|
||||
dtensor_sd = model.state_dict()
|
||||
dtensor_osd = FSDP.optim_state_dict(model, optim)
|
||||
|
||||
ref_model, ref_optim = self._create_model(is_even_sharded_model)
|
||||
FSDP.set_state_dict_type(
|
||||
ref_model,
|
||||
|
|
@ -131,6 +139,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
)
|
||||
sharded_tensor_sd = ref_model.state_dict()
|
||||
sharded_tensor_osd = FSDP.optim_state_dict(ref_model, ref_optim)
|
||||
|
||||
# Check dtensor and sharded_tensor model state dict values are identical
|
||||
for dtensor_sd_item, sharded_tensor_sd_item in zip(
|
||||
dtensor_sd.items(), sharded_tensor_sd.items()
|
||||
|
|
@ -138,6 +147,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
k1, v1 = dtensor_sd_item
|
||||
k2, v2 = sharded_tensor_sd_item
|
||||
self.assertEqual(k1, k2)
|
||||
|
||||
# if the ShardedTensor is an empty shard,
|
||||
# then the local tensor of DTensor should be local_tensor=tensor([])
|
||||
if len(v2.local_shards()) == 0:
|
||||
|
|
@ -149,6 +159,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
self.assertEqual(v1.to_local(), v2.local_tensor())
|
||||
# check whether device are the same
|
||||
self.assertEqual(v1.to_local().device, v2.local_tensor().device)
|
||||
|
||||
# Check dtensor and sharde_tensor optim state dict values are identical
|
||||
for dtensor_osd_state, sharded_tensor_osd_state in zip(
|
||||
dtensor_osd["state"].items(), sharded_tensor_osd["state"].items()
|
||||
|
|
@ -162,6 +173,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
k1, v1 = dtensor_hyper_param
|
||||
k2, v2 = sharded_tensor_hyper_param
|
||||
self.assertEqual(k1, k2)
|
||||
|
||||
if k1 != "step":
|
||||
# if the ShardedTensor is an empty shard,
|
||||
# then the local tensor of DTensor should be local_tensor=tensor([])
|
||||
|
|
@ -184,8 +196,9 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
def test_dtensor_sharded_optim_load_state_dict(
|
||||
self, offload_to_cpu, is_even_sharded_model
|
||||
):
|
||||
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model, optim = self._create_model(is_even_sharded_model, device_mesh)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
|
|
@ -193,13 +206,16 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
offload_to_cpu=offload_to_cpu
|
||||
),
|
||||
)
|
||||
|
||||
checkpoint = io.BytesIO()
|
||||
torch.save(FSDP.optim_state_dict(model, optim), checkpoint)
|
||||
# Deepcopy to save current optim_state_dict to compare with the optim_state_dict loaded back below.
|
||||
ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))
|
||||
|
||||
# Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
# Load ref_optim_state_dict back.
|
||||
checkpoint.seek(0)
|
||||
load_ref_optim_state_dict = torch.load(checkpoint)
|
||||
|
|
@ -207,6 +223,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
FSDP.optim_state_dict_to_load(model, optim, load_ref_optim_state_dict)
|
||||
)
|
||||
new_optim_state_dict = FSDP.optim_state_dict(model, optim)
|
||||
|
||||
# Check whether new_optim_state_dict is the same as ref_optim_state_dict.
|
||||
for new_optim_state_dict_item, ref_optim_state_dict_item in zip(
|
||||
new_optim_state_dict["state"].items(),
|
||||
|
|
@ -220,10 +237,12 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
):
|
||||
k1, v1 = new_optim_hyper_param
|
||||
k2, v2 = ref_optim_hyper_param
|
||||
|
||||
# check whether keys are the same
|
||||
self.assertEqual(k1, k2)
|
||||
# check whether values are the same
|
||||
self.assertEqual(v1, v2)
|
||||
|
||||
if k1 != "step":
|
||||
self.assertEqual(type(v1), DTensor)
|
||||
self.assertEqual(type(v2), DTensor)
|
||||
|
|
@ -235,29 +254,35 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
def test_dtensor_sharded_model_load_state_dict(
|
||||
self, offload_to_cpu, is_even_sharded_model
|
||||
):
|
||||
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model, optim = self._create_model(is_even_sharded_model, device_mesh)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),
|
||||
)
|
||||
|
||||
checkpoint = io.BytesIO()
|
||||
torch.save(model.state_dict(), checkpoint)
|
||||
# Deepcopy to save current state_dict to compare with the state_dict loaded back below.
|
||||
ref_state_dict = deepcopy(model.state_dict())
|
||||
|
||||
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
# Load ref_state_dict back.
|
||||
checkpoint.seek(0)
|
||||
load_ref_state_dict = torch.load(checkpoint)
|
||||
model.load_state_dict(load_ref_state_dict)
|
||||
new_state_dict = model.state_dict()
|
||||
|
||||
# Check whether new_state_dict is the same as ref_state_dict.
|
||||
for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()):
|
||||
# check whether fqn are the same
|
||||
self.assertEqual(k1, k2)
|
||||
|
||||
self.assertEqual(type(v1), DTensor)
|
||||
self.assertEqual(type(v2), DTensor)
|
||||
# check whether DTensor are the same
|
||||
|
|
@ -266,18 +291,20 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_raises_warning_or_errors(self):
|
||||
device_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model, optim = self._create_model(
|
||||
is_even_sharded_model=True, device_mesh=device_mesh
|
||||
)
|
||||
# initialize optim
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
|
||||
):
|
||||
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
|
||||
state_dict = model.state_dict()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DeviceMesh is not compatible with LOCAL_STATE_DICT."
|
||||
):
|
||||
|
|
@ -285,9 +312,6 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
optim_state_dict = FSDP.optim_state_dict(model, optim)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(
|
||||
TestFSDPWithDeviceMeshAndDTensor, globals(), only_for=devices
|
||||
)
|
||||
instantiate_parametrized_tests(TestFSDPWithDeviceMeshAndDTensor)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -8,18 +8,16 @@ import torch
|
|||
from torch import distributed as dist
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
|
@ -65,7 +63,7 @@ class Model(torch.nn.Module):
|
|||
)
|
||||
return z
|
||||
|
||||
def get_input(self, device):
|
||||
def get_input(self, device: torch.device):
|
||||
return (torch.randn((8, 5)).to(device),)
|
||||
|
||||
def get_loss(self, input, output):
|
||||
|
|
@ -88,19 +86,22 @@ class Model(torch.nn.Module):
|
|||
self.use_alt_path = not self.use_alt_path
|
||||
|
||||
@staticmethod
|
||||
def wrap(sharding_strategy: ShardingStrategy, device):
|
||||
def wrap(sharding_strategy: ShardingStrategy, device: torch.device):
|
||||
model = Model()
|
||||
model.layer1 = FSDP(
|
||||
model.layer1, sharding_strategy=sharding_strategy, device_id=device
|
||||
)
|
||||
model.layer2 = FSDP(
|
||||
model.layer2, sharding_strategy=sharding_strategy, device_id=device
|
||||
)
|
||||
fsdp_model = FSDP(model, sharding_strategy=sharding_strategy, device_id=device)
|
||||
model.layer1 = FSDP(model.layer1, sharding_strategy=sharding_strategy)
|
||||
model.layer2 = FSDP(model.layer2, sharding_strategy=sharding_strategy)
|
||||
fsdp_model = FSDP(model, sharding_strategy=sharding_strategy)
|
||||
return fsdp_model.to(device)
|
||||
|
||||
|
||||
class TestFSDPExecOrder(FSDPTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return torch.device("cuda")
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize(
|
||||
"sharding_strategy",
|
||||
|
|
@ -108,7 +109,6 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
)
|
||||
def test_invalid_first_iter_order(
|
||||
self,
|
||||
device,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
):
|
||||
"""Tests that FSDP errors if the all-gather order differs across ranks
|
||||
|
|
@ -116,10 +116,10 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
# Rank 0 runs the forward pass in one order and all other ranks run in
|
||||
# different order
|
||||
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
||||
fsdp_model = Model.wrap(sharding_strategy, device_type)
|
||||
fsdp_model = Model.wrap(sharding_strategy, self.device)
|
||||
if self.rank != 0:
|
||||
fsdp_model.flip_path()
|
||||
inp = fsdp_model.module.get_input(device_type)
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
# Match the error message with the following prefix
|
||||
error_regex = "^(Forward order differs across ranks)"
|
||||
with self.assertRaisesRegex(RuntimeError, error_regex):
|
||||
|
|
@ -133,7 +133,6 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
@parametrize("iters_before_path_change", [1, 3])
|
||||
def test_invalid_later_iter_order(
|
||||
self,
|
||||
device,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
iters_before_path_change: int,
|
||||
):
|
||||
|
|
@ -142,11 +141,11 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
||||
# On the first iteration, all ranks run the same order, and on the next
|
||||
# iteration, all but rank 0 run in a different order
|
||||
fsdp_model = Model.wrap(sharding_strategy, device_type)
|
||||
fsdp_model = Model.wrap(sharding_strategy, self.device)
|
||||
for _ in range(iters_before_path_change):
|
||||
inp = fsdp_model.module.get_input(device_type)
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
output = fsdp_model(*inp)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
# Match the warning message with the following prefix
|
||||
regex = (
|
||||
|
|
@ -164,16 +163,16 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
)
|
||||
if self.rank != 0:
|
||||
fsdp_model.flip_path()
|
||||
inp = fsdp_model.module.get_input(device_type)
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
# Expect a warning for the forward pass all-gather
|
||||
with context: # warning for forward pass all-gather
|
||||
output = fsdp_model(*inp)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
# Run an additional iteration to check that there are no more warnings
|
||||
inp = fsdp_model.module.get_input(device_type)
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
output = fsdp_model(*inp)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
@ -181,24 +180,24 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
"sharding_strategy",
|
||||
[ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
|
||||
)
|
||||
def test_train_eval(self, device, sharding_strategy: ShardingStrategy):
|
||||
def test_train_eval(self, sharding_strategy: ShardingStrategy):
|
||||
dist.set_debug_level(dist.DebugLevel.DETAIL)
|
||||
fsdp_model = Model.wrap(sharding_strategy, device_type)
|
||||
fsdp_model = Model.wrap(sharding_strategy, self.device)
|
||||
NUM_ITERS = 3
|
||||
NUM_EPOCHS = 2
|
||||
with warnings.catch_warnings(record=True) as w: # records warnings to `w`
|
||||
for _ in range(NUM_EPOCHS):
|
||||
fsdp_model.train()
|
||||
for _ in range(NUM_ITERS):
|
||||
inp = fsdp_model.module.get_input(device_type)
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
output = fsdp_model(*inp)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(device_type)
|
||||
loss = fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
fsdp_model.module.run_backward(loss)
|
||||
fsdp_model.eval()
|
||||
for _ in range(NUM_ITERS):
|
||||
inp = fsdp_model.module.get_input(device_type)
|
||||
inp = fsdp_model.module.get_input(self.device)
|
||||
output = fsdp_model(*inp)
|
||||
fsdp_model.module.get_loss(inp, output).to(device_type)
|
||||
fsdp_model.module.get_loss(inp, output).to(self.device)
|
||||
# Check that the order validation warning was not issued (errors do not
|
||||
# need to be checked since they will be directly reported)
|
||||
warning_prefix = "Forward order differs"
|
||||
|
|
@ -211,7 +210,7 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
# an `AssertionError` will be raised above for both sharding strategies
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestFSDPExecOrder, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestFSDPExecOrder)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from unittest import mock
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed.fsdp import BackwardPrefetch, CPUOffload, MixedPrecision
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
|
|
@ -15,18 +14,11 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
|
@ -47,12 +39,12 @@ class LinearUnusedInput(nn.Linear):
|
|||
class ModelUnusedInput(nn.Module):
|
||||
def __init__(self, freeze: bool):
|
||||
super().__init__()
|
||||
self.layer0 = LinearUnusedInput(4, 4)
|
||||
self.layer1_frozen = LinearUnusedInput(4, 4)
|
||||
self.layer0 = LinearUnusedInput(4, 4, device="cuda")
|
||||
self.layer1_frozen = LinearUnusedInput(4, 4, device="cuda")
|
||||
if freeze:
|
||||
for param in self.layer1_frozen.parameters():
|
||||
param.requires_grad = False
|
||||
self.layer2 = LinearUnusedInput(4, 4)
|
||||
self.layer2 = LinearUnusedInput(4, 4, device="cuda")
|
||||
|
||||
def forward(self, frozen_input, learnable_input):
|
||||
x = self.layer0(frozen_input, learnable_input)
|
||||
|
|
@ -68,13 +60,13 @@ class TestFSDPFineTune(FSDPTest):
|
|||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(_get_device_module(self.device_type).device_count(), 2)
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
|
||||
def _init_seq_module(self, device) -> nn.Module:
|
||||
def _init_seq_module(self) -> nn.Module:
|
||||
torch.manual_seed(42)
|
||||
modules = []
|
||||
for _ in range(self.NUM_LINEARS):
|
||||
modules += [nn.Linear(5, 5, device=device), nn.ReLU()]
|
||||
modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()]
|
||||
seq = nn.Sequential(*modules)
|
||||
self._set_seq_module_requires_grad(seq, False)
|
||||
return seq
|
||||
|
|
@ -89,14 +81,13 @@ class TestFSDPFineTune(FSDPTest):
|
|||
param.requires_grad = requires_grad
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_backward_reshard_hooks(self, device):
|
||||
def test_backward_reshard_hooks(self):
|
||||
"""
|
||||
Tests that the post-backward reshard happens even for flat parameters
|
||||
that do not require gradients.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"device_id": [device],
|
||||
"sharding_strategy": [
|
||||
ShardingStrategy.FULL_SHARD,
|
||||
ShardingStrategy.SHARD_GRAD_OP,
|
||||
|
|
@ -111,21 +102,18 @@ class TestFSDPFineTune(FSDPTest):
|
|||
|
||||
def _test_backward_reshard_hooks(
|
||||
self,
|
||||
device_id,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
use_orig_params: bool,
|
||||
inp_requires_grad: bool,
|
||||
unfreeze_params: bool,
|
||||
):
|
||||
seq = self._init_seq_module(device_type)
|
||||
seq = self._init_seq_module()
|
||||
policy = ModuleWrapPolicy({nn.Linear})
|
||||
fsdp_kwargs = {"device_id": device_type}
|
||||
seq = FSDP(
|
||||
seq,
|
||||
auto_wrap_policy=policy,
|
||||
sharding_strategy=sharding_strategy,
|
||||
use_orig_params=use_orig_params,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
orig_post_backward_reshard = (
|
||||
torch.distributed.fsdp._runtime_utils._post_backward_reshard
|
||||
|
|
@ -174,7 +162,7 @@ class TestFSDPFineTune(FSDPTest):
|
|||
self._set_seq_module_requires_grad(seq, True)
|
||||
|
||||
inp = torch.randn(
|
||||
(8, 5), device=device_type, requires_grad=inp_requires_grad
|
||||
(8, 5), device="cuda", requires_grad=inp_requires_grad
|
||||
)
|
||||
if step_idx == nograd_step_idx:
|
||||
with torch.no_grad():
|
||||
|
|
@ -187,15 +175,15 @@ class TestFSDPFineTune(FSDPTest):
|
|||
_assert_post_backward_reshard_count(step_idx, num_steps)
|
||||
post_backward_reshard_count = 0
|
||||
|
||||
def _init_multi_traversal_module(self, device) -> nn.Module:
|
||||
def _init_multi_traversal_module(self) -> nn.Module:
|
||||
torch.manual_seed(42)
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer_0 = nn.Linear(5, 5, device=device)
|
||||
self.layer_no_grad = nn.Linear(5, 5, device=device)
|
||||
self.layer_with_grad = nn.Linear(5, 5, device=device)
|
||||
self.layer_0 = nn.Linear(5, 5, device="cuda")
|
||||
self.layer_no_grad = nn.Linear(5, 5, device="cuda")
|
||||
self.layer_with_grad = nn.Linear(5, 5, device="cuda")
|
||||
self.layer_no_grad.requires_grad_(False)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -240,26 +228,22 @@ class TestFSDPFineTune(FSDPTest):
|
|||
inp_requires_grad: bool,
|
||||
forward_prefetch: bool,
|
||||
):
|
||||
seq = self._init_multi_traversal_module(device_type.type)
|
||||
seq = self._init_multi_traversal_module()
|
||||
policy = ModuleWrapPolicy({nn.Linear})
|
||||
fsdp_kwargs = {"device_id": device_type}
|
||||
fsdp_seq = FSDP(
|
||||
copy.deepcopy(seq),
|
||||
auto_wrap_policy=policy,
|
||||
sharding_strategy=sharding_strategy,
|
||||
use_orig_params=use_orig_params,
|
||||
forward_prefetch=forward_prefetch,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
|
||||
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[self.rank])
|
||||
fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
|
||||
ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
|
||||
torch.manual_seed(self.rank + 1)
|
||||
losses = []
|
||||
for _ in range(6):
|
||||
inp = torch.randn(
|
||||
(8, 5), device=device_type, requires_grad=inp_requires_grad
|
||||
)
|
||||
inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad)
|
||||
for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
|
||||
loss = seq(inp).sum()
|
||||
losses.append(loss)
|
||||
|
|
@ -292,37 +276,32 @@ class TestFSDPFineTune(FSDPTest):
|
|||
sharding_strategy: ShardingStrategy,
|
||||
use_orig_params: bool,
|
||||
):
|
||||
seq = self._init_seq_module(device_type)
|
||||
seq = self._init_seq_module()
|
||||
policy = ModuleWrapPolicy({nn.Linear})
|
||||
fsdp_kwargs = {"device_id": device_type}
|
||||
fsdp_seq = FSDP(
|
||||
copy.deepcopy(seq),
|
||||
auto_wrap_policy=policy,
|
||||
sharding_strategy=sharding_strategy,
|
||||
use_orig_params=use_orig_params,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
|
||||
ddp_seq = DDP(copy.deepcopy(seq), device_ids=[self.rank])
|
||||
fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
|
||||
ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
|
||||
torch.manual_seed(self.rank + 1)
|
||||
losses = []
|
||||
for _ in range(6):
|
||||
inp = torch.randn((8, 5), device=device_type.type)
|
||||
inp = torch.randn((8, 5), device="cuda")
|
||||
for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
|
||||
loss = seq(inp).sum()
|
||||
losses.append(loss)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
if TEST_CUDA:
|
||||
torch.testing.assert_close(losses[0], losses[1])
|
||||
else:
|
||||
torch.testing.assert_close(losses[0], losses[1], atol=1e-03, rtol=1e-03)
|
||||
torch.testing.assert_close(losses[0], losses[1])
|
||||
losses.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_parity_with_non_frozen_fsdp(self, device):
|
||||
def test_parity_with_non_frozen_fsdp(self):
|
||||
"""
|
||||
For frozen modules with unused input, reshard could happen without unshard
|
||||
Verify numerical parity between `_post_backward_reshard_only_hook` and
|
||||
|
|
@ -330,7 +309,6 @@ class TestFSDPFineTune(FSDPTest):
|
|||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"device_id": [device],
|
||||
"sharding_strategy": [
|
||||
ShardingStrategy.FULL_SHARD,
|
||||
ShardingStrategy.SHARD_GRAD_OP,
|
||||
|
|
@ -355,7 +333,6 @@ class TestFSDPFineTune(FSDPTest):
|
|||
|
||||
def _test_parity_with_non_frozen_fsdp(
|
||||
self,
|
||||
device_id,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
use_orig_params: bool,
|
||||
offload_params: bool,
|
||||
|
|
@ -363,11 +340,10 @@ class TestFSDPFineTune(FSDPTest):
|
|||
backward_prefetch: BackwardPrefetch,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = ModelUnusedInput(freeze=True).to(device_type)
|
||||
model = ModelUnusedInput(freeze=True)
|
||||
torch.manual_seed(42)
|
||||
ref_model = ModelUnusedInput(freeze=False).to(device_type)
|
||||
ref_model = ModelUnusedInput(freeze=False)
|
||||
fsdp_kwargs = {
|
||||
"device_id": device_type,
|
||||
"auto_wrap_policy": ModuleWrapPolicy({LinearUnusedInput}),
|
||||
"sharding_strategy": sharding_strategy,
|
||||
"use_orig_params": use_orig_params,
|
||||
|
|
@ -389,10 +365,8 @@ class TestFSDPFineTune(FSDPTest):
|
|||
torch.manual_seed(self.rank + 1)
|
||||
losses = []
|
||||
for idx in range(6):
|
||||
frozen_input = torch.randn((4, 4), device=device_type, requires_grad=False)
|
||||
learnable_input = torch.randn(
|
||||
(4, 4), device=device_type, requires_grad=True
|
||||
)
|
||||
frozen_input = torch.randn((4, 4), device="cuda", requires_grad=False)
|
||||
learnable_input = torch.randn((4, 4), device="cuda", requires_grad=True)
|
||||
for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
|
||||
loss = _model(frozen_input, frozen_input).sum()
|
||||
losses.append(loss)
|
||||
|
|
@ -407,7 +381,5 @@ class TestFSDPFineTune(FSDPTest):
|
|||
self.assertEqual(param, ref_param)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestFSDPFineTune, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp._trace_utils import _ExecOrderTracer
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
|
@ -113,7 +117,7 @@ class TestSymbolicTracing(TestCase):
|
|||
self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestSymbolicTracing, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestSymbolicTracing)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
|
@ -6,10 +7,10 @@ from torch import distributed as dist
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear, Module
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
|
|
@ -20,6 +21,7 @@ from torch.testing._internal.common_utils import (
|
|||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -35,11 +37,11 @@ class TestInput(FSDPTest):
|
|||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")])
|
||||
def test_input_type(self, device, input_cls):
|
||||
def test_input_type(self, input_cls):
|
||||
"""Test FSDP with input being a list or a dict, only single GPU."""
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer = Linear(4, 4)
|
||||
|
||||
|
|
@ -51,26 +53,25 @@ class TestInput(FSDPTest):
|
|||
input = input["in"]
|
||||
return self.layer(input)
|
||||
|
||||
fsdp_kwargs = {
|
||||
"device_id": device,
|
||||
}
|
||||
model = FSDP(Model().to(device), **fsdp_kwargs)
|
||||
model = FSDP(Model()).cuda()
|
||||
optim = SGD(model.parameters(), lr=0.1)
|
||||
|
||||
for _ in range(5):
|
||||
in_data = torch.rand(64, 4).to(device)
|
||||
in_data = torch.rand(64, 4).cuda()
|
||||
in_data.requires_grad = True
|
||||
if input_cls is list:
|
||||
in_data = [in_data]
|
||||
else:
|
||||
self.assertTrue(input_cls is dict)
|
||||
in_data = {"in": in_data}
|
||||
|
||||
out = model(in_data)
|
||||
out.sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestInput, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestInput)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -14,7 +13,6 @@ from torch.testing._internal.common_utils import (
|
|||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
|
@ -160,7 +158,6 @@ class TestFSDPMemory(FSDPTest):
|
|||
output = cmp(results, expected)
|
||||
self.assertEqual(output, "")
|
||||
|
||||
@unittest.skipIf(TEST_HPU, "Memory will be differnt for CUDA and HPU, skipping")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("ckpt", ["no_ckpt", "ckpt"])
|
||||
def test_fsdp_memory(self, ckpt):
|
||||
|
|
@ -232,5 +229,7 @@ class TestFSDPMemory(FSDPTest):
|
|||
|
||||
|
||||
instantiate_parametrized_tests(TestFSDPMemory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
|
@ -7,17 +8,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|||
from torch.nn import Linear, Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, get_full_params
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -47,33 +46,37 @@ class TestMultiForward(FSDPTest):
|
|||
def _dist_train(self, wrap_fsdp):
|
||||
# keep everything deterministic for input data
|
||||
torch.manual_seed(0)
|
||||
model = Model(wrap_fsdp).to(device_type.type)
|
||||
|
||||
model = Model(wrap_fsdp).cuda()
|
||||
if wrap_fsdp:
|
||||
model = FSDP(model, device_id=device_type.type)
|
||||
model = FSDP(model)
|
||||
else:
|
||||
model = DistributedDataParallel(model, device_ids=[device_type.type])
|
||||
model = DistributedDataParallel(model, device_ids=[self.rank])
|
||||
optim = SGD(model.parameters(), lr=0.1)
|
||||
in_data = torch.rand(64, 4).to(device_type.type)
|
||||
|
||||
in_data = torch.rand(64, 4).cuda()
|
||||
in_data.requires_grad = True
|
||||
for _ in range(3):
|
||||
out = model(in_data)
|
||||
out.sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
if wrap_fsdp:
|
||||
return get_full_params(model)
|
||||
|
||||
return list(model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_multi_forward(self):
|
||||
# DDP
|
||||
ddp_state = self._dist_train(wrap_fsdp=False)
|
||||
|
||||
# FSDP
|
||||
fsdp_state = self._dist_train(wrap_fsdp=True)
|
||||
|
||||
self.assertEqual(ddp_state, fsdp_state)
|
||||
|
||||
|
||||
devices = ("cpu", "hpu")
|
||||
instantiate_device_type_tests(TestMultiForward, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
|
@ -6,17 +7,15 @@ from torch import distributed as dist
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear, Module, Sequential
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -26,9 +25,9 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
|
||||
|
||||
class InnerModel(Module):
|
||||
def __init__(self, device):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = Sequential(FSDP(Linear(5, 5), device_id=device_type.type))
|
||||
self.layers = Sequential(FSDP(Linear(5, 5)))
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
|
@ -36,32 +35,32 @@ class InnerModel(Module):
|
|||
|
||||
class TestMultipleWrapping(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_multiple_wrapping(self, device):
|
||||
def test_multiple_wrapping(self):
|
||||
"""
|
||||
This test simulates wrapping the module after training to run inference.
|
||||
This is required in cases where later in a session, the model is wrapped again in FSDP but
|
||||
contains nested FSDP wrappers within the module.
|
||||
"""
|
||||
inner_model = InnerModel(device)
|
||||
model = FSDP(inner_model).to(device_type.type)
|
||||
inner_model = InnerModel()
|
||||
model = FSDP(inner_model).cuda()
|
||||
optim = SGD(model.parameters(), lr=0.1)
|
||||
|
||||
for i in range(3):
|
||||
input = torch.rand((1, 5), dtype=torch.float).to(device_type.type)
|
||||
input = torch.rand((1, 5), dtype=torch.float).cuda()
|
||||
input.requires_grad = True
|
||||
output = model(input)
|
||||
output.sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
input = torch.rand((1, 5), dtype=torch.float).to(device_type.type)
|
||||
input = torch.rand((1, 5), dtype=torch.float).cuda()
|
||||
output = model(input)
|
||||
|
||||
# second time to rewrap the inner model
|
||||
# rewrapped_model = FSDP(inner_model, device_id=device)
|
||||
rewrapped_model = FSDP(inner_model).to(device_type.type)
|
||||
rewrapped_model = FSDP(inner_model).cuda()
|
||||
rewrapped_output = rewrapped_model(input)
|
||||
|
||||
self.assertEqual(output, rewrapped_output)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestMultipleWrapping, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from statistics import mean
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -11,13 +10,11 @@ import torch.nn as nn
|
|||
from torch import distributed as dist
|
||||
from torch.cuda import Event
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
get_cycles_per_ms,
|
||||
run_tests,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
|
@ -244,7 +241,6 @@ class TestForwardOverlapWorldSizeOne(FSDPTest):
|
|||
both = e4["gpu_total"]
|
||||
self.assertTrue(compute_only + all_gather_only > 1.1 * both)
|
||||
|
||||
@unittest.skipIf(TEST_HPU, "HPU doesn't has HW sleep API support, skipping")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_forward_overlap(self):
|
||||
self._dist_train()
|
||||
|
|
@ -256,9 +252,5 @@ class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
|
|||
return 2
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(
|
||||
TestForwardOverlapWorldSizeOne, globals(), only_for=devices
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -10,20 +10,20 @@ from torch.distributed.fsdp import (
|
|||
FullyShardedDataParallel as FSDP,
|
||||
MixedPrecision,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
DEVICEInitMode,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
NestedWrappedModule,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
|
@ -37,6 +37,11 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
|
||||
|
||||
class TestPureFP16(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
# Test fails due to inaccuracies when using more than 4 GPUs
|
||||
return min(4, super().world_size)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_pure_fp16_training(self):
|
||||
"""Tests pure FP16 training, including when the parameter's dtype is
|
||||
|
|
@ -97,13 +102,11 @@ class TestPureFP16(FSDPTest):
|
|||
self.process_group,
|
||||
FSDPInitMode.NO_FSDP,
|
||||
DEVICEInitMode.DEVICE_NEVER,
|
||||
{
|
||||
"device_id": device_type,
|
||||
},
|
||||
{},
|
||||
)
|
||||
fsdp_kwargs = {
|
||||
"use_orig_params": use_orig_params,
|
||||
"device_id": device_type,
|
||||
"device_id": torch.cuda.current_device(),
|
||||
"mixed_precision": mixed_precision,
|
||||
}
|
||||
if to_half_before_fsdp_init:
|
||||
|
|
@ -115,7 +118,7 @@ class TestPureFP16(FSDPTest):
|
|||
self.assertEqual(param.dtype, torch.float16)
|
||||
inp = tuple(
|
||||
t.half() if torch.is_tensor(t) else t
|
||||
for t in fsdp_model.module.get_input(self.device_type)
|
||||
for t in fsdp_model.module.get_input(torch.device("cuda"))
|
||||
)
|
||||
out = fsdp_model(*inp)
|
||||
out.sum().backward()
|
||||
|
|
@ -151,7 +154,7 @@ class TestPureFP16(FSDPTest):
|
|||
self.assertEqual(param.grad.dtype, torch.float16)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestPureFP16)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
DEVICEInitMode,
|
||||
|
|
@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_AS
|
|||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -56,7 +57,5 @@ class TestTraversal(FSDPTest):
|
|||
)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestTraversal, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@ from torch import distributed as dist
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn import Linear
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
|
|
@ -24,39 +23,37 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestUnevenParamShard(FSDPTest):
|
||||
def _get_ref_results(self, device, model, input, my_lr):
|
||||
def _get_ref_results(self, model, input, my_lr):
|
||||
with torch.no_grad():
|
||||
# Compute one iteration local output.
|
||||
weight = model.weight.T.clone().to(device_type)
|
||||
v = torch.Tensor(input[self.rank]).to(device_type)
|
||||
weight = model.weight.T.clone().to(self.rank)
|
||||
v = torch.Tensor(input[self.rank]).to(self.rank)
|
||||
ref_forward_output_my_rank = torch.matmul(v, weight)
|
||||
# Compute one iteration global weight update.
|
||||
v = torch.Tensor(input[: self.world_size]).to(device_type)
|
||||
v = torch.Tensor(input[: self.world_size]).to(self.rank)
|
||||
grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
|
||||
ref_weight_out = weight - grad.T * my_lr
|
||||
|
||||
return ref_forward_output_my_rank, ref_weight_out
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_one_iteration(self, device):
|
||||
def test_one_iteration(self):
|
||||
"""Test FSDP with uneven divide of parameter shards."""
|
||||
model = Linear(3, 3, bias=False)
|
||||
input = torch.rand(8, 3)
|
||||
my_lr = 0.1
|
||||
|
||||
ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
|
||||
device, model, input, my_lr
|
||||
model, input, my_lr
|
||||
)
|
||||
|
||||
model.to(device_type)
|
||||
model.to(self.rank)
|
||||
model = FSDP(model)
|
||||
optim = SGD(model.parameters(), lr=my_lr)
|
||||
self.assertTrue(len(input) >= self.world_size)
|
||||
in_data = torch.Tensor(input[self.rank]).to(device_type)
|
||||
in_data = torch.Tensor(input[self.rank]).to(self.rank)
|
||||
out = model(in_data)
|
||||
out.float().sum().backward()
|
||||
optim.step()
|
||||
|
|
@ -68,7 +65,5 @@ class TestUnevenParamShard(FSDPTest):
|
|||
self.assertEqual(ref_weight_out, weight_out)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestUnevenParamShard, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -25,15 +25,12 @@ from torch.testing._internal.common_fsdp import (
|
|||
DEVICEInitMode,
|
||||
FSDPInitMode,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
NestedWrappedModule,
|
||||
TransformerWithSharedParams,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
|
@ -51,6 +48,10 @@ class TestUnshardParamsBase(FSDPTest):
|
|||
This contains any methods common to both the sharded and non-sharded cases.
|
||||
"""
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cuda", self.rank)
|
||||
|
||||
def _test_unshard_params_writeback(
|
||||
self,
|
||||
writeback: bool,
|
||||
|
|
@ -58,8 +59,8 @@ class TestUnshardParamsBase(FSDPTest):
|
|||
**fsdp_kwargs: Dict[str, Any],
|
||||
):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(5, 5, bias=False, device=device_type.type),
|
||||
nn.Linear(5, 3, bias=False, device=device_type.type),
|
||||
nn.Linear(5, 5, bias=False, device=self.device),
|
||||
nn.Linear(5, 3, bias=False, device=self.device),
|
||||
)
|
||||
model[0] = FSDP(model[0], **fsdp_kwargs)
|
||||
model = FSDP(model, **fsdp_kwargs)
|
||||
|
|
@ -125,7 +126,7 @@ class TestUnshardParamsBase(FSDPTest):
|
|||
self.process_group,
|
||||
FSDPInitMode.NO_FSDP,
|
||||
DEVICEInitMode.DEVICE_BEFORE,
|
||||
fsdp_kwargs={"device_id": device_type.type},
|
||||
fsdp_kwargs={},
|
||||
deterministic=True,
|
||||
)
|
||||
# Apply FSDP such that the root module does not have FSDP applied,
|
||||
|
|
@ -237,7 +238,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
testing that writing to padding does not persist.
|
||||
NOTE: This method depends on FSDP internals.
|
||||
"""
|
||||
model = FSDP(nn.Linear(1, 1, bias=False, device=device_type.type))
|
||||
model = FSDP(nn.Linear(1, 1, bias=False, device=self.device))
|
||||
flat_param = model._handle.flat_param
|
||||
self.assertEqual(1, flat_param.numel())
|
||||
# Write a known value to the *sharded* `FlatParameter`
|
||||
|
|
@ -290,10 +291,8 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
}
|
||||
model = FSDP(
|
||||
nn.Sequential(
|
||||
FSDP(
|
||||
nn.Linear(5, 5, bias=False, device=device_type.type), **fsdp_kwargs
|
||||
),
|
||||
nn.Linear(5, 3, bias=False, device=device_type.type),
|
||||
FSDP(nn.Linear(5, 5, bias=False, device=self.device), **fsdp_kwargs),
|
||||
nn.Linear(5, 3, bias=False, device=self.device),
|
||||
),
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
|
@ -310,7 +309,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
# Validate the expected behavior: the root does not reshard after
|
||||
# forward; the non-root reshards after forward; and both reshard after
|
||||
# backward
|
||||
output = model(torch.zeros(5, device=device_type.type))
|
||||
output = model(torch.zeros(5, device=self.device))
|
||||
self.assertEqual(
|
||||
expected_outer_flat_param_unsharded_numel,
|
||||
_get_unsharded_storage_size(outer_flat_param),
|
||||
|
|
@ -322,7 +321,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
|
||||
# Check that with parameter unsharding in between forward and backward
|
||||
# as well as after backward, the reshard behavior matches
|
||||
output = model(torch.zeros(5, device=device_type.type))
|
||||
output = model(torch.zeros(5, device=self.device))
|
||||
with FSDP.summon_full_params(
|
||||
model,
|
||||
rank0_only=rank0_only,
|
||||
|
|
@ -379,10 +378,8 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
}
|
||||
model = FSDP(
|
||||
nn.Sequential(
|
||||
FSDP(
|
||||
nn.Linear(5, 5, bias=False, device=device_type.type), **fsdp_kwargs
|
||||
),
|
||||
nn.Linear(5, 3, bias=False, device=device_type.type),
|
||||
FSDP(nn.Linear(5, 5, bias=False, device=self.device), **fsdp_kwargs),
|
||||
nn.Linear(5, 3, bias=False, device=self.device),
|
||||
),
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
|
@ -560,7 +557,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
DEVICEInitMode.DEVICE_BEFORE,
|
||||
deterministic=True,
|
||||
)
|
||||
ddp_model = DDP(model, device_ids=[device_type])
|
||||
ddp_model = DDP(model, device_ids=[self.rank])
|
||||
fsdp_model = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.RECURSIVE,
|
||||
|
|
@ -569,7 +566,6 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
fsdp_kwargs={
|
||||
"use_orig_params": use_orig_params,
|
||||
"sharding_strategy": sharding_strategy,
|
||||
"device_id": device_type.type,
|
||||
},
|
||||
)
|
||||
with FSDP.summon_full_params(fsdp_model):
|
||||
|
|
@ -577,7 +573,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
assert torch.all(torch.isclose(p1, p2))
|
||||
|
||||
# Check calling after backward
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
ddp_out = ddp_model(*inp)
|
||||
fsdp_out = fsdp_model(*inp)
|
||||
ddp_out.sum().backward()
|
||||
|
|
@ -587,7 +583,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
_check_grads(ddp_model, fsdp_model, old_fsdp_grads)
|
||||
|
||||
# Check calling between forward and backward
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
ddp_out = ddp_model(*inp)
|
||||
fsdp_out = fsdp_model(*inp)
|
||||
old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
|
||||
|
|
@ -620,7 +616,6 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
fsdp_kwargs={
|
||||
"use_orig_params": True,
|
||||
"sharding_strategy": sharding_strategy,
|
||||
"device_id": device_type.type,
|
||||
},
|
||||
)
|
||||
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
|
||||
|
|
@ -635,7 +630,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
model = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
|
||||
nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
|
||||
).to(device_type.type)
|
||||
).cuda()
|
||||
model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
|
||||
with FSDP.summon_full_params(model[0]):
|
||||
# Check that the summoned module does not have its flat parameter
|
||||
|
|
@ -689,7 +684,7 @@ class TestUnshardParamsErrors(TestUnshardParamsBase):
|
|||
with fsdp_module.summon_full_params(fsdp_module):
|
||||
pass
|
||||
|
||||
model = FSDP(MyModule()).to(device_type.type)
|
||||
model = FSDP(MyModule()).cuda(self.rank)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "Cannot manually unshard parameters during forward/backward"
|
||||
):
|
||||
|
|
@ -697,8 +692,8 @@ class TestUnshardParamsErrors(TestUnshardParamsBase):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_unshard_params_from_backward_raises(self):
|
||||
model = FSDP(nn.Linear(2, 1, device=device_type.type))
|
||||
output = model(torch.ones(2, device=device_type.type))
|
||||
model = FSDP(nn.Linear(2, 1, device=self.device))
|
||||
output = model(torch.ones(2, device=self.device))
|
||||
|
||||
def invalid_backward_hook(*args, **kwargs):
|
||||
with FSDP.summon_full_params(model):
|
||||
|
|
|
|||
|
|
@ -16,9 +16,11 @@ from torch.distributed.fsdp.api import (
|
|||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
|
|
@ -26,9 +28,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
# Simple and boring model to test interface and some corner cases that do not
|
||||
# require complicated wrapping strategy.
|
||||
class DenseModel(torch.nn.Module):
|
||||
|
|
@ -43,17 +42,16 @@ class DenseModel(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.net4(self.net3(self.net2(self.net1(x))))
|
||||
|
||||
def get_input(self, device):
|
||||
return torch.rand(4, 8, device=device)
|
||||
def get_input(self):
|
||||
return torch.rand(4, 8, device="cuda")
|
||||
|
||||
|
||||
# TODO: Consolidate DeviceMesh based FSDP and HSDP test cases.
|
||||
class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
||||
def _create_model(self, device, device_mesh=None):
|
||||
device_id = device_type
|
||||
def _create_model(self, device_mesh=None):
|
||||
if device_mesh:
|
||||
model = FSDP(
|
||||
DenseModel().to(device_id),
|
||||
DenseModel().cuda(),
|
||||
device_mesh=device_mesh,
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
)
|
||||
|
|
@ -62,23 +60,22 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
intra_node_pg = mesh_2d.get_group(mesh_dim=1)
|
||||
inter_node_pg = mesh_2d.get_group(mesh_dim=0)
|
||||
model = FSDP(
|
||||
DenseModel().to(device_id),
|
||||
DenseModel().cuda(),
|
||||
process_group=(intra_node_pg, inter_node_pg),
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
)
|
||||
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.1)
|
||||
model(model.get_input(device_id)).sum().backward()
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
return model, optim
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_hsdp_init_with_device_mesh(self, device):
|
||||
device_id = device_type
|
||||
def test_hsdp_init_with_device_mesh(self):
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
model, optim = self._create_model(device_id, mesh_2d)
|
||||
model, optim = self._create_model(mesh_2d)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
|
|
@ -110,11 +107,10 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("offload_to_cpu", [True, False])
|
||||
def test_dtensor_sharded_tensor_state_dict_identical(self, device, offload_to_cpu):
|
||||
device_id = device_type
|
||||
def test_dtensor_sharded_tensor_state_dict_identical(self, offload_to_cpu):
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
model, optim = self._create_model(mesh_2d)
|
||||
|
||||
model, optim = self._create_model(device_id, mesh_2d)
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
|
|
@ -126,7 +122,7 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
dtensor_sd = model.state_dict()
|
||||
dtensor_osd = FSDP.optim_state_dict(model, optim)
|
||||
|
||||
ref_model, ref_optim = self._create_model(device_id)
|
||||
ref_model, ref_optim = self._create_model()
|
||||
FSDP.set_state_dict_type(
|
||||
ref_model,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
|
|
@ -180,10 +176,9 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("offload_to_cpu", [True, False])
|
||||
def test_dtensor_sharded_optim_load_state_dict(self, device, offload_to_cpu):
|
||||
device_id = device_type
|
||||
def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu):
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
model, optim = self._create_model(device_id, mesh_2d)
|
||||
model, optim = self._create_model(mesh_2d)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
|
|
@ -199,7 +194,7 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))
|
||||
|
||||
# Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.
|
||||
model(model.get_input(device_id)).sum().backward()
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
# Load ref_optim_state_dict back.
|
||||
|
|
@ -235,10 +230,9 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("offload_to_cpu", [True, False])
|
||||
def test_dtensor_sharded_model_load_state_dict(self, device, offload_to_cpu):
|
||||
device_id = device_type
|
||||
def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
model, optim = self._create_model(device_id, mesh_2d)
|
||||
model, optim = self._create_model(mesh_2d)
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
model,
|
||||
|
|
@ -252,7 +246,7 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
ref_state_dict = deepcopy(model.state_dict())
|
||||
|
||||
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
|
||||
model(model.get_input(device_id)).sum().backward()
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
|
||||
# Load ref_state_dict back.
|
||||
|
|
@ -273,15 +267,13 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_root_module_is_not_FSDP(self, device):
|
||||
device_id = device_type
|
||||
|
||||
def test_root_module_is_not_FSDP(self):
|
||||
class FakeMPModel(torch.nn.Module):
|
||||
def __init__(self, device_mesh):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.dense = FSDP(
|
||||
DenseModel().to(device_id),
|
||||
DenseModel().cuda(),
|
||||
use_orig_params=True,
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
device_mesh=device_mesh,
|
||||
|
|
@ -300,10 +292,10 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
return self.dense(sparse)
|
||||
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
model = FakeMPModel(device_mesh=mesh_2d).to(device_id)
|
||||
model = FakeMPModel(device_mesh=mesh_2d).cuda()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
batch = torch.rand(5, 8, device=self.device_type)
|
||||
batch = torch.rand(5, 8, device=torch.device("cuda"))
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
osd = optim.state_dict()
|
||||
|
|
@ -324,9 +316,6 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
|
|||
self.assertIsInstance(state["exp_avg_sq"], torch.Tensor)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(
|
||||
TestHSDPWithDeviceMeshAndDTensor, globals(), only_for=devices
|
||||
)
|
||||
instantiate_parametrized_tests(TestHSDPWithDeviceMeshAndDTensor)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
|
@ -10,13 +11,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
|
|
@ -33,25 +32,22 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
list_device = "hpu" if TEST_HPU else "cuda"
|
||||
|
||||
|
||||
class TestUtils(TestCase):
|
||||
@parametrize(
|
||||
"device_list",
|
||||
[
|
||||
["cpu"],
|
||||
[list_device],
|
||||
subtest(["cpu", list_device], name=f"cpu_{list_device}"),
|
||||
],
|
||||
"devices", [["cpu"], ["cuda"], subtest(["cpu", "cuda"], name="cpu_cuda")]
|
||||
)
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_apply_to_tensors(self, device_list):
|
||||
def test_apply_to_tensors(self, devices):
|
||||
if "cuda" in devices and (
|
||||
not torch.cuda.is_available() or torch.cuda.device_count() < 1
|
||||
):
|
||||
raise unittest.SkipTest("Skipped due to lack of GPU")
|
||||
|
||||
expected = 0
|
||||
|
||||
def get_a_tensor():
|
||||
"""Return a random tensor on random device."""
|
||||
dev = random.choice(device_list)
|
||||
dev = random.choice(devices)
|
||||
shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
|
||||
t = torch.rand(shape).to(dev)
|
||||
nonlocal expected
|
||||
|
|
@ -95,7 +91,6 @@ class TestUtils(TestCase):
|
|||
for i, v in enumerate(data):
|
||||
self.assertEqual(type(new_data[i]), type(v))
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_replace_by_prefix(self):
|
||||
state_dict = {
|
||||
"layer.a": torch.tensor(1),
|
||||
|
|
@ -112,7 +107,6 @@ class TestUtils(TestCase):
|
|||
_replace_by_prefix(state_dict, "module.layer.", "layer.")
|
||||
assert state_dict == original_state_dict
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_packed_sequence(self):
|
||||
"""Test to ensure RNN packed sequences are modified correctly."""
|
||||
rnn = nn.RNN(5, 5)
|
||||
|
|
@ -130,7 +124,7 @@ class TestUtils(TestCase):
|
|||
self.assertEqual(torch.sum(x), 0)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestUtils, globals(), only_for=devices)
|
||||
instantiate_parametrized_tests(TestUtils)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -164,10 +164,6 @@ def _assert_module_states(
|
|||
assert_fn(p1, p2)
|
||||
|
||||
|
||||
def get_devtype():
|
||||
return torch.device(DEVICE_TYPE)
|
||||
|
||||
|
||||
def _zero_model(
|
||||
model: nn.Module,
|
||||
zero_buffers: bool = False,
|
||||
|
|
@ -659,7 +655,7 @@ class ModuleWithDelay(FSDPTestModel):
|
|||
loss = self.module.get_loss(input, output) # type: ignore[operator]
|
||||
if self.delay_after_loss_ms > 0:
|
||||
if TEST_HPU:
|
||||
time.sleep(self.delay_after_loss_ms / 1000)
|
||||
time.sleep(self.delay_before_reduction_ms / 1000)
|
||||
elif TEST_CUDA:
|
||||
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
|
||||
|
||||
|
|
@ -1374,12 +1370,7 @@ class FSDPTest(MultiProcessTestCase):
|
|||
**init_kwargs,
|
||||
)
|
||||
if ref_init_fn is None:
|
||||
if TEST_HPU:
|
||||
ref_model = DDP(
|
||||
model, device_ids=[DEVICE_TYPE], output_device=DEVICE_TYPE
|
||||
)
|
||||
else:
|
||||
ref_model = DDP(model, device_ids=[rank], output_device=rank)
|
||||
ref_model = DDP(model, device_ids=[rank], output_device=rank)
|
||||
else:
|
||||
ref_model = ref_init_fn(model)
|
||||
if use_pure_fp16:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user