mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Added optional support for using activation blobs for sharing as well. Doing this change revealed an non-optimal implementation in the blob sharing: we need to prefer to reuse freeblobs by prefering those blobs that are already shared by many other blobs. Otherwise the memory usage can increase when the pool of 'free blobs' grows. Also, my first version only passed "free blobs" (i.e blobs in recycling pool) down the first branch when operators forked. But now we pass those blobs that were not used by the first branch down the second branch and so on. Also added support for blob size information in the heuristic. This uses the shape inference mechanism. I had to also do some small tweaks: - use Sum() operator as a way to match shapes of blobs that had otherwise unknown shapes. This is related to the Sum() operator that is added to combine multiple incoming gradient inputs (with _autosplit gradients). - a couple of random shape inference fixes This reduces the Resnet-50 memory usage on 64 batch from 9.45 Gig to 8.5 Gig. For a 32 batch, the memory usage is 4330 MiB, down from 4800 MB, compared to Torch's 6856MiB (thanks prigoyal for checking this for me). This is unfortunately quite a bunch to review... Reviewed By: asaadaldien Differential Revision: D4393909 fbshipit-source-id: 9c7c94125f96512bea80463ebcb63c215ef95ff9
773 lines
29 KiB
Python
773 lines
29 KiB
Python
## @package data_parallel_model
|
|
# Module caffe2.python.data_parallel_model
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from collections import OrderedDict
|
|
import logging
|
|
import copy
|
|
|
|
from caffe2.python import model_helper, dyndep, scope, workspace, core, memonger
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/nccl:nccl_ops")
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops")
|
|
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/gloo:gloo_ops_gpu")
|
|
|
|
log = logging.getLogger("data_parallel_model")
|
|
log.setLevel(logging.INFO)
|
|
|
|
|
|
def Parallelize_GPU(
|
|
model_helper_obj,
|
|
input_builder_fun,
|
|
forward_pass_builder_fun,
|
|
param_update_builder_fun,
|
|
devices=range(0, workspace.NumCudaDevices()),
|
|
rendezvous=None,
|
|
net_type='dag',
|
|
broadcast_computed_params=True,
|
|
optimize_gradient_memory=False,
|
|
):
|
|
'''
|
|
Function to create a model that can run on many GPUs.
|
|
model_helper_obj: an object of ModelHelper, such as CNNModelHelper
|
|
input_builder_fun:
|
|
Function that adds the input operators
|
|
Note: Remember to instantiate reader outside of this
|
|
function so all GPUs share same reader object.
|
|
Signature: input_builder_fun(model)
|
|
forward_pass_builder_fun:
|
|
Function to add the operators to the model.
|
|
Must return list of loss-blob references that
|
|
are used to build the gradient. Loss scale parameter
|
|
is passed, as you should scale the loss of your model
|
|
by 1.0 / the total number of gpus.
|
|
Signature: forward_pass_builder_fun(model, loss_scale)
|
|
param_update_builder_fun:
|
|
Function that adds operators that are run after
|
|
gradient update, such as updating the weights and
|
|
weight decaying.
|
|
Signature: param_update_builder_fun(model)
|
|
devices: List of GPU ids, such as [0, 1, 2, 3],
|
|
rendezvous: used for rendezvous in distributed computation, if None
|
|
then only one node is used. To create rendezvous,
|
|
use <TBD>.
|
|
net_type: Network type
|
|
optimize_gradient_memory: whether to apply 'memonger' to share blobs
|
|
in gradient computation to reduce memory footprint
|
|
|
|
'''
|
|
log.info("Parallelizing model for devices: {}".format(devices))
|
|
extra_workers = 8 if rendezvous is not None else 0 # best-guess
|
|
model_helper_obj.net.Proto().num_workers = len(devices) * 4 + extra_workers
|
|
model_helper_obj.net.Proto().type = net_type
|
|
|
|
# Store some information in the model -- a bit ugly
|
|
model_helper_obj._devices = devices
|
|
model_helper_obj._rendezvous = rendezvous
|
|
model_helper_obj._grad_names = []
|
|
|
|
assert isinstance(model_helper_obj, model_helper.ModelHelper)
|
|
|
|
# Keep track of params that were in the model before: they are not
|
|
# data parallel, so we need to handle them separately
|
|
non_datapar_params = copy.copy(model_helper_obj.params)
|
|
|
|
# Add input and model
|
|
log.info("Create input and model training operators")
|
|
|
|
losses_by_gpu = {}
|
|
num_shards = 1 if rendezvous is None else rendezvous['num_shards']
|
|
loss_scale = 1.0 / (len(devices) * num_shards)
|
|
|
|
for device in devices:
|
|
device_opt = core.DeviceOption(caffe2_pb2.CUDA, device)
|
|
with core.DeviceScope(device_opt):
|
|
with core.NameScope("gpu_{}".format(device)):
|
|
log.info("Model for GPU: {}".format(device))
|
|
input_builder_fun(model_helper_obj)
|
|
losses = forward_pass_builder_fun(model_helper_obj, loss_scale)
|
|
# Losses are not needed for test net
|
|
if param_update_builder_fun is not None:
|
|
assert isinstance(losses, list), \
|
|
'Model builder function must return list of loss blobs'
|
|
for loss in losses:
|
|
assert isinstance(loss, core.BlobReference), \
|
|
'Model builder func must return list of loss blobs'
|
|
|
|
losses_by_gpu[device] = losses
|
|
_ValidateParams(model_helper_obj.params)
|
|
|
|
# Create parameter map
|
|
model_helper_obj._device_grouped_blobs =\
|
|
_GroupByDevice(devices, model_helper_obj.params, non_datapar_params)
|
|
|
|
# computed params
|
|
computed_params_grouped =\
|
|
_GroupByDevice(devices, model_helper_obj.computed_params, [])
|
|
model_helper_obj._device_grouped_blobs.update(computed_params_grouped)
|
|
|
|
model_helper_obj._param_names =\
|
|
model_helper_obj._device_grouped_blobs.keys()
|
|
model_helper_obj._computed_param_names = computed_params_grouped.keys()
|
|
|
|
if (param_update_builder_fun is None):
|
|
log.info("Parameter update function not defined --> only forward")
|
|
_InferBlobDevice(model_helper_obj)
|
|
return
|
|
|
|
log.info("Adding gradient operators")
|
|
_AddGradientOperators(devices, model_helper_obj, losses_by_gpu)
|
|
|
|
_ValidateParams(model_helper_obj.params)
|
|
|
|
# Group gradients by device and register to blob lookup
|
|
param_to_grad = model_helper_obj.param_to_grad
|
|
grads_ordered = [param_to_grad[p] for p in
|
|
model_helper_obj.params if p in param_to_grad]
|
|
non_datapar_grads = [param_to_grad[p] for p in non_datapar_params]
|
|
|
|
gradients_grouped = _GroupByDevice(
|
|
devices,
|
|
grads_ordered,
|
|
non_datapar_grads
|
|
)
|
|
model_helper_obj._device_grouped_blobs.update(gradients_grouped)
|
|
model_helper_obj._grad_names = gradients_grouped.keys()
|
|
model_helper_obj._losses_by_gpu = losses_by_gpu
|
|
|
|
_InferBlobDevice(model_helper_obj)
|
|
|
|
log.info("Add gradient all-reduces for SyncSGD")
|
|
if broadcast_computed_params:
|
|
_BroadcastComputedParams(devices, model_helper_obj, rendezvous)
|
|
|
|
if len(model_helper_obj._grad_names) > 0:
|
|
_AllReduceGradients(devices, model_helper_obj, rendezvous)
|
|
else:
|
|
log.info("NOTE: Param builder function did not create any parameters.")
|
|
|
|
|
|
log.info("Post-iteration operators for updating params")
|
|
num_shards = 1 if rendezvous is None else rendezvous['num_shards']
|
|
# The following check is necessary for ring reduce to work
|
|
if rendezvous is not None:
|
|
assert num_shards > 1, \
|
|
"Please use more than one shard for distributed training"
|
|
for device in devices:
|
|
device_opt = core.DeviceOption(caffe2_pb2.CUDA, device)
|
|
with core.DeviceScope(device_opt):
|
|
with core.NameScope("gpu_{}".format(device)):
|
|
param_update_builder_fun(model_helper_obj)
|
|
|
|
_InferBlobDevice(model_helper_obj)
|
|
_AnalyzeOperators(model_helper_obj)
|
|
|
|
# Configure dagnet to run with only one worker on the first iteration,
|
|
# to prevent concurrency problems with allocs and nccl.
|
|
arg = model_helper_obj.Proto().arg.add()
|
|
arg.name = "first_iter_only_one_worker"
|
|
arg.i = 1
|
|
|
|
# Add initial parameter syncs
|
|
log.info("Add initial parameter sync")
|
|
if (rendezvous is not None):
|
|
_AddDistributedParameterSync(
|
|
devices,
|
|
model_helper_obj,
|
|
model_helper_obj.param_init_net,
|
|
model_helper_obj.param_init_net,
|
|
rendezvous,
|
|
)
|
|
|
|
_SyncParams(devices, model_helper_obj, model_helper_obj.param_init_net)
|
|
|
|
if optimize_gradient_memory:
|
|
_OptimizeGradientMemoryDEPRECATED(model_helper_obj, losses_by_gpu, devices)
|
|
|
|
|
|
def _AddGradientOperators(devices, model, losses_by_gpu):
|
|
def create_grad(lossp):
|
|
return model.ConstantFill(lossp, str(lossp) + "_grad", value=1.0)
|
|
|
|
loss_grad = {}
|
|
# Explicitly need to create gradients on each GPU
|
|
for gpu_id in devices:
|
|
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
|
|
with core.DeviceScope(device):
|
|
for l in losses_by_gpu[gpu_id]:
|
|
lg = create_grad(l)
|
|
loss_grad[str(l)] = str(lg)
|
|
|
|
model.AddGradientOperators(loss_grad)
|
|
|
|
|
|
def ExtractPredictorNet(model, inputs, outputs, device):
|
|
'''
|
|
Returns (net, params) that can be exported to be used as a prediction
|
|
net.
|
|
'''
|
|
master_device = model._devices[0]
|
|
prefix = "gpu_{}/".format(master_device)
|
|
prefix_inputs = [prefix + str(b) for b in inputs]
|
|
prefix_outputs = [prefix + str(b) for b in outputs]
|
|
predictor_net = model_helper.ExtractPredictorNet(
|
|
net_proto=model.net.Proto(),
|
|
input_blobs=prefix_inputs,
|
|
output_blobs=prefix_outputs,
|
|
device=device,
|
|
renames={
|
|
a: b
|
|
for (a, b) in zip(prefix_inputs + prefix_outputs, inputs + outputs)
|
|
}
|
|
)
|
|
|
|
params = set(predictor_net.Proto().external_input) - set(inputs)
|
|
return (predictor_net, params)
|
|
|
|
|
|
def FinalizeAfterCheckpoint(model, blobs, sync_iter=True):
|
|
if not hasattr(model, "_checkpoint_net"):
|
|
uniq_blob_names = [stripParamName(p) for p in blobs]
|
|
|
|
# Synchronize to the blob lookup map, as the provided
|
|
# blobs might have non-parameters, such as momemtum blobs.
|
|
log.info("Creating checkpoint synchronization net")
|
|
devices = model.GetDevices()
|
|
for name in uniq_blob_names:
|
|
if name not in model._device_grouped_blobs:
|
|
grouped = {
|
|
d:
|
|
core.BlobReference("gpu_{}{}{}".format(
|
|
d,
|
|
scope._NAMESCOPE_SEPARATOR,
|
|
name)
|
|
) for d in devices}
|
|
model._device_grouped_blobs[name] = grouped
|
|
|
|
model._checkpoint_net = core.Net("checkpoint_sync_net")
|
|
model._checkpoint_net.RunAllOnGPU()
|
|
|
|
if (model._rendezvous is not None):
|
|
checkpoint_init_net = core.Net("checkpoint_init_net")
|
|
checkpoint_init_net.RunAllOnGPU()
|
|
_AddDistributedParameterSync(
|
|
devices,
|
|
model,
|
|
checkpoint_init_net,
|
|
model._checkpoint_net,
|
|
model._rendezvous,
|
|
uniq_blob_names,
|
|
)
|
|
workspace.RunNetOnce(checkpoint_init_net)
|
|
|
|
# Setup sync of initial params
|
|
_SyncParams(devices, model, model._checkpoint_net, uniq_blob_names)
|
|
|
|
# Sync ITER -- which is in CPU scope
|
|
if sync_iter:
|
|
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
|
for gpu_idx in devices[1:]:
|
|
model._checkpoint_net.Copy(
|
|
"gpu_{}/ITER".format(devices[0]),
|
|
"gpu_{}/ITER".format(gpu_idx),
|
|
)
|
|
workspace.CreateNet(model._checkpoint_net)
|
|
|
|
# Run the sync
|
|
log.info("Run checkpoint net")
|
|
workspace.RunNet(model._checkpoint_net.Proto().name)
|
|
|
|
|
|
def _Broadcast(devices, model, net, param):
|
|
# TODO(akyrola): replace with NCCLBroadcast when it's working
|
|
# Copy params from gpu_0 to other
|
|
master_gpu = devices[0]
|
|
for gpu_idx in devices[1:]:
|
|
if _IsGPUBlob(model, param):
|
|
device_opt = core.DeviceOption(caffe2_pb2.CUDA, gpu_idx)
|
|
else:
|
|
device_opt = core.DeviceOption(caffe2_pb2.CPU, 0)
|
|
with core.DeviceScope(device_opt):
|
|
net.Copy(
|
|
model._device_grouped_blobs[param][master_gpu],
|
|
model._device_grouped_blobs[param][gpu_idx]
|
|
)
|
|
|
|
|
|
def _SyncParams(devices, model, net, unique_param_names=None):
|
|
if unique_param_names is None:
|
|
unique_param_names = model._param_names
|
|
|
|
for param in unique_param_names:
|
|
_Broadcast(devices, model, net, param)
|
|
|
|
|
|
def _AddDistributedParameterSync(
|
|
devices,
|
|
model,
|
|
init_net,
|
|
net,
|
|
rendezvous,
|
|
uniq_param_names=None,
|
|
):
|
|
if uniq_param_names is None:
|
|
uniq_param_names = model._param_names
|
|
|
|
device_opt = core.DeviceOption(caffe2_pb2.CUDA, devices[0])
|
|
|
|
# ITER is in CPU scope :(
|
|
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
|
comm_world = init_net.CreateCommonWorld(
|
|
rendezvous['kv_handler'],
|
|
"iter_cw",
|
|
name=net.Proto().name + ".iter_cw_op",
|
|
size=rendezvous['num_shards'],
|
|
rank=rendezvous['shard_id'],
|
|
engine=rendezvous['engine'],
|
|
)
|
|
net.Broadcast(
|
|
inputs=[comm_world, "gpu_{}/ITER".format(devices[0])],
|
|
outputs=["gpu_{}/ITER".format(devices[0])],
|
|
engine=rendezvous['engine'],
|
|
)
|
|
|
|
# Create a single common world for all broadcast operations.
|
|
# This is not a problem since they are executed sequentially.
|
|
comm_world = None
|
|
for param_name in sorted(uniq_param_names):
|
|
param = model._device_grouped_blobs[param_name][devices[0]]
|
|
|
|
def broadcast(comm_world, param):
|
|
if comm_world is None:
|
|
comm_world = init_net.CreateCommonWorld(
|
|
rendezvous['kv_handler'],
|
|
"broadcast_cw",
|
|
name=net.Proto().name + ".broadcast_cw_op",
|
|
size=rendezvous['num_shards'],
|
|
rank=rendezvous['shard_id'],
|
|
engine=rendezvous['engine'],
|
|
)
|
|
net.Broadcast(
|
|
inputs=[comm_world, param],
|
|
outputs=[param],
|
|
engine=rendezvous['engine'],
|
|
)
|
|
return comm_world
|
|
|
|
if rendezvous['engine'] == 'GLOO':
|
|
with core.DeviceScope(device_opt):
|
|
comm_world = broadcast(comm_world, param)
|
|
else:
|
|
# Copy between GPU and CPU
|
|
with core.DeviceScope(device_opt):
|
|
param_cpu = net.CopyGPUToCPU(param, str(param) + "cpu")
|
|
with core.DeviceOption(caffe2_pb2.CPU):
|
|
comm_world = broadcast(comm_world, param_cpu)
|
|
with core.DeviceScope(device_opt):
|
|
net.CopyCPUToGPU(param_cpu, param)
|
|
|
|
|
|
def _AllReduceGradients(devices, model, rendezvous):
|
|
if rendezvous is None:
|
|
_AllReduceGradientsSingleHost(devices, model)
|
|
else:
|
|
_AllReduceGradientsDistributed(devices, model, rendezvous)
|
|
|
|
|
|
def _AllReduceGradientsDistributed(
|
|
devices,
|
|
model,
|
|
rendezvous,
|
|
):
|
|
num_workers = model.net.Proto().num_workers
|
|
assert num_workers > 1, "Please specify more than 1 worker"
|
|
all_reduce_engine = rendezvous['engine']
|
|
|
|
# Make list of gradients in reverse order
|
|
reverse_ordered_grads = _GetReverseOrderedGrads(model)
|
|
|
|
master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, devices[0])
|
|
reducing_device_opt = master_device_opt
|
|
|
|
# We need to specify a partial order using control_input to ensure
|
|
# progress (all machines need to do same allreduce in parallel)
|
|
num_controls = min(4, num_workers - 1)
|
|
cyclical_controls = []
|
|
|
|
# Since num_controls determines the partial ordering of
|
|
# allreduces, there is no need for more common world instances
|
|
# than there are parallel allreduce operations.
|
|
num_comm_worlds = num_controls
|
|
cyclical_comm_worlds = []
|
|
|
|
counter = 0
|
|
nccl_control_blob = None
|
|
|
|
# Note: sorted order to ensure each host puts the operators in
|
|
# same order.
|
|
for grad_name in reverse_ordered_grads:
|
|
master_grad = model._device_grouped_blobs[grad_name][devices[0]]
|
|
grads_group = model._device_grouped_blobs[grad_name].values()
|
|
|
|
assert master_grad in grads_group
|
|
|
|
# Remark: NCCLReduce does not support in-place modifications
|
|
# so we need a temporary gradient blob
|
|
reduced_grad = str(master_grad) + "_red"
|
|
|
|
control_input = None if len(cyclical_controls) < num_controls \
|
|
else cyclical_controls[counter % num_controls]
|
|
comm_world = None if len(cyclical_comm_worlds) < num_comm_worlds \
|
|
else cyclical_comm_worlds[counter % num_comm_worlds]
|
|
|
|
def allreduce(comm_world, grads):
|
|
with core.DeviceScope(reducing_device_opt):
|
|
if comm_world is None:
|
|
comm_number = len(cyclical_comm_worlds)
|
|
comm_world = model.param_init_net.CreateCommonWorld(
|
|
rendezvous['kv_handler'],
|
|
"allreduce_{}_cw".format(comm_number),
|
|
name="allreduce_{}_cw_op".format(comm_number),
|
|
size=rendezvous['num_shards'],
|
|
rank=rendezvous['shard_id'],
|
|
engine=rendezvous['engine'],
|
|
)
|
|
model.net.Allreduce(
|
|
inputs=[comm_world] + grads,
|
|
outputs=grads,
|
|
name=grad_name,
|
|
engine=all_reduce_engine,
|
|
control_input=control_input,
|
|
)
|
|
return comm_world
|
|
|
|
if rendezvous['engine'] == 'GLOO':
|
|
# With Gloo cross GPU and cross machine allreduce
|
|
# can be executed in a single operation
|
|
comm_world = allreduce(comm_world, grads_group)
|
|
control_output = grads_group[0]
|
|
else:
|
|
# Step 1: sum gradients from local GPUs to master GPU
|
|
with core.DeviceScope(master_device_opt):
|
|
model.ConstantFill(master_grad, reduced_grad, value=0.0)
|
|
|
|
# Temp fix since NCCLReduce does not work
|
|
model.net.NCCLAllreduce(
|
|
grads_group,
|
|
grads_group,
|
|
control_input=nccl_control_blob,
|
|
)
|
|
nccl_control_blob = grads_group[0]
|
|
model.net.Copy(master_grad, reduced_grad)
|
|
|
|
# Step 2: allreduce between all hosts, between master GPUs
|
|
comm_world = allreduce(comm_world, [reduced_grad])
|
|
control_output = reduced_grad
|
|
|
|
with core.DeviceScope(master_device_opt):
|
|
model.net.Copy(reduced_grad, master_grad)
|
|
|
|
# Step 3: broadcast locally
|
|
_Broadcast(devices, model, model.net, grad_name)
|
|
|
|
if len(cyclical_controls) < num_controls:
|
|
cyclical_controls.append(control_output)
|
|
else:
|
|
cyclical_controls[counter % num_controls] = control_output
|
|
|
|
if len(cyclical_comm_worlds) < num_comm_worlds:
|
|
cyclical_comm_worlds.append(comm_world)
|
|
else:
|
|
assert cyclical_comm_worlds[counter % num_comm_worlds] == comm_world
|
|
|
|
counter += 1
|
|
|
|
|
|
def _AllReduceGradientsSingleHost(devices, model):
|
|
"""Performs NCCL AllReduce to distribute gradients to all the GPUs."""
|
|
|
|
if len(devices) == 1:
|
|
return
|
|
|
|
# Gradients in reverse order
|
|
reverse_ordered_grads = _GetReverseOrderedGrads(model)
|
|
assert(len(reverse_ordered_grads) > 0)
|
|
|
|
# Now we need to Allreduce gradients on all the GPUs.
|
|
# Pick GPU #0 as a master GPU.
|
|
master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, devices[0])
|
|
last_out = None
|
|
|
|
for grad_name in reverse_ordered_grads:
|
|
# Group by grads for reduce.
|
|
grads_group = model._device_grouped_blobs[grad_name].values()
|
|
assert len(grads_group) == len(devices), \
|
|
"Each GPU from {}, should have a copy of {}.".format(
|
|
devices, grad_name)
|
|
|
|
if _IsGPUBlob(model, grad_name):
|
|
with core.DeviceScope(master_device_opt):
|
|
if not isinstance(grads_group[0], core.GradientSlice):
|
|
model.NCCLAllreduce(
|
|
grads_group,
|
|
grads_group,
|
|
control_input=last_out,
|
|
)
|
|
|
|
# last_out is used to serialize the execution of nccls
|
|
last_out = grads_group[0]
|
|
else:
|
|
# Sparse gradients: all-gather for indices and values
|
|
master_ns = "gpu_{}".format(devices[0])
|
|
grad_idx_concat, _ = model.net.Concat(
|
|
[g.indices for g in grads_group],
|
|
["{}/{}_index_concat".format(master_ns, grad_name),
|
|
"{}/{}_index_splitinfo".format(master_ns, grad_name)],
|
|
axis=0,
|
|
name="note:data_parallel_model")
|
|
for gpu, g in model._device_grouped_blobs[grad_name].items():
|
|
device_opt = core.DeviceOption(caffe2_pb2.CUDA, gpu)
|
|
with core.DeviceScope(device_opt):
|
|
model.Copy(grad_idx_concat, g.indices)
|
|
|
|
grad_val_concat, _ = model.net.Concat(
|
|
[g.values for g in grads_group],
|
|
["{}/{}_val_concat".format(master_ns, grad_name),
|
|
"{}/{}_val_splitinfo".format(master_ns, grad_name)],
|
|
axis=0, name="note:data_parallel_model")
|
|
for gpu, g in model._device_grouped_blobs[grad_name].items():
|
|
device_opt = core.DeviceOption(caffe2_pb2.CUDA, gpu)
|
|
with core.DeviceScope(device_opt):
|
|
model.Copy(grad_val_concat, g.values)
|
|
|
|
else:
|
|
assert isinstance(grads_group[0], core.GradientSlice), \
|
|
"Synchronizing gradient slices not supported"
|
|
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
|
|
# Poor man's allreduce
|
|
model.Sum(grads_group, grads_group[0])
|
|
_Broadcast(devices, model, grad_name)
|
|
|
|
|
|
def _BroadcastComputedParams(devices, model, rendezvous):
|
|
if rendezvous is None:
|
|
_BroadcastComputedParamsSingleHost(devices, model)
|
|
else:
|
|
_BroadcastComputedParamsDistributed(devices, model, rendezvous)
|
|
|
|
|
|
def _BroadcastComputedParamsDistributed(
|
|
devices,
|
|
model,
|
|
rendezvous,
|
|
):
|
|
_BroadcastComputedParamsSingleHost(devices, model)
|
|
log.warn("Distributed computed params all-reduce not implemented yet")
|
|
|
|
|
|
def _BroadcastComputedParamsSingleHost(devices, model):
|
|
'''
|
|
Average computed params over all devices
|
|
'''
|
|
if len(devices) == 1:
|
|
return
|
|
|
|
for param_name in model._computed_param_names:
|
|
# Copy from master to others -- averaging would be perhaps better,
|
|
# but currently NCCLAllReduce is too prone to deadlock
|
|
_Broadcast(devices, model, model.net, param_name)
|
|
|
|
|
|
def _GetReverseOrderedGrads(model):
|
|
'''
|
|
Returns the gradients in reverse order (namespace stripped),
|
|
for the optimal synchronization order.
|
|
'''
|
|
return list(reversed(model._grad_names))
|
|
|
|
|
|
# A helper function to extract a parameter's name
|
|
def stripParamName(param):
|
|
# Format is "a/b/c/d" -> "b/c/d"
|
|
if isinstance(param, core.GradientSlice):
|
|
return stripParamName(param.indices) + ":" + stripParamName(param.values)
|
|
else:
|
|
name = str(param)
|
|
return name[name.index(scope._NAMESCOPE_SEPARATOR) + 1:]
|
|
|
|
|
|
def _AnalyzeOperators(model):
|
|
'''
|
|
Look at all the operators and check that they do not cross device scopes
|
|
'''
|
|
for op in model.Proto().op:
|
|
if "NCCL" in op.type or "Copy" in op.type or "Concat" in op.type:
|
|
continue
|
|
if "Allreduce" in op.type and "GLOO" in op.engine:
|
|
continue
|
|
|
|
op_dev = op.device_option
|
|
op_gpu = op_dev.cuda_gpu_id
|
|
|
|
# This avoids failing on operators that are only for CPU
|
|
if op_dev.device_type == caffe2_pb2.CPU:
|
|
continue
|
|
|
|
namescope = "gpu_{}/".format(op_gpu)
|
|
for inp in list(op.input) + list(op.output):
|
|
if inp.startswith("gpu_") and not inp.startswith(namescope):
|
|
raise Exception(
|
|
"Blob {} of op {}, should have namescope {}. Op: {}".format(
|
|
inp, op.type, "gpu_{}/".format(op_gpu), str(op),
|
|
))
|
|
|
|
|
|
def _InferBlobDevice(model):
|
|
'''
|
|
Assign blob to device option based on the operator outputing it
|
|
'''
|
|
mapping = {}
|
|
|
|
def map_ops(proto):
|
|
for op in proto.op:
|
|
for b in list(op.input) + list(op.output):
|
|
mapping[b] = op.device_option
|
|
if op.type.startswith('RecurrentNetwork'):
|
|
import google.protobuf.text_format as protobuftx
|
|
step_args = [a for a in op.arg if a.name.endswith("step_net")]
|
|
for step_arg in step_args:
|
|
step_proto = caffe2_pb2.NetDef()
|
|
protobuftx.Merge(step_arg.s, step_proto)
|
|
map_ops(step_proto)
|
|
map_ops(model.net.Proto())
|
|
|
|
model._blob_to_device = mapping
|
|
|
|
|
|
def _IsGPUBlob(model, blob_name):
|
|
if blob_name in model._blob_to_device:
|
|
return model._blob_to_device[blob_name].device_type == caffe2_pb2.CUDA
|
|
else:
|
|
blob_name = "gpu_{}/{}".format(model._devices[0], blob_name)
|
|
if blob_name not in model._blob_to_device:
|
|
return True
|
|
return model._blob_to_device[blob_name].device_type == caffe2_pb2.CUDA
|
|
|
|
|
|
def _GroupByDevice(devices, params, non_data_params):
|
|
'''
|
|
Groups blobs by device, returning a map of [blobname] = {0: BlobRef, 1: ..}.
|
|
Returns ordered dictionary, ensuring the original order.
|
|
'''
|
|
grouped = OrderedDict()
|
|
# Only consider params that were created to be "data parallel"
|
|
params = params[len(non_data_params):]
|
|
assert len(params) % len(devices) == 0,\
|
|
"There should be equal number of params per device"
|
|
|
|
num_params_per_device = int(len(params) / len(devices))
|
|
|
|
for i, p in enumerate(params):
|
|
assert isinstance(p, core.BlobReference) or \
|
|
isinstance(p, core.GradientSlice), \
|
|
"Param {} is not BlobReference or GradientSlice".format(p)
|
|
|
|
name = stripParamName(p)
|
|
gpuid = devices[i // num_params_per_device]
|
|
|
|
if isinstance(p, core.BlobReference):
|
|
assert "gpu_{}/".format(gpuid) in p.GetNameScope(),\
|
|
"Param {} expected to have namescope 'gpu_{}'".format(str(p), gpuid)
|
|
else:
|
|
assert "gpu_{}/".format(gpuid) in p.indices.GetNameScope(),\
|
|
"Indices {} expected to have namescope 'gpu_{}'".format(str(p), gpuid)
|
|
assert "gpu_{}/".format(gpuid) in p.values.GetNameScope(),\
|
|
"Values {} expected to have namescope 'gpu_{}'".format(str(p), gpuid)
|
|
|
|
if name not in grouped:
|
|
grouped[name] = {}
|
|
grouped[name][gpuid] = p
|
|
|
|
# Confirm consistency
|
|
for j, (p, ps) in enumerate(grouped.items()):
|
|
assert \
|
|
len(ps) == len(devices), \
|
|
"Param {} does not have value for each device (only {}: {})".format(
|
|
p, len(ps), ps,
|
|
)
|
|
# Ensure ordering
|
|
if (ps[devices[0]] != params[j]):
|
|
log.error("Params: {}".format(params))
|
|
log.error("Grouped: {}".format(grouped.keys()))
|
|
assert ps[devices[0]] == params[j], \
|
|
"Incorrect ordering: {}".format(ps)
|
|
|
|
return grouped
|
|
|
|
|
|
def _ValidateParams(params):
|
|
set_params = set(params)
|
|
if len(params) > len(set_params):
|
|
dupes = []
|
|
sp = sorted(params)
|
|
for j, p in enumerate(sp):
|
|
if j > 0 and params[j - 1] == p:
|
|
dupes.append(p)
|
|
|
|
assert len(params) == len(set_params), \
|
|
"Duplicate entries in params: {}".format(dupes)
|
|
|
|
|
|
def _OptimizeGradientMemoryDEPRECATED(model, losses_by_gpu, devices):
|
|
log.warning("------- DEPRECATED API, please use " +
|
|
"data_parallel_model.OptimizeGradientMemory() ----- ")
|
|
for device in devices:
|
|
namescope = "gpu_{}/".format(device)
|
|
model.net._net = memonger.share_grad_blobs(
|
|
model.net,
|
|
losses_by_gpu[device],
|
|
set(model.param_to_grad.values()),
|
|
namescope,
|
|
share_activations=False,
|
|
)
|
|
|
|
|
|
def OptimizeGradientMemory(model,
|
|
input_shapes,
|
|
excluded_blobs,
|
|
recycle_activations):
|
|
"""
|
|
Optimize memory usage of the backward pass by recycling blobs for gradient
|
|
inputs that have been 'used'.
|
|
input_shapes: dict of blob name to shape for the inputs of the model.
|
|
Pass empty dictionary if not known.
|
|
excluded_blobs: list of blobs that cannot be recycled. These are blobs
|
|
that you will access externally.
|
|
recycle_activations: whether to also recycle forward pass activations
|
|
"""
|
|
input_shapes_all_devices = {}
|
|
for b, shp in input_shapes.items():
|
|
for d in model._devices:
|
|
input_shapes_all_devices["gpu_{}/{}".format(d, b)] = shp
|
|
|
|
(shapes, types) = workspace.InferShapesAndTypes(
|
|
[model.param_init_net, model.net],
|
|
input_shapes_all_devices,
|
|
)
|
|
|
|
for device in model._devices:
|
|
namescope = "gpu_{}/".format(device)
|
|
excluded_blobs_by_device = set([namescope + b for b in excluded_blobs])
|
|
model.net._net = memonger.share_grad_blobs(
|
|
model.net,
|
|
model._losses_by_gpu[device],
|
|
set(model.param_to_grad.values()),
|
|
namescope,
|
|
dont_share_blobs=excluded_blobs_by_device,
|
|
share_activations=recycle_activations,
|
|
blob_shapes=shapes,
|
|
)
|