pytorch/caffe2/python/cnn.py
Yury Zemlyanskiy c2d28fb874 RNNs API simplification
Summary:
This is a first step in improving our RNN story. It provides a wrapper around current RecurrentNetworkOp implementation which infers most of the redundant parameters and makes API much simpler.

Also in order to support general step nets I added an extra argument to the RecurrentNetworkOp.

Future work:

1. Inferring step net output and internal blobs (scratches) sizes and type
2. Avoid accessing blobs by names in c++ part
3. Remove requirement for inputs / output 1:1 correspondence in the step net
4. Make python API support networks with operators like Sum being on the boarder of the Cell net (currently there is an issue with such networks where gradient blobs which are on the side are not explicitly created).

Differential Revision: D4268503

fbshipit-source-id: f8a66491c2b55daa730caeed7e9f2b3921541b49
2016-12-21 09:29:43 -08:00

716 lines
26 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, scope
from caffe2.python.model_helper import ModelHelperBase
from caffe2.proto import caffe2_pb2
class CNNModelHelper(ModelHelperBase):
"""A helper model so we can write CNN models more easily, without having to
manually define parameter initializations and operators separately.
"""
def __init__(self, order="NCHW", name=None,
use_cudnn=True, cudnn_exhaustive_search=False,
ws_nbytes_limit=None, init_params=True,
skip_sparse_optim=False,
param_model=None):
super(CNNModelHelper, self).__init__(
skip_sparse_optim=skip_sparse_optim,
name="CNN" if name is None else name,
init_params=init_params,
param_model=param_model,
)
self.weights = []
self.biases = []
self.order = order
self.use_cudnn = use_cudnn
self.cudnn_exhaustive_search = cudnn_exhaustive_search
self.ws_nbytes_limit = ws_nbytes_limit
if self.order != "NHWC" and self.order != "NCHW":
raise ValueError(
"Cannot understand the CNN storage order %s." % self.order
)
def GetWeights(self, namescope=None):
if namescope is None:
namescope = scope.CurrentNameScope()
if namescope == '':
return self.weights[:]
else:
return [w for w in self.weights if w.GetNameScope() == namescope]
def GetBiases(self, namescope=None):
if namescope is None:
namescope = scope.CurrentNameScope()
if namescope == '':
return self.biases[:]
else:
return [b for b in self.biases if b.GetNameScope() == namescope]
def ImageInput(
self, blob_in, blob_out, **kwargs
):
"""Image Input."""
if self.order == "NCHW":
data, label = self.net.ImageInput(
blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs)
data = self.net.NHWC2NCHW(data, blob_out[0])
else:
data, label = self.net.ImageInput(
blob_in, blob_out, **kwargs)
return data, label
def Conv(
self, blob_in, blob_out, dim_in, dim_out, kernel, weight_init=None,
bias_init=None, **kwargs
):
"""Convolution. We intentionally do not provide odd kernel/stride/pad
settings in order to discourage the use of odd cases.
"""
use_bias = False if ("no_bias" in kwargs and kwargs["no_bias"]) else True
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
weight_shape = (
[dim_out, dim_in, kernel, kernel]
if self.order == "NCHW" else [dim_out, kernel, kernel, dim_in]
)
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=weight_shape,
**weight_init[1]
)
if use_bias:
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
if use_bias:
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
if use_bias:
self.params.extend([weight, bias])
else:
self.params.extend([weight])
self.weights.append(weight)
if use_bias:
self.biases.append(bias)
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
if self.ws_nbytes_limit:
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
inputs = []
if use_bias:
inputs = [blob_in, weight, bias]
else:
inputs = [blob_in, weight]
return self.net.Conv(
inputs,
blob_out,
kernel=kernel,
order=self.order,
**kwargs
)
def ConvTranspose(
self, blob_in, blob_out, dim_in, dim_out, kernel, weight_init=None,
bias_init=None, **kwargs
):
"""ConvTranspose.
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
weight_shape = (
[dim_in, dim_out, kernel, kernel]
if self.order == "NCHW" else [dim_in, kernel, kernel, dim_out]
)
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=weight_shape,
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
if self.ws_nbytes_limit:
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
return self.net.ConvTranspose(
[blob_in, weight, bias],
blob_out,
kernel=kernel,
order=self.order,
**kwargs
)
def GroupConv(
self,
blob_in,
blob_out,
dim_in,
dim_out,
kernel,
weight_init,
bias_init,
group=1,
**kwargs
):
"""Convolution. We intentionally do not provide odd kernel/stride/pad
settings in order to discourage the use of odd cases.
"""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
if self.ws_nbytes_limit:
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
if dim_in % group:
raise ValueError("dim_in should be divisible by group.")
if dim_out % group:
raise ValueError("dim_out should be divisible by group.")
splitted_blobs = self.net.DepthSplit(
blob_in,
['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)],
dimensions=[int(dim_in / group) for i in range(group)],
order=self.order
)
weight_shape = (
[dim_out / group, dim_in / group, kernel, kernel]
if self.order == "NCHW" else
[dim_out / group, kernel, kernel, dim_in / group]
)
# Make sure that the shapes are of int format. Especially for py3 where
# int division gives float output.
weight_shape = [int(v) for v in weight_shape]
conv_blobs = []
for i in range(group):
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_gconv_%d_w' % i,
shape=weight_shape,
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_gconv_%d_b' % i,
shape=[int(dim_out / group)],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_gconv_%d_w' % i, self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_gconv_%d_b' % i, self.param_init_net)
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
conv_blobs.append(
splitted_blobs[i].Conv(
[weight, bias],
blob_out + '_gconv_%d' % i,
kernel=kernel,
order=self.order,
**kwargs
)
)
concat, concat_dims = self.net.Concat(
conv_blobs,
[blob_out, "_" + blob_out + "_concat_dims"],
order=self.order
)
return concat
def _FC_or_packed_FC(
self, op_call, blob_in, blob_out, dim_in, dim_out, weight_init=None,
bias_init=None, **kwargs
):
"""FC"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
if self.init_params:
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=[dim_out, dim_in],
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
else:
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
if 'freeze_bias' in kwargs:
self.params.extend([weight])
else:
self.params.extend([weight, bias])
self.weights.append(weight)
self.biases.append(bias)
return op_call([blob_in, weight, bias], blob_out, **kwargs)
def FC(self, *args, **kwargs):
return self._FC_or_packed_FC(self.net.FC, *args, **kwargs)
def PackedFC(self, *args, **kwargs):
return self._FC_or_packed_FC(self.net.PackedFC, *args, **kwargs)
def FC_Decomp(
self, blob_in, blob_out, dim_in, dim_out,
rank_approx=5, weight_init=None,
bias_init=None, **kwargs
):
"""FC_Decomp version
Here we assume that the rank of original input is bigger than 5.
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
u = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_u',
shape=[dim_out, rank_approx],
**weight_init[1]
)
v = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_v',
shape=[dim_in, rank_approx],
**weight_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
self.params.extend([u, v, bias])
return self.net.FC_Decomp([blob_in, u, v, bias], blob_out, **kwargs)
def FC_Prune(
self, blob_in, blob_out, dim_in, dim_out,
weight_init=None, bias_init=None, mask_init=None,
threshold=0.00001, need_compress_rate=False,
comp_lb=0.05,
**kwargs
):
"""FC_Prune version
Runnable so far. Great!:)
"""
weight_init = weight_init if weight_init else ('XavierFill', {})
bias_init = bias_init if bias_init else ('ConstantFill', {})
mask_init = mask_init if mask_init else ('ConstantFill', {})
blob_out = blob_out or self.net.NextName()
compress_rate = blob_out + '_compress_rate'
if self.init_params:
compress_lb = self.param_init_net.ConstantFill(
[],
blob_out + '_lb',
shape=[1],
value=comp_lb
)
weight = self.param_init_net.__getattr__(weight_init[0])(
[],
blob_out + '_w',
shape=[dim_out, dim_in],
**weight_init[1]
)
mask = self.param_init_net.ConstantFill(
[],
blob_out + '_m',
shape=[dim_out, dim_in],
value=1.0
)
ag_dw = self.param_init_net.__getattr__(mask_init[0])(
[],
blob_out + '_ag_dw',
shape=[dim_out, dim_in],
**mask_init[1]
)
bias = self.param_init_net.__getattr__(bias_init[0])(
[],
blob_out + '_b',
shape=[dim_out, ],
**bias_init[1]
)
mask_seq = self.param_init_net.__getattr__(mask_init[0])(
[],
blob_out + '_mask_seq',
shape=[dim_out, dim_in],
**mask_init[1]
)
thres = self.param_init_net.ConstantFill(
[],
blob_out + '_thres',
shape=[1],
value=threshold
)
else:
compress_lb = core.ScopedBlobReference(
blob_out + '_lb', self.param_init_net)
weight = core.ScopedBlobReference(
blob_out + '_w', self.param_init_net)
bias = core.ScopedBlobReference(
blob_out + '_b', self.param_init_net)
mask = core.ScopedBlobReference(
blob_out + '_m', self.param_init_net)
ag_dw = core.ScopedBlobReference(
blob_out + '_ag_dw', self.param_init_net)
mask_seq = core.ScopedBlobReference(
blob_out + '_mask_seq', self.param_init_net)
thres = core.ScopedBlobReference(
blob_out + '_thres', self.param_init_net)
self.params.extend([weight, bias])
if need_compress_rate:
return self.net.FC_Prune([blob_in, weight, mask,
bias, ag_dw, mask_seq,
thres, compress_lb],
[blob_out, compress_rate], **kwargs)
else:
return self.net.FC_Prune([blob_in, weight, mask,
bias, ag_dw, mask_seq,
thres, compress_lb],
blob_out, **kwargs)
def FC_Sparse(
self, blob_in, blob_out, w_csr, iw, jw, bias,
**kwargs
):
"""FC_Sparse: Only takes in alocated weights"""
if not (w_csr and iw and jw and bias):
print("Warning...")
self.params.extend([w_csr, iw, jw, bias])
return self.net.FC_Sparse([blob_in, w_csr, iw, jw, bias],
blob_out, **kwargs)
def LRN(self, blob_in, blob_out, **kwargs):
"""LRN"""
return self.net.LRN(
blob_in,
[blob_out, "_" + blob_out + "_scale"],
order=self.order,
**kwargs
)[0]
def Dropout(self, blob_in, blob_out, **kwargs):
"""Dropout"""
return self.net.Dropout(
blob_in, [blob_out, "_" + blob_out + "_mask"], **kwargs
)[0]
def MaxPool(self, blob_in, blob_out, **kwargs):
"""Max pooling"""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
return self.net.MaxPool(blob_in, blob_out, order=self.order, **kwargs)
def AveragePool(self, blob_in, blob_out, **kwargs):
"""Average pooling"""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
return self.net.AveragePool(
blob_in,
blob_out,
order=self.order,
**kwargs
)
def Concat(self, blobs_in, blob_out, **kwargs):
"""Depth Concat."""
return self.net.Concat(
blobs_in,
[blob_out, "_" + blob_out + "_concat_dims"],
order=self.order,
**kwargs
)[0]
def DepthConcat(self, blobs_in, blob_out, **kwargs):
"""The old depth concat function - we should move to use concat."""
print("DepthConcat is deprecated. use Concat instead.")
return self.Concat(blobs_in, blob_out, **kwargs)
def PRelu(self, blob_in, blob_out, num_channels=1, slope_init=None,
**kwargs):
"""PRelu"""
slope_init = (
slope_init if slope_init else ('ConstantFill', {'value': 0.25}))
if self.init_params:
slope = self.param_init_net.__getattr__(slope_init[0])(
[],
blob_out + '_slope',
shape=[num_channels],
**slope_init[1]
)
else:
slope = core.ScopedBlobReference(
blob_out + '_slope', self.param_init_net)
self.params.extend([slope])
return self.net.PRelu([blob_in, slope], [blob_out])
def Relu(self, blob_in, blob_out, **kwargs):
"""Relu."""
if self.use_cudnn:
kwargs['engine'] = 'CUDNN'
return self.net.Relu(blob_in, blob_out, order=self.order, **kwargs)
def Transpose(self, blob_in, blob_out, **kwargs):
"""Transpose."""
return self.net.Transpose(blob_in, blob_out, **kwargs)
def Sum(self, blob_in, blob_out, **kwargs):
"""Sum"""
return self.net.Sum(blob_in, blob_out, **kwargs)
def SpatialBN(self, blob_in, blob_out, dim_in, **kwargs):
blob_out = blob_out or self.net.NextName()
# Input: input, scale, bias, est_mean, est_inv_var
# Output: output, running_mean, running_inv_var, saved_mean,
# saved_inv_var
# scale: initialize with ones
# bias: initialize with zeros
# est mean: zero
# est var: ones
def init_blob(value, suffix):
return self.param_init_net.ConstantFill(
[], blob_out + "_" + suffix, shape=[dim_in], value=value)
scale, bias = init_blob(1.0, "s"), init_blob(0.0, "b")
running_mean = init_blob(0.0, "rm")
running_inv_var = init_blob(1.0, "riv")
self.params.extend([scale, bias])
self.computed_params.extend([running_mean, running_inv_var])
self.weights.append(scale)
self.biases.append(bias)
blob_outs = [blob_out, running_mean, running_inv_var,
blob_out + "_sm", blob_out + "_siv"]
if 'is_test' in kwargs and kwargs['is_test']:
blob_outputs = self.net.SpatialBN(
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], [blob_out],
order=self.order, **kwargs)
return blob_outputs
else:
blob_outputs = self.net.SpatialBN(
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], blob_outs,
order=self.order, **kwargs)
# Return the output
return blob_outputs[0]
def Iter(self, blob_out, **kwargs):
if 'device_option' in kwargs:
del kwargs['device_option']
self.param_init_net.ConstantFill(
[], blob_out, shape=[1], value=0, dtype=core.DataType.INT64,
device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
**kwargs)
return self.net.Iter(blob_out, blob_out, **kwargs)
@property
def XavierInit(self):
return ('XavierFill', {})
def ConstantInit(self, value):
return ('ConstantFill', dict(value=value))
@property
def MSRAInit(self):
return ('MSRAFill', {})
@property
def ZeroInit(self):
return ('ConstantFill', {})
def AddWeightDecay(self, weight_decay):
"""Adds a decay to weights in the model.
This is a form of L2 regularization.
Args:
weight_decay: strength of the regularization
"""
if weight_decay <= 0.0:
return
wd = self.param_init_net.ConstantFill([], 'wd', shape=[1],
value=weight_decay)
ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
for param in self.GetWeights():
# Equivalent to: grad += wd * param
grad = self.param_to_grad[param]
self.net.WeightedSum(
[grad, ONE, param, wd],
grad,
)
@property
def CPU(self):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CPU
return device_option
@property
def GPU(self, gpu_id=0):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CUDA
device_option.cuda_gpu_id = gpu_id
return device_option
def LSTM(self, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope=None):
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
scope_name = scope or str(input_blob)
return "{}/{}".format(str(scope_name), str(name))
(hidden_input_blob, cell_input_blob) = initial_states
input_blob = self.FC(input_blob, s("i2h"),
dim_in=dim_in, dim_out=4 * dim_out, axis=2)
step_net = CNNModelHelper(name="LSTM")
step_net.Proto().external_input.extend([
str(seq_lengths),
"input_t",
"timestep",
"hidden_t_prev",
"cell_t_prev",
s("gates_t_w"),
s("gates_t_b"),
])
step_net.Proto().type = "simple"
step_net.Proto().external_output.extend(
["hidden_t", "cell_t", s("gates_t")])
step_net.FC("hidden_t_prev", s("gates_t"),
dim_in=dim_out, dim_out=4 * dim_out, axis=2)
step_net.net.Sum([s("gates_t"), "input_t"], [s("gates_t")])
step_net.net.LSTMUnit(
["cell_t_prev", s("gates_t"), str(seq_lengths), "timestep"],
["hidden_t", "cell_t"])
links = [
("hidden_t_prev", s("hidden"), 0),
("hidden_t", s("hidden"), 1),
("cell_t_prev", s("cell"), 0),
("cell_t", s("cell"), 1),
(s("gates_t"), s("gates"), 0),
("input_t", str(input_blob), 0),
]
link_internal, link_external, link_offset = zip(*links)
# # Initialize params for step net in the parent net
# for op in step_net.param_init_net.Proto().op:
# Set up the backward links
backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
step_net.Proto().op,
{"hidden_t": "hidden_t_grad", "cell_t": "cell_t_grad"})
backward_mapping = {str(k): str(v) for k, v
in backward_mapping.items()}
backward_step_net = core.Net("LSTMBackward")
del backward_step_net.Proto().op[:]
backward_step_net.Proto().op.extend(backward_ops)
backward_links = [
("hidden_t_prev_grad", s("hidden_grad"), 0),
("hidden_t_grad", s("hidden_grad"), 1),
("cell_t_prev_grad", s("cell_grad"), 0),
("cell_t_grad", s("cell_grad"), 1),
(s("gates_t_grad"), s("gates_grad"), 0),
]
backward_link_internal, backward_link_external, \
backward_link_offset = zip(*backward_links)
backward_step_net.Proto().external_input.extend(
["hidden_t_grad", "cell_t_grad"])
backward_step_net.Proto().external_input.extend(
step_net.Proto().external_input)
backward_step_net.Proto().external_input.extend(
step_net.Proto().external_output)
output, _, _, hidden_state, cell_state = self.net.RecurrentNetwork(
[input_blob, seq_lengths,
s("gates_t_w"), s("gates_t_b"),
hidden_input_blob, cell_input_blob],
[s("output"), s("hidden"), s("cell"),
s("hidden_output"), s("cell_output")],
param=[str(p) for p in step_net.params],
param_gradient=[backward_mapping[str(p)] for p in step_net.params],
alias_src=[s("hidden"), s("hidden"), s("cell")],
alias_dst=[s("output"), s("hidden_output"), s("cell_output")],
alias_offset=[1, -1, -1],
recurrent_states=[s("hidden"), s("cell")],
recurrent_inputs=[str(hidden_input_blob), str(cell_input_blob)],
recurrent_sizes=[dim_out, dim_out],
link_internal=link_internal,
link_external=link_external,
link_offset=link_offset,
backward_link_internal=backward_link_internal,
backward_link_external=backward_link_external,
backward_link_offset=backward_link_offset,
backward_alias_src=[s("gates_grad")],
backward_alias_dst=[str(input_blob) + "_grad"],
backward_alias_offset=[0],
scratch=[s("gates")],
backward_scratch=[s("gates_grad")],
scratch_sizes=[4 * dim_out],
step_net=str(step_net.Proto()),
backward_step_net=str(backward_step_net.Proto()),
timestep="timestep")
self.param_init_net.Proto().op.extend(
step_net.param_init_net.Proto().op)
self.params += step_net.params
for p in step_net.params:
if str(p) in backward_mapping:
self.param_to_grad[p] = backward_mapping[str(p)]
self.weights += step_net.weights
self.biases += step_net.biases
return output, hidden_state, cell_state