mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change. Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200 Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
433 lines
15 KiB
Python
433 lines
15 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import distributed as dist
|
|
from torch.distributed.algorithms._comm_hooks import default_hooks
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
from torch.testing._internal.common_distributed import (
|
|
requires_nccl,
|
|
requires_nccl_version,
|
|
skip_but_pass_in_sandcastle_if,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.testing._internal.common_fsdp import FSDPTest
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
)
|
|
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
# bfloat16 is only supported by CUDA 11+
|
|
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
|
torch.version.cuda is not None or torch.version.hip is not None
|
|
)
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
|
|
# to ensure determinism
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
super().__init__()
|
|
|
|
if has_wrapping:
|
|
self.net = FSDP(
|
|
nn.Sequential(
|
|
nn.Linear(8, 16),
|
|
nn.ReLU(),
|
|
FSDP(
|
|
nn.Linear(16, 8),
|
|
device_id=torch.cuda.current_device(),
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
),
|
|
),
|
|
device_id=torch.cuda.current_device(),
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
)
|
|
else:
|
|
self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8))
|
|
|
|
self.out = nn.Linear(8, 4)
|
|
|
|
def forward(self, x):
|
|
return self.out(F.relu(self.net(x)))
|
|
|
|
|
|
class DummyState:
|
|
__slots__ = ["process_group", "noise"]
|
|
|
|
def __init__(self, process_group: dist.ProcessGroup, noise: int):
|
|
self.process_group = process_group
|
|
self.noise = noise
|
|
|
|
|
|
class DummyHook:
|
|
def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor):
|
|
"""
|
|
This communication hook is for illustration and testing purpose only.
|
|
This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to
|
|
the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened,
|
|
unsharded gradient.
|
|
"""
|
|
grad.add_(state.noise)
|
|
dist.all_reduce(grad, group=state.process_group)
|
|
|
|
def custom_reduce_scatter(self, output, input, group=None):
|
|
"""
|
|
This function is for illustrative purpose only.
|
|
It is meant to implement a custom reduce-scatter
|
|
of a flattened tensor to all processes in a group.
|
|
Currently a no-op.
|
|
"""
|
|
|
|
def dummy_hook_for_sharded_fsdp(
|
|
self, state: DummyState, grad: torch.Tensor, output: torch.Tensor
|
|
):
|
|
"""
|
|
This communication hook is for illustration and testing purposes only.
|
|
This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training.
|
|
It adds some noise to the provided ``grad`` parameter, uses
|
|
``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``.
|
|
"""
|
|
grad.add_(state.noise)
|
|
self.custom_reduce_scatter(output, grad, group=state.process_group)
|
|
|
|
|
|
class TestCommunicationHooks(FSDPTest):
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[
|
|
ShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
)
|
|
def test_default_communication_hook_behavior(
|
|
self, sharding_strategy: Optional[ShardingStrategy]
|
|
):
|
|
"""
|
|
Tests FSDP's default communication hook's behavior and correctness.
|
|
This test creates a simple linear net with weight shape ``1 X N``,
|
|
where ``N`` is the number of workers.
|
|
For sharded cases, each worker gets 1 element of the weight parameter. This test
|
|
checks that after backward, each worker has a proper value in its chunk of
|
|
the gradient, or the whole gradient on every worker is equal to an expected value.
|
|
|
|
Arguments:
|
|
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
|
|
"""
|
|
out_dim = self.world_size
|
|
net = torch.nn.Linear(1, out_dim, bias=False)
|
|
inpt = torch.tensor([self.rank]).float().cuda(self.rank)
|
|
|
|
net_default_hook = FSDP(
|
|
net,
|
|
device_id=torch.cuda.current_device(),
|
|
sharding_strategy=sharding_strategy,
|
|
).to(self.rank)
|
|
|
|
# Check that by default, `_comm_hook` is None
|
|
for entry in FSDP.fsdp_modules(net_default_hook):
|
|
self.assertEqual(entry._comm_hook, None)
|
|
|
|
for _ in range(4):
|
|
# Clear gradients
|
|
net_default_hook.zero_grad()
|
|
loss = net_default_hook(inpt).sum()
|
|
loss.backward()
|
|
|
|
# For each worker, the gradient on the weight should be worker_rank.
|
|
grad = net_default_hook.params[0].grad
|
|
expected_grad = (
|
|
sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
|
|
)
|
|
# Verify default hook produces expected gradients
|
|
self.assertEqual(
|
|
grad[0].item(),
|
|
expected_grad,
|
|
msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}",
|
|
)
|
|
|
|
def _get_submodules(self, fsdp_net):
|
|
return [
|
|
submodule
|
|
for submodule in FSDP.fsdp_modules(fsdp_net)
|
|
if not submodule.check_is_root()
|
|
]
|
|
|
|
def _init_model(self, core, sharding_strategy, mixed_precision=None):
|
|
device = torch.device("cuda")
|
|
return FSDP(
|
|
core,
|
|
device_id=torch.cuda.current_device(),
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mixed_precision,
|
|
).to(device)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("has_wrapping", [True, False])
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[
|
|
ShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
)
|
|
def test_default_communication_hook_initialization(
|
|
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
|
|
):
|
|
"""
|
|
Tests FSDP's communication hook interface behavior.
|
|
|
|
Arguments:
|
|
has_wrapping (bool): Configures wrapping of a module.
|
|
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
|
|
"""
|
|
|
|
# Initialize a model
|
|
fsdp_model_with_hook = self._init_model(
|
|
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
|
|
sharding_strategy=sharding_strategy,
|
|
)
|
|
|
|
# Check that by default, `_comm_hook` is None
|
|
for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
|
|
self.assertEqual(fsdp_module._comm_hook, None)
|
|
|
|
dummy_state = DummyState(process_group=None, noise=1234)
|
|
dummy_hook = (
|
|
DummyHook.dummy_hook_for_no_shard_fsdp
|
|
if sharding_strategy != ShardingStrategy.NO_SHARD
|
|
else DummyHook.dummy_hook_for_sharded_fsdp
|
|
)
|
|
|
|
fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
|
|
|
|
# Check that we can't register comm hook twice
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "^A communication hook is already registered$"
|
|
):
|
|
fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
|
|
|
|
# Check dummy hook was registered for the root and all submodules if any
|
|
for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
|
|
self.assertEqual(fsdp_module._comm_hook, dummy_hook)
|
|
self.assertEqual(fsdp_module._comm_hook_state, dummy_state)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[
|
|
ShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
)
|
|
def test_registering_hook_non_root(
|
|
self, sharding_strategy: Optional[ShardingStrategy]
|
|
):
|
|
"""
|
|
Tests FSDP's communication hook registering for submodules.
|
|
Make sure it can't be registered for non-root submodules.
|
|
Currently tests only ``NO_SHARD`` strategy.
|
|
|
|
Arguments:
|
|
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
|
|
"""
|
|
|
|
fsdp_model_with_hook = self._init_model(
|
|
Net(has_wrapping=True, sharding_strategy=sharding_strategy),
|
|
sharding_strategy=sharding_strategy,
|
|
)
|
|
dummy_state = DummyState(process_group=None, noise=1234)
|
|
dummy_hook = (
|
|
DummyHook.dummy_hook_for_no_shard_fsdp
|
|
if sharding_strategy != ShardingStrategy.NO_SHARD
|
|
else DummyHook.dummy_hook_for_sharded_fsdp
|
|
)
|
|
# Creating a list of non-root submodules to test
|
|
submodules = self._get_submodules(fsdp_model_with_hook)
|
|
# Check that assertion is raised for registering a comm hook on a non-root
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"^register_comm_hook can only be called on a root instance.$",
|
|
):
|
|
submodules[1].register_comm_hook(dummy_state, dummy_hook)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_registering_hook_hybrid_strategy(self):
|
|
for sharding_strategy in (
|
|
ShardingStrategy.HYBRID_SHARD,
|
|
ShardingStrategy._HYBRID_SHARD_ZERO2,
|
|
):
|
|
model = Net(False, None, None).cuda()
|
|
fsdp_model = FSDP(
|
|
model,
|
|
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
|
|
sharding_strategy=sharding_strategy,
|
|
)
|
|
dummy_state = DummyState(process_group=None, noise=1234)
|
|
dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Communication hook is not supported for hybrid strategies",
|
|
):
|
|
fsdp_model.register_comm_hook(dummy_state, dummy_hook)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[
|
|
ShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
)
|
|
def test_registering_hook_submodules(
|
|
self, sharding_strategy: Optional[ShardingStrategy]
|
|
):
|
|
"""
|
|
Tests FSDP's communication hook registering for submodules.
|
|
Checks behavior if a hook was registered for a non-root submodule
|
|
Currently tests only ``NO_SHARD`` strategy.
|
|
|
|
Arguments:
|
|
sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
|
|
"""
|
|
|
|
fsdp_model_with_hook = self._init_model(
|
|
Net(has_wrapping=True, sharding_strategy=sharding_strategy),
|
|
sharding_strategy=sharding_strategy,
|
|
)
|
|
dummy_state = DummyState(process_group=None, noise=1234)
|
|
dummy_hook = (
|
|
DummyHook.dummy_hook_for_no_shard_fsdp
|
|
if sharding_strategy != ShardingStrategy.NO_SHARD
|
|
else DummyHook.dummy_hook_for_sharded_fsdp
|
|
)
|
|
submodules = self._get_submodules(fsdp_model_with_hook)
|
|
|
|
# Simulate a registration of a hook on a submodule
|
|
submodules[1]._comm_hook = dummy_hook
|
|
# Check that an error is raised when some of submodules have a non-default hook assigned
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "^A communication hook is already registered$"
|
|
):
|
|
fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
|
|
|
|
def _check_low_precision_hook(
|
|
self, state, hook, sharding_strategy, dtype, has_wrapping
|
|
):
|
|
# keep everything deterministic for input data
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
|
|
fsdp_with_hook = self._init_model(
|
|
Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
|
|
sharding_strategy=sharding_strategy,
|
|
)
|
|
fsdp_with_hook.register_comm_hook(state, hook)
|
|
|
|
mp_only_grad = MixedPrecision(reduce_dtype=dtype)
|
|
fsdp_with_mp = self._init_model(
|
|
Net(
|
|
has_wrapping=has_wrapping,
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mp_only_grad,
|
|
),
|
|
sharding_strategy=sharding_strategy,
|
|
mixed_precision=mp_only_grad,
|
|
)
|
|
|
|
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
|
|
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
|
|
|
|
in_data = torch.rand(16, 8).cuda()
|
|
fsdp_with_hook.train()
|
|
fsdp_with_mp.train()
|
|
loss_hook = fsdp_with_hook(in_data).sum()
|
|
loss_mp = fsdp_with_mp(in_data).sum()
|
|
loss_hook.backward()
|
|
# Make sure grads were cast to the parameter's precision
|
|
self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type)
|
|
loss_mp.backward()
|
|
optim_hook.step()
|
|
optim_mp.step()
|
|
|
|
dist.barrier()
|
|
|
|
for hook_param, mp_param in zip(
|
|
fsdp_with_hook.parameters(), fsdp_with_mp.parameters()
|
|
):
|
|
self.assertEqual(hook_param.grad, mp_param.grad)
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("has_wrapping", [True, False])
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[
|
|
ShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
)
|
|
def test_fp16_hook(
|
|
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
|
|
):
|
|
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
|
|
hook = default_hooks.fp16_compress_hook
|
|
|
|
self._check_low_precision_hook(
|
|
state, hook, sharding_strategy, torch.float16, has_wrapping
|
|
)
|
|
|
|
@requires_nccl()
|
|
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
|
|
@skip_but_pass_in_sandcastle_if(
|
|
not BFLOAT16_AVAILABLE,
|
|
"BFloat16 is only supported by CUDA 11+",
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("has_wrapping", [True, False])
|
|
@parametrize(
|
|
"sharding_strategy",
|
|
[
|
|
ShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP,
|
|
],
|
|
)
|
|
def test_bf16_hook(
|
|
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
|
|
):
|
|
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
|
|
hook = default_hooks.bf16_compress_hook
|
|
|
|
self._check_low_precision_hook(
|
|
state, hook, sharding_strategy, torch.bfloat16, has_wrapping
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestCommunicationHooks)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|