[wip] enhance DDPSink to work for general outputs (#57073)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57073

Enhances use of DDPSink to work for all output types DDP supports as per https://github.com/pytorch/pytorch/issues/55876.

TODO: Add additional testing for tuple, list, dict return types
ghstack-source-id: 128726768

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D27756985

fbshipit-source-id: 2e0408649fb2d6a46d6c33155a24c4c1723dd799
This commit is contained in:
Rohan Varma 2021-05-12 09:44:05 -07:00 committed by Facebook GitHub Bot
parent 4faa427383
commit c52700dbcd
2 changed files with 110 additions and 12 deletions

View File

@ -10,6 +10,7 @@ from typing import NamedTuple
import torch
import torch.distributed as dist
from torch.autograd import Variable, Function
from torch.utils._pytree import tree_flatten, tree_unflatten
RPC_AVAILABLE = False
if dist.is_available():
@ -106,19 +107,20 @@ class _DDPUnevenInputsConfig(NamedTuple):
ddp_join_divide_by_initial_world_size: bool
ddp_join_throw_on_early_termination: bool
# Add a DDPSink to queue call back of out-most backward/graph task,
# Add a DDPSink to run various functions when backwards starts, such as
# queueing call back of out-most backward/graph task,
# this helps call back is fired after all gradients' calculation
# is completed.
class DDPSink(Function):
class _DDPSink(Function):
@staticmethod
def forward(ctx, input, reducer):
def forward(ctx, reducer, *inputs):
ctx.reducer = reducer
return input
return inputs
@staticmethod
def backward(ctx, input):
def backward(ctx, *grad_outputs):
Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce)
return input, None
return (None, *grad_outputs)
class DistributedDataParallel(Module):
r"""Implements distributed data parallelism that is based on
@ -812,12 +814,20 @@ class DistributedDataParallel(Module):
else:
self.require_forward_param_sync = False
# TODO. Right now we add this sink for static_graph training only. once
# this feature is stable, we will add this sink for all cases. E.g.
# This sink can help capture more accuracte backward start time as well.
if self.static_graph and self.num_iterations == 1:
output = DDPSink.apply(output, self.reducer)
return output
# TODO. Right now we add this sink for static_graph training only. once
# this feature is stable, we will add this sink for all cases. E.g.
# This sink can help capture more accuracte backward start time as well.
if self.static_graph and self.num_iterations == 1:
# Need to grab list of tensors from user output in order to pass
# to custom autograd function.
output_tensor_list, treespec = tree_flatten(output)
passthrough_tensor_list = _DDPSink.apply(
self.reducer,
*output_tensor_list
)
# Reconstruct output data structure.
output = tree_unflatten(passthrough_tensor_list, treespec)
return output
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

View File

@ -6375,3 +6375,91 @@ class DistributedTest:
# Ensure sync does not occur in eval() mode.
all_gather_calls = get_profiling_event("all_gather", prof)
self.assertEqual([], all_gather_calls)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
)
def test_ddp_static_graph_nested_types(self):
# Tests for static graph training when outputs are not just tensors
# but can be (nested) tuple, list, dict, etc.
rank = self.rank
torch.cuda.set_device(rank)
class NestedOutputModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(100, 1, bias=False)
def forward(self, inp, output_type):
if output_type == "tuple":
return (
self.lin(inp),
(
self.lin(inp),
self.lin(inp),
),
)
elif output_type == "list":
return [
self.lin(inp),
[
self.lin(inp),
self.lin(inp),
],
]
elif output_type == "dict":
return {
"a": self.lin(inp),
"b": {
"c": self.lin(inp),
},
}
def get_loss(model_output):
loss = 0.0
if isinstance(model_output, torch.Tensor):
return model_output.sum()
elif isinstance(model_output, dict):
for value in model_output.values():
loss += get_loss(value)
elif isinstance(model_output, tuple) or isinstance(model_output, list):
for x in model_output:
loss += get_loss(x)
else:
raise ValueError(f"Unknown model output type {type(model_output)}")
return loss
model = NestedOutputModule().cuda(rank)
model_static_graph = copy.deepcopy(model)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
)
model_static_graph = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
)
model_static_graph._set_static_graph()
inp = torch.randn(10, 100)
type_mapping = {
"list": list,
"tuple": tuple,
"dict": dict,
}
for output_type in type_mapping.keys():
for i in range(6):
out = model(inp, output_type=output_type)
loss = get_loss(out)
loss.backward()
self._model_step(model)
out_static = model_static_graph(inp, output_type=output_type)
self.assertTrue(isinstance(out_static, type_mapping[output_type]))
loss_static = get_loss(out_static)
loss_static.backward()
self._model_step(model_static_graph)
for (p, p_static) in zip(
model.parameters(), model_static_graph.parameters()
):
self.assertEqual(p, p_static)