[PT1.11] make static graph to be stable (#71459)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71459

1. add static_graph feature to DDP constructor;
2. still keep _set_static_graph() API, so that existing use cases are not affected, also it can be called internally by DDP constructor
3. four cases are covered:
    static_graph = False, _set_static_graph() is called;
    static_graph = False, _set_static_graph() is not called;
    static_graph = True, _set_static_graph() is not called;
    static_graph = True, _set_static_graph() is called;
ghstack-source-id: 147263797

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D33646738

fbshipit-source-id: 8c1730591152aab91afce7133d2adf1efd723855
(cherry picked from commit dc246a1129)
This commit is contained in:
Yanli Zhao 2022-01-20 09:33:13 -08:00 committed by PyTorch MergeBot
parent 11d8fe59fd
commit 1c61d8c43f
5 changed files with 76 additions and 60 deletions

View File

@ -1564,9 +1564,8 @@ class DistributedDataParallelTest(
process_group=process_group,
find_unused_parameters=True,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
if static_graph:
cpu_model._set_static_graph()
run_and_verify_grad(cpu_model)
# Test on GPU
@ -1577,9 +1576,8 @@ class DistributedDataParallelTest(
process_group=process_group,
find_unused_parameters=True,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
if static_graph:
gpu_model._set_static_graph()
run_and_verify_grad(gpu_model)
@requires_gloo()

View File

@ -1619,11 +1619,9 @@ class DistributedDataParallelTest(
device_ids=[device_id],
process_group=process_group,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
if static_graph:
gpu_model._set_static_graph()
# Register a DDP communication hook if any.
if hook is not None:
gpu_model.register_comm_hook(state, hook)
@ -2194,9 +2192,8 @@ class DistributedDataParallelTest(
device_ids=[self.rank],
process_group=process_group,
find_unused_parameters=find_unused_parameters,
static_graph=static_graph,
)
if static_graph:
ddp_model._set_static_graph()
self.assertEqual(
ddp_model._get_ddp_logging_data().get("static_graph", 0), static_graph
)

View File

@ -478,6 +478,33 @@ class DistributedDataParallel(Module, Joinable):
gradients. If hitting such errors, please fix it by
referring to the :meth:`~torch.optim.Optimizer.zero_grad`
function in ``torch/optim/optimizer.py`` as a solution.
static_graph (bool): When set to ``True``, DDP knows the trained graph is
static. Static graph means 1) The set of used and unused
parameters will not change during the whole training loop; in
this case, it does not matter whether users set
``find_unused_parameters = True`` or not. 2) How the graph is trained
will not change during the whole training loop (meaning there is
no control flow depending on iterations).
When static_graph is set to be ``True``, DDP will support cases that
can not be supported in the past:
1) Reentrant backwards.
2) Activation checkpointing multiple times.
3) Activation checkpointing when model has unused parameters.
4) There are model parameters that are outside of forward function.
5) Potentially improve performance when there are unused parameters,
as DDP will not search graph in each iteraton to detect unused
parameters when static_graph is set to be ``True``.
To check whether you can set static_graph to be ``True``, one way is to
check ddp logging data at the end of your previous model training,
if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
can set ``static_graph = True`` as well.
Example::
>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
>>> # Training loop
>>> .....
>>> ddp_logging_data = model_DDP._get_ddp_logging_data()
>>> static_graph = ddp_logging_data.get("can_set_static_graph")
Attributes:
@ -501,6 +528,7 @@ class DistributedDataParallel(Module, Joinable):
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
static_graph=False,
):
super(DistributedDataParallel, self).__init__()
@ -620,6 +648,9 @@ class DistributedDataParallel(Module, Joinable):
self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
self._has_rebuilt_buckets = False
if static_graph:
self._set_static_graph()
def _sync_params_and_buffers(self, authoritative_rank=0):
module_states = []
for name, param in self.module.named_parameters():
@ -727,7 +758,8 @@ class DistributedDataParallel(Module, Joinable):
# Builds reducer
self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
if self.static_graph:
self._set_static_graph()
self.reducer._set_static_graph()
self.logger._set_static_graph()
def _build_params_for_reducer(self):
# Build tuple of (module, parameter) for all parameters that require grads.
@ -1635,30 +1667,15 @@ class DistributedDataParallel(Module, Joinable):
def _set_static_graph(self):
"""
Users can explicitly let DDP know the trained graph is static,
when 1) the set of used and unused parameters will not change
during the whole training loop; in this case, it does not matter
whether users set find_unsued_parameters = true or not.
2) how the graph is trained will not change during the whole training
loop (meaning there is no control flow depending on iterations).
When graph is set to be static, DDP will support cases that can not
be supported in the past: 1) reentrant backwards
2) activation checkpointing multiple times 3)
activation checkpointing with find_unused_parameters = true.
4) not all output tensors are used in loss calculation.
5) there is model parameter that is outside of forward function.
6) potentially improve performance when find_unsued_parameters = true
or there are unused parameters, as DDP will not search graph in each
iteraton to detect unused parameters when static_graph is set to be True.
This API should be called after DistributedDataParallel construction, and
before training loops starts. Also it should be called in the same way for
all ranks. For example:
ddp_model = DistributedDataParallel(model)
ddp_model._set_static_graph()
for i in range(n):
.....
It is recommended to set static graph in the DDP constructor, which will
call this private API internally.
"""
# If self.static_graph has been set, no need to set it again
if self.static_graph:
warnings.warn(
"You've set static_graph to be True, no need to set it again."
)
return
self.static_graph = True
self.reducer._set_static_graph()
self.logger._set_static_graph()

