pytorch/caffe2/python/cnn.py
Aapo Kyrola eddf23ca0f Handle parameters that are computed but not optimized
Summary:
prigoyal sharply noticed a bug in the Resnet models: we have not been checkpointing, nor synchronizing between gpus, the moving average and variance computed by the SpatialBN ops.  Particularly the first problen is serious, since models starting from checkpoint would have started from a null-state for SpatialBN. Not synchronizing with the data parallel model is less tragic since each GPU should see very similar data.

Thus I propose keeping track of "computed params", i.e params that are computed from data but not optimized. I don't know if there are other examples, but SpatialBN's moving avg and var definitely are one.

- I modified the checkpointign for xray model to store those blobs + also ensure the synchronization of those blobs
- I modified data parallel model to broadcast those params from gpu0. I first tried averaging, but hit some NCCL deadlocks ... :(

Differential Revision: D4281265

fbshipit-source-id: 933311afeec4b7e9344a13cf2d38aa939c50ac31
2016-12-15 12:01:28 -08:00

689 lines
25 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, scope
from caffe2.python.model_helper import ModelHelperBase
from caffe2.proto import caffe2_pb2
class CNNModelHelper(ModelHelperBase):
"""A helper model so we can write CNN models more easily, without having to
manually define parameter initializations and operators separately.
"""
def __init__(self, order="NCHW", name=None,
use_cudnn=True, cudnn_exhaustive_search=False,
ws_nbytes_limit=None, init_params=True,
skip_sparse_optim=False):
super(CNNModelHelper, self).__init__(
name="CNN" if name is None else name, init_params=init_params,
skip_sparse_optim=skip_sparse_optim)
self.weights = []
self.biases = []
self.order = order
self.use_cudnn = use_cudnn
self.cudnn_exhaustive_search = cudnn_exhaustive_search
self.ws_nbytes_limit = ws_nbytes_limit
if self.order != "NHWC" and self.order != "NCHW":
raise ValueError(
"Cannot understand the CNN storage order %s." % self.order
)
def GetWeights(self, namescope=None):
if namescope is None:
namescope = scope.CurrentNameScope()
if namescope == '':
return self.weights[:]
else:
return [w for w in self.weights if w.GetNameScope() == namescope]
def GetBiases(self, namescope=None):
if namescope is None:
namescope = scope.CurrentNameScope()
if namescope == '':
return self.biases[:]
else:
return [b for b in self.biases if b.GetNameScope() == namescope]
def ImageInput(
self, blob_in, blob_out, **kwargs
):
"""Image Input."""
if self.order == "NCHW":
data, label = self.net.ImageInput(
blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs)
data = self.net.NHWC2NCHW(data, blob_out[0])
else:
data, label = self.net.ImageInput(
blob_in, blob_out, **kwargs)
return data, label
def Conv(
self, blob_in, blob_out, dim_in, dim_out, kernel, weight_init=None,
bias_init=None, **kwargs
):
"""Convolution. We intentionally do not provide odd kernel/stride/pad
settings in order to discourage the use of odd cases.
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
weight_shape = (
[dim_out, dim_in, kernel, kernel]
if self.order == "NCHW" else [dim_out, kernel, kernel, dim_in]
)
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=weight_shape,
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
if self.ws_nbytes_limit:
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
return self.net.Conv(
[blob_in, weight, bias],
blob_out,
kernel=kernel,
order=self.order,
**kwargs
)
def ConvTranspose(
self, blob_in, blob_out, dim_in, dim_out, kernel, weight_init=None,
bias_init=None, **kwargs
):
"""ConvTranspose.
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
weight_shape = (
[dim_in, dim_out, kernel, kernel]
if self.order == "NCHW" else [dim_in, kernel, kernel, dim_out]
)
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=weight_shape,
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
if self.ws_nbytes_limit:
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
return self.net.ConvTranspose(
[blob_in, weight, bias],
blob_out,
kernel=kernel,
order=self.order,
**kwargs
)
def GroupConv(
self,
blob_in,
blob_out,
dim_in,
dim_out,
kernel,
weight_init,
bias_init,
group=1,
**kwargs
):
"""Convolution. We intentionally do not provide odd kernel/stride/pad
settings in order to discourage the use of odd cases.
"""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
if self.ws_nbytes_limit:
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
if dim_in % group:
raise ValueError("dim_in should be divisible by group.")
splitted_blobs = self.net.DepthSplit(
blob_in,
['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)],
dimensions=[dim_in / group for i in range(group)],
order=self.order
)
weight_shape = (
[dim_out / group, dim_in / group, kernel, kernel]
if self.order == "NCHW" else
[dim_out / group, kernel, kernel, dim_in / group]
)
conv_blobs = []
for i in range(group):
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_gconv_%d_w' % i,
shape=weight_shape,
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_gconv_%d_b' % i,
shape=[dim_out / group],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_gconv_%d_w' % i, self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_gconv_%d_b' % i, self.param_init_net)
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
conv_blobs.append(
splitted_blobs[i].Conv(
[weight, bias],
blob_out + '_gconv_%d' % i,
kernel=kernel,
order=self.order,
**kwargs
)
)
concat, concat_dims = self.net.Concat(
conv_blobs,
[blob_out, "_" + blob_out + "_concat_dims"],
order=self.order
)
return concat
def _FC_or_packed_FC(
self, op_call, blob_in, blob_out, dim_in, dim_out, weight_init=None,
bias_init=None, **kwargs
):
"""FC"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=[dim_out, dim_in],
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
if 'freeze_bias' in kwargs:
self.params.extend([weight])
else:
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
return op_call([blob_in, weight, bias], blob_out, **kwargs)
def FC(self, *args, **kwargs):
return self._FC_or_packed_FC(self.net.FC, *args, **kwargs)
def PackedFC(self, *args, **kwargs):
return self._FC_or_packed_FC(self.net.PackedFC, *args, **kwargs)
def FC_Decomp(
self, blob_in, blob_out, dim_in, dim_out,
rank_approx=5, weight_init=None,
bias_init=None, **kwargs
):
"""FC_Decomp version
Here we assume that the rank of original input is bigger than 5.
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
u = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_u',
shape=[dim_out, rank_approx],
**weight_init[1]
)
v = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_v',
shape=[dim_in, rank_approx],
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
self.params.extend([u, v, bias])
return self.net.FC_Decomp([blob_in, u, v, bias], blob_out, **kwargs)
def FC_Prune(
self, blob_in, blob_out, dim_in, dim_out,
weight_init=None, bias_init=None, mask_init=None,
threshold=0.00001, need_compress_rate=False,
comp_lb=0.05,
**kwargs
):
"""FC_Prune version
Runnable so far. Great!:)
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
mask_init = mask_init if mask_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
compress_rate = blob_out + '_compress_rate'
if self.init_params:
compress_lb = self.param_init_net.ConstantFill(
[],
blob_out + '_lb',
shape=[1],
value=comp_lb
)
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=[dim_out, dim_in],
**weight_init[1]
)
mask = self.param_init_net.ConstantFill(
[],
blob_out + '_m',
shape=[dim_out, dim_in],
value=1.0
)
ag_dw = self.param_init_net.__getattr__(mask_init[0])(
[],
blob_out + '_ag_dw',
shape=[dim_out, dim_in],
**mask_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
mask_seq = self.param_init_net.__getattr__(mask_init[0])(
[],
blob_out + '_mask_seq',
shape=[dim_out, dim_in],
**mask_init[1]
)
thres = self.param_init_net.ConstantFill(
[],
blob_out + '_thres',
shape=[1],
value=threshold
)
else:
compress_lb = core.ScopedBlobReference(
blob_out + '_lb', self.param_init_net)
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
mask = core.ScopedBlobReference(
blob_out + '_m', self.param_init_net)
ag_dw = core.ScopedBlobReference(
blob_out + '_ag_dw', self.param_init_net)
mask_seq = core.ScopedBlobReference(
blob_out + '_mask_seq', self.param_init_net)
thres = core.ScopedBlobReference(
blob_out + '_thres', self.param_init_net)
self.params.extend([weight, bias])
if need_compress_rate:
return self.net.FC_Prune([blob_in, weight, mask,
bias, ag_dw, mask_seq,
thres, compress_lb],
[blob_out, compress_rate], **kwargs)
else:
return self.net.FC_Prune([blob_in, weight, mask,
bias, ag_dw, mask_seq,
thres, compress_lb],
blob_out, **kwargs)
def FC_Sparse(
self, blob_in, blob_out, w_csr, iw, jw, bias,
**kwargs
):
"""FC_Sparse: Only takes in alocated weights"""
if not (w_csr and iw and jw and bias):
print("Warning...")
self.params.extend([w_csr, iw, jw, bias])
return self.net.FC_Sparse([blob_in, w_csr, iw, jw, bias],
blob_out, **kwargs)
def LRN(self, blob_in, blob_out, **kwargs):
"""LRN"""
return self.net.LRN(
blob_in,
[blob_out, "_" + blob_out + "_scale"],
order=self.order,
**kwargs
)[0]
def Dropout(self, blob_in, blob_out, **kwargs):
"""Dropout"""
return self.net.Dropout(
blob_in, [blob_out, "_" + blob_out + "_mask"], **kwargs
)[0]
def MaxPool(self, blob_in, blob_out, **kwargs):
"""Max pooling"""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs)
def AveragePool(self, blob_in, blob_out, **kwargs):
"""Average pooling"""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
return self.net.AveragePool(
blob_in,
blob_out,
order=self.order,
**kwargs
)
def Concat(self, blobs_in, blob_out, **kwargs):
"""Depth Concat."""
return self.net.Concat(
blobs_in,
[blob_out, "_" + blob_out + "_concat_dims"],
order=self.order,
**kwargs
)[0]
def DepthConcat(self, blobs_in, blob_out, **kwargs):
"""The old depth concat function - we should move to use concat."""
print("DepthConcat is deprecated. use Concat instead.")
return self.Concat(blobs_in, blob_out, **kwargs)
def PRelu(self, blob_in, blob_out, num_channels=1, slope_init=None,
**kwargs):
"""PRelu"""
slope_init = (
slope_init if slope_init else ('ConstantFill', {'value': 0.25}))
if self.init_params:
slope = self.param_init_net.__getattr__(slope_init[0])(
[],
blob_out + '_slope',
shape=[num_channels],
**slope_init[1]
)
else:
slope = core.ScopedBlobReference(
blob_out + '_slope', self.param_init_net)
self.params.extend([slope])
return self.net.PRelu([blob_in, slope], [blob_out])
def Relu(self, blob_in, blob_out, **kwargs):
"""Relu."""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
return self.net.Relu(blob_in, blob_out, order=self.order, **kwargs)
def Transpose(self, blob_in, blob_out, **kwargs):
"""Transpose."""
return self.net.Transpose(blob_in, blob_out, **kwargs)
def Sum(self, blob_in, blob_out, **kwargs):
"""Sum"""
return self.net.Sum(blob_in, blob_out, **kwargs)
def SpatialBN(self, blob_in, blob_out, dim_in, **kwargs):
blob_out = blob_out or self.net.NextName()
# Input: input, scale, bias, est_mean, est_inv_var
# Output: output, running_mean, running_inv_var, saved_mean,
# saved_inv_var
# scale: initialize with ones
# bias: initialize with zeros
# est mean: zero
# est var: ones
def init_blob(value, suffix):
return self.param_init_net.ConstantFill(
[], blob_out + "_" + suffix, shape=[dim_in], value=value)
scale, bias = init_blob(1.0, "s"), init_blob(0.0, "b")
running_mean = init_blob(0.0, "rm")
running_inv_var = init_blob(1.0, "riv")
self.params.extend([scale, bias])
self.computed_params.extend([running_mean, running_inv_var])
self.weights.append(scale)
self.biases.append(bias)
blob_outs = [blob_out, running_mean, running_inv_var,
blob_out + "_sm", blob_out + "_siv"]
if 'is_test' in kwargs and kwargs['is_test']:
blob_outputs = self.net.SpatialBN(
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], [blob_out],
order=self.order, **kwargs)
return blob_outputs
else:
blob_outputs = self.net.SpatialBN(
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], blob_outs,
order=self.order, **kwargs)
# Return the output
return blob_outputs[0]
def Iter(self, blob_out, **kwargs):
if 'device_option' in kwargs:
del kwargs['device_option']
self.param_init_net.ConstantFill(
[], blob_out, shape=[1], value=0, dtype=core.DataType.INT64,
device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
**kwargs)
return self.net.Iter(blob_out, blob_out, **kwargs)
@property
def XavierInit(self):
return ('XavierFill', {})
def ConstantInit(self, value):
return ('ConstantFill', dict(value=value))
@property
def MSRAInit(self):
return ('MSRAFill', {})
@property
def ZeroInit(self):
return ('ConstantFill', {})
def AddWeightDecay(self, weight_decay):
"""Adds a decay to weights in the model.
This is a form of L2 regularization.
Args:
weight_decay: strength of the regularization
"""
if weight_decay <= 0.0:
return
wd = self.param_init_net.ConstantFill([], 'wd', shape=[1],
value=weight_decay)
ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
for param in self.GetWeights():
# Equivalent to: grad += wd * param
grad = self.param_to_grad[param]
self.net.WeightedSum(
[grad, ONE, param, wd],
grad,
)
@property
def CPU(self):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CPU
return device_option
@property
def GPU(self, gpu_id=0):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CUDA
device_option.cuda_gpu_id = gpu_id
return device_option
def LSTM(self, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope=None):
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
scope_name = scope or str(input_blob)
return "{}/{}".format(str(scope_name), str(name))
(hidden_input_blob, cell_input_blob) = initial_states
input_blob = self.FC(input_blob, s("i2h"),
dim_in=dim_in, dim_out=4 * dim_out, axis=2)
step_net = CNNModelHelper(name="LSTM")
step_net.Proto().external_input.extend([
str(seq_lengths),
"input_t",
"timestep",
"hidden_t_prev",
"cell_t_prev",
s("gates_t_w"),
s("gates_t_b"),
])
step_net.Proto().type = "simple"
step_net.Proto().external_output.extend(
["hidden_t", "cell_t", s("gates_t")])
step_net.FC("hidden_t_prev", s("gates_t"),
dim_in=dim_out, dim_out=4 * dim_out, axis=2)
step_net.net.Sum([s("gates_t"), "input_t"], [s("gates_t")])
step_net.net.LSTMUnit(
["cell_t_prev", s("gates_t"), str(seq_lengths), "timestep"],
["hidden_t", "cell_t"])
links = [
("hidden_t_prev", s("hidden"), 0),
("hidden_t", s("hidden"), 1),
("cell_t_prev", s("cell"), 0),
("cell_t", s("cell"), 1),
(s("gates_t"), s("gates"), 0),
("input_t", str(input_blob), 0),
]
link_internal, link_external, link_offset = zip(*links)
# # Initialize params for step net in the parent net
# for op in step_net.param_init_net.Proto().op:
# Set up the backward links
backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
step_net.Proto().op,
{"hidden_t": "hidden_t_grad", "cell_t": "cell_t_grad"})
backward_mapping = {str(k): str(v) for k, v
in backward_mapping.items()}
backward_step_net = core.Net("LSTMBackward")
del backward_step_net.Proto().op[:]
backward_step_net.Proto().op.extend(backward_ops)
backward_links = [
("hidden_t_prev_grad", s("hidden_grad"), 0),
("hidden_t_grad", s("hidden_grad"), 1),
("cell_t_prev_grad", s("cell_grad"), 0),
("cell_t_grad", s("cell_grad"), 1),
(s("gates_t_grad"), s("gates_grad"), 0),
]
backward_link_internal, backward_link_external, \
backward_link_offset = zip(*backward_links)
backward_step_net.Proto().external_input.extend(
["hidden_t_grad", "cell_t_grad"])
backward_step_net.Proto().external_input.extend(
step_net.Proto().external_input)
backward_step_net.Proto().external_input.extend(
step_net.Proto().external_output)
output, _, _, hidden_state, cell_state = self.net.RecurrentNetwork(
[input_blob, seq_lengths,
s("gates_t_w"), s("gates_t_b"),
hidden_input_blob, cell_input_blob],
[s("output"), s("hidden"), s("cell"),
s("hidden_output"), s("cell_output")],
param=[str(p) for p in step_net.params],
param_gradient=[backward_mapping[str(p)] for p in step_net.params],
alias_src=[s("hidden"), s("hidden"), s("cell")],
alias_dst=[s("output"), s("hidden_output"), s("cell_output")],
alias_offset=[1, -1, -1],
recurrent_states=[s("hidden"), s("cell")],
recurrent_inputs=[str(hidden_input_blob), str(cell_input_blob)],
recurrent_sizes=[dim_out, dim_out],
link_internal=link_internal,
link_external=link_external,
link_offset=link_offset,
backward_link_internal=backward_link_internal,
backward_link_external=backward_link_external,
backward_link_offset=backward_link_offset,
backward_alias_src=[s("gates_grad")],
backward_alias_dst=[str(input_blob) + "_grad"],
backward_alias_offset=[0],
scratch=[s("gates")],
backward_scratch=[s("gates_grad")],
scratch_sizes=[4 * dim_out],
step_net=str(step_net.Proto()),
backward_step_net=str(backward_step_net.Proto()),
timestep="timestep")
self.param_init_net.Proto().op.extend(
step_net.param_init_net.Proto().op)
self.params += step_net.params
for p in step_net.params:
if str(p) in backward_mapping:
self.param_to_grad[p] = backward_mapping[str(p)]
self.weights += step_net.weights
self.biases += step_net.biases
return output, hidden_state, cell_state