mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PT1.11] make static graph to be stable (#71459)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71459
1. add static_graph feature to DDP constructor;
2. still keep _set_static_graph() API, so that existing use cases are not affected, also it can be called internally by DDP constructor
3. four cases are covered:
static_graph = False, _set_static_graph() is called;
static_graph = False, _set_static_graph() is not called;
static_graph = True, _set_static_graph() is not called;
static_graph = True, _set_static_graph() is called;
ghstack-source-id: 147263797
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D33646738
fbshipit-source-id: 8c1730591152aab91afce7133d2adf1efd723855
(cherry picked from commit dc246a1129)
This commit is contained in:
parent
11d8fe59fd
commit
1c61d8c43f
|
|
@ -1564,9 +1564,8 @@ class DistributedDataParallelTest(
|
|||
process_group=process_group,
|
||||
find_unused_parameters=True,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
cpu_model._set_static_graph()
|
||||
run_and_verify_grad(cpu_model)
|
||||
|
||||
# Test on GPU
|
||||
|
|
@ -1577,9 +1576,8 @@ class DistributedDataParallelTest(
|
|||
process_group=process_group,
|
||||
find_unused_parameters=True,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
gpu_model._set_static_graph()
|
||||
run_and_verify_grad(gpu_model)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
|
|||
|
|
@ -1619,11 +1619,9 @@ class DistributedDataParallelTest(
|
|||
device_ids=[device_id],
|
||||
process_group=process_group,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
if static_graph:
|
||||
gpu_model._set_static_graph()
|
||||
|
||||
# Register a DDP communication hook if any.
|
||||
if hook is not None:
|
||||
gpu_model.register_comm_hook(state, hook)
|
||||
|
|
@ -2194,9 +2192,8 @@ class DistributedDataParallelTest(
|
|||
device_ids=[self.rank],
|
||||
process_group=process_group,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
ddp_model._set_static_graph()
|
||||
self.assertEqual(
|
||||
ddp_model._get_ddp_logging_data().get("static_graph", 0), static_graph
|
||||
)
|
||||
|
|
|
|||
|
|
@ -478,6 +478,33 @@ class DistributedDataParallel(Module, Joinable):
|
|||
gradients. If hitting such errors, please fix it by
|
||||
referring to the :meth:`~torch.optim.Optimizer.zero_grad`
|
||||
function in ``torch/optim/optimizer.py`` as a solution.
|
||||
static_graph (bool): When set to ``True``, DDP knows the trained graph is
|
||||
static. Static graph means 1) The set of used and unused
|
||||
parameters will not change during the whole training loop; in
|
||||
this case, it does not matter whether users set
|
||||
``find_unused_parameters = True`` or not. 2) How the graph is trained
|
||||
will not change during the whole training loop (meaning there is
|
||||
no control flow depending on iterations).
|
||||
When static_graph is set to be ``True``, DDP will support cases that
|
||||
can not be supported in the past:
|
||||
1) Reentrant backwards.
|
||||
2) Activation checkpointing multiple times.
|
||||
3) Activation checkpointing when model has unused parameters.
|
||||
4) There are model parameters that are outside of forward function.
|
||||
5) Potentially improve performance when there are unused parameters,
|
||||
as DDP will not search graph in each iteraton to detect unused
|
||||
parameters when static_graph is set to be ``True``.
|
||||
To check whether you can set static_graph to be ``True``, one way is to
|
||||
check ddp logging data at the end of your previous model training,
|
||||
if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
|
||||
can set ``static_graph = True`` as well.
|
||||
|
||||
Example::
|
||||
>>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
|
||||
>>> # Training loop
|
||||
>>> .....
|
||||
>>> ddp_logging_data = model_DDP._get_ddp_logging_data()
|
||||
>>> static_graph = ddp_logging_data.get("can_set_static_graph")
|
||||
|
||||
|
||||
Attributes:
|
||||
|
|
@ -501,6 +528,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
find_unused_parameters=False,
|
||||
check_reduction=False,
|
||||
gradient_as_bucket_view=False,
|
||||
static_graph=False,
|
||||
):
|
||||
|
||||
super(DistributedDataParallel, self).__init__()
|
||||
|
|
@ -620,6 +648,9 @@ class DistributedDataParallel(Module, Joinable):
|
|||
self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
|
||||
self._has_rebuilt_buckets = False
|
||||
|
||||
if static_graph:
|
||||
self._set_static_graph()
|
||||
|
||||
def _sync_params_and_buffers(self, authoritative_rank=0):
|
||||
module_states = []
|
||||
for name, param in self.module.named_parameters():
|
||||
|
|
@ -727,7 +758,8 @@ class DistributedDataParallel(Module, Joinable):
|
|||
# Builds reducer
|
||||
self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
|
||||
if self.static_graph:
|
||||
self._set_static_graph()
|
||||
self.reducer._set_static_graph()
|
||||
self.logger._set_static_graph()
|
||||
|
||||
def _build_params_for_reducer(self):
|
||||
# Build tuple of (module, parameter) for all parameters that require grads.
|
||||
|
|
@ -1635,30 +1667,15 @@ class DistributedDataParallel(Module, Joinable):
|
|||
|
||||
def _set_static_graph(self):
|
||||
"""
|
||||
Users can explicitly let DDP know the trained graph is static,
|
||||
when 1) the set of used and unused parameters will not change
|
||||
during the whole training loop; in this case, it does not matter
|
||||
whether users set find_unsued_parameters = true or not.
|
||||
2) how the graph is trained will not change during the whole training
|
||||
loop (meaning there is no control flow depending on iterations).
|
||||
When graph is set to be static, DDP will support cases that can not
|
||||
be supported in the past: 1) reentrant backwards
|
||||
2) activation checkpointing multiple times 3)
|
||||
activation checkpointing with find_unused_parameters = true.
|
||||
4) not all output tensors are used in loss calculation.
|
||||
5) there is model parameter that is outside of forward function.
|
||||
6) potentially improve performance when find_unsued_parameters = true
|
||||
or there are unused parameters, as DDP will not search graph in each
|
||||
iteraton to detect unused parameters when static_graph is set to be True.
|
||||
|
||||
This API should be called after DistributedDataParallel construction, and
|
||||
before training loops starts. Also it should be called in the same way for
|
||||
all ranks. For example:
|
||||
ddp_model = DistributedDataParallel(model)
|
||||
ddp_model._set_static_graph()
|
||||
for i in range(n):
|
||||
.....
|
||||
It is recommended to set static graph in the DDP constructor, which will
|
||||
call this private API internally.
|
||||
"""
|
||||
# If self.static_graph has been set, no need to set it again
|
||||
if self.static_graph:
|
||||
warnings.warn(
|
||||
"You've set static_graph to be True, no need to set it again."
|
||||
)
|
||||
return
|
||||
self.static_graph = True
|
||||
self.reducer._set_static_graph()
|
||||
self.logger._set_static_graph()
|
||||
|
|
|
|||
|
|
@ -3676,6 +3676,7 @@ class DistributedTest:
|
|||
output_device=None,
|
||||
gradient_as_bucket_view=False,
|
||||
static_graph=False,
|
||||
set_static_graph_twice=False,
|
||||
):
|
||||
# Run a simple end to end DDP model, use result of single node model
|
||||
# as baseline
|
||||
|
|
@ -3694,8 +3695,10 @@ class DistributedTest:
|
|||
model_DDP,
|
||||
device_ids=gpu_subset,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
|
||||
if set_static_graph_twice:
|
||||
model_DDP._set_static_graph()
|
||||
|
||||
# test serializable/unserializable
|
||||
|
|
@ -3934,10 +3937,9 @@ class DistributedTest:
|
|||
copy.deepcopy(model).cuda(),
|
||||
device_ids=[self.rank],
|
||||
gradient_as_bucket_view=grad_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
)
|
||||
if static_graph:
|
||||
ddp_model_with_optimizer_hook._set_static_graph()
|
||||
|
||||
# Register hook that runs allreduce + functional optimizer
|
||||
# step.
|
||||
|
|
@ -3957,9 +3959,8 @@ class DistributedTest:
|
|||
copy.deepcopy(model).cuda(),
|
||||
device_ids=[self.rank],
|
||||
gradient_as_bucket_view=grad_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
ddp_model_with_no_hook._set_static_graph()
|
||||
|
||||
mapping = {v: k for k, v in functional_optim_map.items()}
|
||||
optimizer_no_hook = mapping.get(functional_optim_cls)(
|
||||
|
|
@ -4468,6 +4469,15 @@ class DistributedTest:
|
|||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
# test set static graph twice
|
||||
self._test_DistributedDataParallel(
|
||||
gpu_subset=gpus,
|
||||
rank=rank,
|
||||
gradient_as_bucket_view=use_bucket_view,
|
||||
static_graph=static_graph,
|
||||
set_static_graph_twice=True,
|
||||
)
|
||||
|
||||
# test output_device
|
||||
self._test_DistributedDataParallel(
|
||||
gpu_subset=gpus,
|
||||
|
|
@ -5216,10 +5226,6 @@ class DistributedTest:
|
|||
@sandcastle_skip_if(BACKEND == "nccl", "nccl does not support DDP on CPU models")
|
||||
def test_static_graph_api_cpu(self):
|
||||
model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
|
||||
model_DDP._set_static_graph()
|
||||
self.assertEqual(
|
||||
model_DDP._get_ddp_logging_data().get("static_graph"), True
|
||||
)
|
||||
expected_err = "should be called before training loop starts"
|
||||
with self.assertRaisesRegex(RuntimeError, expected_err):
|
||||
local_bs = 2
|
||||
|
|
@ -6393,9 +6399,8 @@ class DistributedTest:
|
|||
device_ids=[device_id],
|
||||
find_unused_parameters=find_unused,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
ddp._set_static_graph()
|
||||
# Materialize new params. These are not registered in DDP and thus
|
||||
# don't have autograd hooks installed on them.
|
||||
ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
|
||||
|
|
@ -6513,10 +6518,11 @@ class DistributedTest:
|
|||
model = ToyModel().to(torch.cuda.current_device())
|
||||
for static in [True, False]:
|
||||
ddp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
copy.deepcopy(model), device_ids=[self.rank], find_unused_parameters=True
|
||||
copy.deepcopy(model),
|
||||
device_ids=[self.rank],
|
||||
find_unused_parameters=True,
|
||||
static_graph=static,
|
||||
)
|
||||
if static:
|
||||
ddp_model._set_static_graph()
|
||||
inp = torch.randn(20, 10, device=self.rank)
|
||||
for i in range(6):
|
||||
loss = ddp_model(inp)
|
||||
|
|
@ -6758,8 +6764,8 @@ class DistributedTest:
|
|||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
ControlFlowToyModel().cuda(self.rank),
|
||||
device_ids=[self.rank],
|
||||
static_graph=True,
|
||||
)
|
||||
model._set_static_graph()
|
||||
random_input = torch.randn(20, 10, device=self.rank)
|
||||
ones_input = torch.ones(20, 10, device=self.rank)
|
||||
# unused parameter in the first iteration got used
|
||||
|
|
@ -7286,9 +7292,8 @@ class DistributedTest:
|
|||
device_ids=[self.rank],
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
gradient_as_bucket_view=True,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
ddp_model._set_static_graph()
|
||||
random_input = torch.randn(20, 10, device=self.rank)
|
||||
for i in range(10):
|
||||
out = ddp_model(random_input)
|
||||
|
|
@ -7906,8 +7911,8 @@ class DistributedTest:
|
|||
model_static_graph = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[rank],
|
||||
static_graph=True,
|
||||
)
|
||||
model_static_graph._set_static_graph()
|
||||
inp = torch.randn(10, 100)
|
||||
type_mapping = {
|
||||
"list": list,
|
||||
|
|
@ -7959,10 +7964,9 @@ class DistributedTest:
|
|||
model,
|
||||
device_ids=[self.rank],
|
||||
output_device=self.rank,
|
||||
find_unused_parameters=find_unused
|
||||
find_unused_parameters=find_unused,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
if static_graph:
|
||||
ddp._set_static_graph()
|
||||
for i in range(6):
|
||||
out = ddp(inp)
|
||||
self.assertFalse(out[0].requires_grad)
|
||||
|
|
@ -8053,11 +8057,9 @@ class DistributedTest:
|
|||
output_device=self.rank,
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=find_unused,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
if static_graph:
|
||||
ddp._set_static_graph()
|
||||
|
||||
opt = [None for _ in range(3)]
|
||||
for i in range(2):
|
||||
ddp.zero_grad()
|
||||
|
|
|
|||
|
|
@ -106,9 +106,11 @@ class PipeWithDDPTest(RpcAgentTestFixture):
|
|||
layer2
|
||||
)
|
||||
model = Pipe(model, chunks=2, checkpoint=checkpoint)
|
||||
model = DistributedDataParallel(model, find_unused_parameters=find_unused_parameters)
|
||||
if static_graph:
|
||||
model._set_static_graph()
|
||||
model = DistributedDataParallel(
|
||||
model,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
# Ensure inputs are different across ranks to verify that gradient
|
||||
# sync indeed occurs.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user