View File

@ -3676,6 +3676,7 @@ class DistributedTest:
output_device=None,
gradient_as_bucket_view=False,
static_graph=False,
set_static_graph_twice=False,
):
# Run a simple end to end DDP model, use result of single node model
# as baseline
@ -3694,8 +3695,10 @@ class DistributedTest:
model_DDP,
device_ids=gpu_subset,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
if static_graph:
if set_static_graph_twice:
model_DDP._set_static_graph()
# test serializable/unserializable
@ -3934,10 +3937,9 @@ class DistributedTest:
copy.deepcopy(model).cuda(),
device_ids=[self.rank],
gradient_as_bucket_view=grad_as_bucket_view,
static_graph=static_graph,
)
)
if static_graph:
ddp_model_with_optimizer_hook._set_static_graph()
# Register hook that runs allreduce + functional optimizer
# step.
@ -3957,9 +3959,8 @@ class DistributedTest:
copy.deepcopy(model).cuda(),
device_ids=[self.rank],
gradient_as_bucket_view=grad_as_bucket_view,
static_graph=static_graph,
)
if static_graph:
ddp_model_with_no_hook._set_static_graph()
mapping = {v: k for k, v in functional_optim_map.items()}
optimizer_no_hook = mapping.get(functional_optim_cls)(
@ -4468,6 +4469,15 @@ class DistributedTest:
static_graph=static_graph,
)
# test set static graph twice
self._test_DistributedDataParallel(
gpu_subset=gpus,
rank=rank,
gradient_as_bucket_view=use_bucket_view,
static_graph=static_graph,
set_static_graph_twice=True,
)
# test output_device
self._test_DistributedDataParallel(
gpu_subset=gpus,
@ -5216,10 +5226,6 @@ class DistributedTest:
@sandcastle_skip_if(BACKEND == "nccl", "nccl does not support DDP on CPU models")
def test_static_graph_api_cpu(self):
model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
model_DDP._set_static_graph()
self.assertEqual(
model_DDP._get_ddp_logging_data().get("static_graph"), True
)
expected_err = "should be called before training loop starts"
with self.assertRaisesRegex(RuntimeError, expected_err):
local_bs = 2
@ -6393,9 +6399,8 @@ class DistributedTest:
device_ids=[device_id],
find_unused_parameters=find_unused,
broadcast_buffers=broadcast_buffers,
static_graph=static_graph,
)
if static_graph:
ddp._set_static_graph()
# Materialize new params. These are not registered in DDP and thus
# don't have autograd hooks installed on them.
ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
@ -6513,10 +6518,11 @@ class DistributedTest:
model = ToyModel().to(torch.cuda.current_device())
for static in [True, False]:
ddp_model = torch.nn.parallel.DistributedDataParallel(
copy.deepcopy(model), device_ids=[self.rank], find_unused_parameters=True
copy.deepcopy(model),
device_ids=[self.rank],
find_unused_parameters=True,
static_graph=static,
)
if static:
ddp_model._set_static_graph()
inp = torch.randn(20, 10, device=self.rank)
for i in range(6):
loss = ddp_model(inp)
@ -6758,8 +6764,8 @@ class DistributedTest:
model = torch.nn.parallel.DistributedDataParallel(
ControlFlowToyModel().cuda(self.rank),
device_ids=[self.rank],
static_graph=True,
)
model._set_static_graph()
random_input = torch.randn(20, 10, device=self.rank)
ones_input = torch.ones(20, 10, device=self.rank)
# unused parameter in the first iteration got used
@ -7286,9 +7292,8 @@ class DistributedTest:
device_ids=[self.rank],
find_unused_parameters=find_unused_parameters,
gradient_as_bucket_view=True,
static_graph=static_graph,
)
if static_graph:
ddp_model._set_static_graph()
random_input = torch.randn(20, 10, device=self.rank)
for i in range(10):
out = ddp_model(random_input)
@ -7906,8 +7911,8 @@ class DistributedTest:
model_static_graph = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
static_graph=True,
)
model_static_graph._set_static_graph()
inp = torch.randn(10, 100)
type_mapping = {
"list": list,
@ -7959,10 +7964,9 @@ class DistributedTest:
model,
device_ids=[self.rank],
output_device=self.rank,
find_unused_parameters=find_unused
find_unused_parameters=find_unused,
static_graph=static_graph,
)
if static_graph:
ddp._set_static_graph()
for i in range(6):
out = ddp(inp)
self.assertFalse(out[0].requires_grad)
@ -8053,11 +8057,9 @@ class DistributedTest:
output_device=self.rank,
broadcast_buffers=False,
find_unused_parameters=find_unused,
static_graph=static_graph,
)
if static_graph:
ddp._set_static_graph()
opt = [None for _ in range(3)]
for i in range(2):
ddp.zero_grad()

View File

@ -106,9 +106,11 @@ class PipeWithDDPTest(RpcAgentTestFixture):
layer2
)
model = Pipe(model, chunks=2, checkpoint=checkpoint)
model = DistributedDataParallel(model, find_unused_parameters=find_unused_parameters)
if static_graph:
model._set_static_graph()
model = DistributedDataParallel(
model,
find_unused_parameters=find_unused_parameters,
static_graph=static_graph,
)
# Ensure inputs are different across ranks to verify that gradient
# sync indeed occurs.