mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[wip] enhance DDPSink to work for general outputs (#57073)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57073 Enhances use of DDPSink to work for all output types DDP supports as per https://github.com/pytorch/pytorch/issues/55876. TODO: Add additional testing for tuple, list, dict return types ghstack-source-id: 128726768 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D27756985 fbshipit-source-id: 2e0408649fb2d6a46d6c33155a24c4c1723dd799
This commit is contained in:
parent
4faa427383
commit
c52700dbcd
|
|
@ -10,6 +10,7 @@ from typing import NamedTuple
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
RPC_AVAILABLE = False
|
||||
if dist.is_available():
|
||||
|
|
@ -106,19 +107,20 @@ class _DDPUnevenInputsConfig(NamedTuple):
|
|||
ddp_join_divide_by_initial_world_size: bool
|
||||
ddp_join_throw_on_early_termination: bool
|
||||
|
||||
# Add a DDPSink to queue call back of out-most backward/graph task,
|
||||
# Add a DDPSink to run various functions when backwards starts, such as
|
||||
# queueing call back of out-most backward/graph task,
|
||||
# this helps call back is fired after all gradients' calculation
|
||||
# is completed.
|
||||
class DDPSink(Function):
|
||||
class _DDPSink(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, reducer):
|
||||
def forward(ctx, reducer, *inputs):
|
||||
ctx.reducer = reducer
|
||||
return input
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, input):
|
||||
def backward(ctx, *grad_outputs):
|
||||
Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce)
|
||||
return input, None
|
||||
return (None, *grad_outputs)
|
||||
|
||||
class DistributedDataParallel(Module):
|
||||
r"""Implements distributed data parallelism that is based on
|
||||
|
|
@ -812,12 +814,20 @@ class DistributedDataParallel(Module):
|
|||
else:
|
||||
self.require_forward_param_sync = False
|
||||
|
||||
# TODO. Right now we add this sink for static_graph training only. once
|
||||
# this feature is stable, we will add this sink for all cases. E.g.
|
||||
# This sink can help capture more accuracte backward start time as well.
|
||||
if self.static_graph and self.num_iterations == 1:
|
||||
output = DDPSink.apply(output, self.reducer)
|
||||
return output
|
||||
# TODO. Right now we add this sink for static_graph training only. once
|
||||
# this feature is stable, we will add this sink for all cases. E.g.
|
||||
# This sink can help capture more accuracte backward start time as well.
|
||||
if self.static_graph and self.num_iterations == 1:
|
||||
# Need to grab list of tensors from user output in order to pass
|
||||
# to custom autograd function.
|
||||
output_tensor_list, treespec = tree_flatten(output)
|
||||
passthrough_tensor_list = _DDPSink.apply(
|
||||
self.reducer,
|
||||
*output_tensor_list
|
||||
)
|
||||
# Reconstruct output data structure.
|
||||
output = tree_unflatten(passthrough_tensor_list, treespec)
|
||||
return output
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
|
|
|
|||
|
|
@ -6375,3 +6375,91 @@ class DistributedTest:
|
|||
# Ensure sync does not occur in eval() mode.
|
||||
all_gather_calls = get_profiling_event("all_gather", prof)
|
||||
self.assertEqual([], all_gather_calls)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(
|
||||
BACKEND != "nccl" and BACKEND != "gloo",
|
||||
"Only Nccl & Gloo backend support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_static_graph_nested_types(self):
|
||||
# Tests for static graph training when outputs are not just tensors
|
||||
# but can be (nested) tuple, list, dict, etc.
|
||||
rank = self.rank
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
class NestedOutputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin = nn.Linear(100, 1, bias=False)
|
||||
|
||||
def forward(self, inp, output_type):
|
||||
if output_type == "tuple":
|
||||
return (
|
||||
self.lin(inp),
|
||||
(
|
||||
self.lin(inp),
|
||||
self.lin(inp),
|
||||
),
|
||||
)
|
||||
elif output_type == "list":
|
||||
return [
|
||||
self.lin(inp),
|
||||
[
|
||||
self.lin(inp),
|
||||
self.lin(inp),
|
||||
],
|
||||
]
|
||||
elif output_type == "dict":
|
||||
return {
|
||||
"a": self.lin(inp),
|
||||
"b": {
|
||||
"c": self.lin(inp),
|
||||
},
|
||||
}
|
||||
|
||||
def get_loss(model_output):
|
||||
loss = 0.0
|
||||
if isinstance(model_output, torch.Tensor):
|
||||
return model_output.sum()
|
||||
elif isinstance(model_output, dict):
|
||||
for value in model_output.values():
|
||||
loss += get_loss(value)
|
||||
elif isinstance(model_output, tuple) or isinstance(model_output, list):
|
||||
for x in model_output:
|
||||
loss += get_loss(x)
|
||||
else:
|
||||
raise ValueError(f"Unknown model output type {type(model_output)}")
|
||||
return loss
|
||||
|
||||
model = NestedOutputModule().cuda(rank)
|
||||
model_static_graph = copy.deepcopy(model)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[rank],
|
||||
)
|
||||
model_static_graph = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[rank],
|
||||
)
|
||||
model_static_graph._set_static_graph()
|
||||
inp = torch.randn(10, 100)
|
||||
type_mapping = {
|
||||
"list": list,
|
||||
"tuple": tuple,
|
||||
"dict": dict,
|
||||
}
|
||||
for output_type in type_mapping.keys():
|
||||
for i in range(6):
|
||||
out = model(inp, output_type=output_type)
|
||||
loss = get_loss(out)
|
||||
loss.backward()
|
||||
self._model_step(model)
|
||||
out_static = model_static_graph(inp, output_type=output_type)
|
||||
self.assertTrue(isinstance(out_static, type_mapping[output_type]))
|
||||
loss_static = get_loss(out_static)
|
||||
loss_static.backward()
|
||||
self._model_step(model_static_graph)
|
||||
for (p, p_static) in zip(
|
||||
model.parameters(), model_static_graph.parameters()
|
||||
):
|
||||
self.assertEqual(p, p_static)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user