pytorch/test/distributed/fsdp/test_fsdp_comm_hooks.py
Xuehai Pan 758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00

433 lines
15 KiB
Python

# Owner(s): ["oncall: distributed"]
import sys
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed as dist
from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import (
requires_nccl,
requires_nccl_version,
skip_but_pass_in_sandcastle_if,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
# bfloat16 is only supported by CUDA 11+
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
torch.version.cuda is not None or torch.version.hip is not None
)
class Net(nn.Module):
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
# to ensure determinism
torch.manual_seed(0)
torch.cuda.manual_seed(0)
super().__init__()
if has_wrapping:
self.net = FSDP(
nn.Sequential(
nn.Linear(8, 16),
nn.ReLU(),
FSDP(
nn.Linear(16, 8),
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
),
),
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
)
else:
self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8))
self.out = nn.Linear(8, 4)
def forward(self, x):
return self.out(F.relu(self.net(x)))
class DummyState:
__slots__ = ["process_group", "noise"]
def __init__(self, process_group: dist.ProcessGroup, noise: int):
self.process_group = process_group
self.noise = noise
class DummyHook:
def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor):
"""
This communication hook is for illustration and testing purpose only.
This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to
the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened,
unsharded gradient.
"""
grad.add_(state.noise)
dist.all_reduce(grad, group=state.process_group)
def custom_reduce_scatter(self, output, input, group=None):
"""
This function is for illustrative purpose only.
It is meant to implement a custom reduce-scatter
of a flattened tensor to all processes in a group.
Currently a no-op.
"""
def dummy_hook_for_sharded_fsdp(
self, state: DummyState, grad: torch.Tensor, output: torch.Tensor
):
"""
This communication hook is for illustration and testing purposes only.
This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training.
It adds some noise to the provided ``grad`` parameter, uses
``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``.
"""
grad.add_(state.noise)
self.custom_reduce_scatter(output, grad, group=state.process_group)
class TestCommunicationHooks(FSDPTest):
@skip_if_lt_x_gpu(2)
@parametrize(
"sharding_strategy",
[
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
],
)
def test_default_communication_hook_behavior(
self, sharding_strategy: Optional[ShardingStrategy]
):
"""
Tests FSDP's default communication hook's behavior and correctness.
This test creates a simple linear net with weight shape ``1 X N``,
where ``N`` is the number of workers.
For sharded cases, each worker gets 1 element of the weight parameter. This test
checks that after backward, each worker has a proper value in its chunk of
the gradient, or the whole gradient on every worker is equal to an expected value.
Arguments:
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
"""
out_dim = self.world_size
net = torch.nn.Linear(1, out_dim, bias=False)
inpt = torch.tensor([self.rank]).float().cuda(self.rank)
net_default_hook = FSDP(
net,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
).to(self.rank)
# Check that by default, `_comm_hook` is None
for entry in FSDP.fsdp_modules(net_default_hook):
self.assertEqual(entry._comm_hook, None)
for _ in range(4):
# Clear gradients
net_default_hook.zero_grad()
loss = net_default_hook(inpt).sum()
loss.backward()
# For each worker, the gradient on the weight should be worker_rank.
grad = net_default_hook.params[0].grad
expected_grad = (
sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
)
# Verify default hook produces expected gradients
self.assertEqual(
grad[0].item(),
expected_grad,
msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}",
)
def _get_submodules(self, fsdp_net):
return [
submodule
for submodule in FSDP.fsdp_modules(fsdp_net)
if not submodule.check_is_root()
]
def _init_model(self, core, sharding_strategy, mixed_precision=None):
device = torch.device("cuda")
return FSDP(
core,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
).to(device)
@skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False])
@parametrize(
"sharding_strategy",
[
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
],
)
def test_default_communication_hook_initialization(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
):
"""
Tests FSDP's communication hook interface behavior.
Arguments:
has_wrapping (bool): Configures wrapping of a module.
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
"""
# Initialize a model
fsdp_model_with_hook = self._init_model(
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
sharding_strategy=sharding_strategy,
)
# Check that by default, `_comm_hook` is None
for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
self.assertEqual(fsdp_module._comm_hook, None)
dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = (
DummyHook.dummy_hook_for_no_shard_fsdp
if sharding_strategy != ShardingStrategy.NO_SHARD
else DummyHook.dummy_hook_for_sharded_fsdp
)
fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
# Check that we can't register comm hook twice
with self.assertRaisesRegex(
AssertionError, "^A communication hook is already registered$"
):
fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
# Check dummy hook was registered for the root and all submodules if any
for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
self.assertEqual(fsdp_module._comm_hook, dummy_hook)
self.assertEqual(fsdp_module._comm_hook_state, dummy_state)
@skip_if_lt_x_gpu(2)
@parametrize(
"sharding_strategy",
[
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
],
)
def test_registering_hook_non_root(
self, sharding_strategy: Optional[ShardingStrategy]
):
"""
Tests FSDP's communication hook registering for submodules.
Make sure it can't be registered for non-root submodules.
Currently tests only ``NO_SHARD`` strategy.
Arguments:
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
"""
fsdp_model_with_hook = self._init_model(
Net(has_wrapping=True, sharding_strategy=sharding_strategy),
sharding_strategy=sharding_strategy,
)
dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = (
DummyHook.dummy_hook_for_no_shard_fsdp
if sharding_strategy != ShardingStrategy.NO_SHARD
else DummyHook.dummy_hook_for_sharded_fsdp
)
# Creating a list of non-root submodules to test
submodules = self._get_submodules(fsdp_model_with_hook)
# Check that assertion is raised for registering a comm hook on a non-root
with self.assertRaisesRegex(
AssertionError,
"^register_comm_hook can only be called on a root instance.$",
):
submodules[1].register_comm_hook(dummy_state, dummy_hook)
@skip_if_lt_x_gpu(2)
def test_registering_hook_hybrid_strategy(self):
for sharding_strategy in (
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
):
model = Net(False, None, None).cuda()
fsdp_model = FSDP(
model,
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
sharding_strategy=sharding_strategy,
)
dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp
with self.assertRaisesRegex(
AssertionError,
"Communication hook is not supported for hybrid strategies",
):
fsdp_model.register_comm_hook(dummy_state, dummy_hook)
@skip_if_lt_x_gpu(2)
@parametrize(
"sharding_strategy",
[
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
],
)
def test_registering_hook_submodules(
self, sharding_strategy: Optional[ShardingStrategy]
):
"""
Tests FSDP's communication hook registering for submodules.
Checks behavior if a hook was registered for a non-root submodule
Currently tests only ``NO_SHARD`` strategy.
Arguments:
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
"""
fsdp_model_with_hook = self._init_model(
Net(has_wrapping=True, sharding_strategy=sharding_strategy),
sharding_strategy=sharding_strategy,
)
dummy_state = DummyState(process_group=None, noise=1234)
dummy_hook = (
DummyHook.dummy_hook_for_no_shard_fsdp
if sharding_strategy != ShardingStrategy.NO_SHARD
else DummyHook.dummy_hook_for_sharded_fsdp
)
submodules = self._get_submodules(fsdp_model_with_hook)
# Simulate a registration of a hook on a submodule
submodules[1]._comm_hook = dummy_hook
# Check that an error is raised when some of submodules have a non-default hook assigned
with self.assertRaisesRegex(
AssertionError, "^A communication hook is already registered$"
):
fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
def _check_low_precision_hook(
self, state, hook, sharding_strategy, dtype, has_wrapping
):
# keep everything deterministic for input data
torch.manual_seed(0)
torch.cuda.manual_seed(0)
fsdp_with_hook = self._init_model(
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
sharding_strategy=sharding_strategy,
)
fsdp_with_hook.register_comm_hook(state, hook)
mp_only_grad = MixedPrecision(reduce_dtype=dtype)
fsdp_with_mp = self._init_model(
Net(
has_wrapping=has_wrapping,
sharding_strategy=sharding_strategy,
mixed_precision=mp_only_grad,
),
sharding_strategy=sharding_strategy,
mixed_precision=mp_only_grad,
)
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
in_data = torch.rand(16, 8).cuda()
fsdp_with_hook.train()
fsdp_with_mp.train()
loss_hook = fsdp_with_hook(in_data).sum()
loss_mp = fsdp_with_mp(in_data).sum()
loss_hook.backward()
# Make sure grads were cast to the parameter's precision
self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type)
loss_mp.backward()
optim_hook.step()
optim_mp.step()
dist.barrier()
for hook_param, mp_param in zip(
fsdp_with_hook.parameters(), fsdp_with_mp.parameters()
):
self.assertEqual(hook_param.grad, mp_param.grad)
@requires_nccl()
@skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False])
@parametrize(
"sharding_strategy",
[
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
],
)
def test_fp16_hook(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
):
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.fp16_compress_hook
self._check_low_precision_hook(
state, hook, sharding_strategy, torch.float16, has_wrapping
)
@requires_nccl()
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
@skip_but_pass_in_sandcastle_if(
not BFLOAT16_AVAILABLE,
"BFloat16 is only supported by CUDA 11+",
)
@skip_if_lt_x_gpu(2)
@parametrize("has_wrapping", [True, False])
@parametrize(
"sharding_strategy",
[
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
],
)
def test_bf16_hook(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
):
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.bf16_compress_hook
self._check_low_precision_hook(
state, hook, sharding_strategy, torch.bfloat16, has_wrapping
)
instantiate_parametrized_tests(TestCommunicationHooks)
if __name__ == "__main__":
run_tests()