pytorch/caffe2/python/recurrent.py
Aapo Kyrola 1e5140aa76 option to recompute blobs backward pass with massive memory savings
Summary:
This diff adds an option to recurrent_net to define some cell blobs to be recomputed on backward step, and thus they don't need to be stored in the step workspace. This is done by modifying the backward step to automatically include all operators that are needed to produce the output that is to be recomputed, and by storing those blobs in a shared workspace. To enable the shared workspace, i had to modify the stepworkspaces blob to also store a forward shared workspace. Making it a class field won't work since the lifecycle of the blob does not match the lifecycle of the operator.

For basic LSTM, the performance hit is quite modest (about 15% with one setting, but your mileage might vary. For Attention models, I am sure this is beneficial as computing the attention blobs is not expensive.

For basic LSTM, the memory saving is wonderful: each forward workspace only has 4 bytes (for timestep).

I also modified the neural_mt LSTM Cells, but there is no test available, so I am not 100% sure I did it correctly. Please have a look.

Added options to LSTM, MILSTM and LSTMAttention to enable memory mode.

Reviewed By: urikz

Differential Revision: D4853890

fbshipit-source-id: d8d0e0e75a5330d174fbfa39b96d8e4e8c446baa
2017-04-11 13:03:48 -07:00

873 lines
32 KiB
Python

## @package recurrent
# Module caffe2.python.recurrent
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import random
import numpy as np
from caffe2.python import core, workspace
from caffe2.python.scope import CurrentNameScope
from caffe2.python.cnn import CNNModelHelper
from caffe2.python.attention import (
apply_regular_attention,
apply_recurrent_attention,
AttentionType,
)
def recurrent_net(
net, cell_net, inputs, initial_cell_inputs,
links, timestep=None, scope=None, outputs_with_grads=(0,),
recompute_blobs_on_backward=None,
):
'''
net: the main net operator should be added to
cell_net: cell_net which is executed in a recurrent fasion
inputs: sequences to be fed into the recurrent net. Currently only one input
is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
Format for each input is:
(cell_net_input_name, external_blob_with_data)
links: a dictionary from cell_net input names in moment t+1 and
output names of moment t. Currently we assume that each output becomes
an input for the next timestep.
timestep: name of the timestep blob to be used. If not provided "timestep"
is used.
scope: Internal blobs are going to be scoped in a format
<scope_name>/<blob_name>
If not provided we generate a scope name automatically
outputs_with_grads : position indices of output blobs which will receive
error gradient (from outside recurrent network) during backpropagation
recompute_blobs_on_backward: specify a list of blobs that will be
recomputed for backward pass, and thus need not to be
stored for each forward timestep.
'''
assert len(inputs) == 1, "Only one input blob is supported so far"
# Validate scoping
for einp in cell_net.Proto().external_input:
assert einp.startswith(CurrentNameScope()), \
'''
Cell net external inputs are not properly scoped, use
AddScopedExternalInputs() when creating them
'''
input_blobs = [str(i[0]) for i in inputs]
initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
op_name = net.NextName('recurrent')
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
scope_name = op_name if scope is None else scope
return "{}/{}".format(str(scope_name), str(name))
# determine inputs that are considered to be references
# it is those that are not referred to in inputs or initial_cell_inputs
known_inputs = map(str, input_blobs + initial_input_blobs)
known_inputs += [str(x[0]) for x in initial_cell_inputs]
if timestep is not None:
known_inputs.append(str(timestep))
references = [
core.BlobReference(b) for b in cell_net.Proto().external_input
if b not in known_inputs]
inner_outputs = list(cell_net.Proto().external_output)
# These gradients are expected to be available during the backward pass
inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
# compute the backward pass of the cell net
backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
cell_net.Proto().op, inner_outputs_map)
backward_mapping = {str(k): v for k, v in backward_mapping.items()}
backward_cell_net = core.Net("RecurrentBackwardStep")
del backward_cell_net.Proto().op[:]
if recompute_blobs_on_backward is not None:
# Insert operators to re-compute the specified blobs.
# They are added in the same order as for the forward pass, thus
# the order is correct.
recompute_blobs_on_backward = set(
[str(b) for b in recompute_blobs_on_backward]
)
for op in cell_net.Proto().op:
if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
backward_cell_net.Proto().op.extend([op])
assert set(op.output).issubset(recompute_blobs_on_backward), \
'Outputs {} are output by op but not recomputed: {}'.format(
set(op.output) - recompute_blobs_on_backward,
op
)
else:
recompute_blobs_on_backward = set()
backward_cell_net.Proto().op.extend(backward_ops)
# compute blobs used but not defined in the backward pass
backward_ssa, backward_blob_versions = core.get_ssa(
backward_cell_net.Proto())
undefined = core.get_undefined_blobs(backward_ssa)
# also add to the output list the intermediate outputs of fwd_step that
# are used by backward.
ssa, blob_versions = core.get_ssa(cell_net.Proto())
scratches = [
blob for (blob, ver) in blob_versions.items()
if ver > 0 and
blob in undefined and
blob not in cell_net.Proto().external_output]
backward_cell_net.Proto().external_input.extend(scratches)
all_inputs = [i[1] for i in inputs] + [
x[1] for x in initial_cell_inputs] + references
all_outputs = []
cell_net.Proto().type = 'prof_dag'
backward_cell_net.Proto().type = 'simple'
# Internal arguments used by RecurrentNetwork operator
# Links are in the format blob_name, recurrent_states, offset.
# In the moment t we know that corresponding data block is at
# t + offset position in the recurrent_states tensor
forward_links = []
backward_links = []
# Aliases are used to expose outputs to external world
# Format (internal_blob, external_blob, offset)
# Negative offset stands for going from the end,
# positive - from the beginning
aliases = []
# States held inputs to the cell net
recurrent_states = []
for cell_input, _ in initial_cell_inputs:
cell_input = str(cell_input)
# Recurrent_states is going to be (T + 1) x ...
# It stores all inputs and outputs of the cell net over time.
# Or their gradients in the case of the backward pass.
state = s(cell_input + "_states")
states_grad = state + "_grad"
cell_output = links[str(cell_input)]
forward_links.append((cell_input, state, 0))
forward_links.append((cell_output, state, 1))
backward_links.append((cell_output + "_grad", states_grad, 1))
backward_cell_net.Proto().external_input.append(
str(cell_output) + "_grad")
aliases.append((state, cell_output + "_all", 1))
aliases.append((state, cell_output + "_last", -1))
all_outputs.extend([cell_output + "_all", cell_output + "_last"])
recurrent_states.append(state)
recurrent_input_grad = cell_input + "_grad"
if not backward_blob_versions.get(recurrent_input_grad, 0):
# If nobody writes to this recurrent input gradient, we need
# to make sure it gets to the states grad blob after all.
# We do this by using backward_links which triggers an alias
# This logic is being used for example in a SumOp case
backward_links.append(
(backward_mapping[cell_input], states_grad, 0))
else:
backward_links.append((cell_input + "_grad", states_grad, 0))
for reference in references:
# Similar to above, in a case of a SumOp we need to write our parameter
# gradient to an external blob. In this case we can be sure that
# reference + "_grad" is a correct parameter name as we know how
# RecurrentNetworkOp gradient schema looks like.
reference_grad = reference + "_grad"
if (reference in backward_mapping and
reference_grad != str(backward_mapping[reference])):
# We can use an Alias because after each timestep
# RNN op adds value from reference_grad into and _acc blob
# which accumulates gradients for corresponding parameter accross
# timesteps. Then in the end of RNN op these two are being
# swaped and reference_grad blob becomes a real blob instead of
# being an alias
backward_cell_net.Alias(
backward_mapping[reference], reference_grad)
for input_t, input_blob in inputs:
forward_links.append((str(input_t), str(input_blob), 0))
backward_links.append((
backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
))
backward_cell_net.Proto().external_input.extend(
cell_net.Proto().external_input)
backward_cell_net.Proto().external_input.extend(
cell_net.Proto().external_output)
def unpack_triple(x):
if x:
a, b, c = zip(*x)
return a, b, c
return [], [], []
# Splitting to separate lists so we can pass them to c++
# where we ensemle them back
link_internal, link_external, link_offset = unpack_triple(forward_links)
backward_link_internal, backward_link_external, backward_link_offset = \
unpack_triple(backward_links)
alias_src, alias_dst, alias_offset = unpack_triple(aliases)
params = [x for x in references if x in backward_mapping.keys()]
recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
global _workspace_seq
results = net.RecurrentNetwork(
all_inputs,
all_outputs + [s("step_workspaces")],
param=map(all_inputs.index, params),
alias_src=alias_src,
alias_dst=map(str, alias_dst),
alias_offset=alias_offset,
recurrent_states=recurrent_states,
initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
link_internal=map(str, link_internal),
link_external=map(str, link_external),
link_offset=link_offset,
backward_link_internal=map(str, backward_link_internal),
backward_link_external=map(str, backward_link_external),
backward_link_offset=backward_link_offset,
step_net=str(cell_net.Proto()),
backward_step_net=str(backward_cell_net.Proto()),
timestep="timestep" if timestep is None else str(timestep),
outputs_with_grads=outputs_with_grads,
recompute_blobs_on_backward=map(str, recompute_blobs_on_backward)
)
# The last output is a list of step workspaces,
# which is only needed internally for gradient propogation
return results[:-1]
def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope, outputs_with_grads=(0,), return_params=False,
memory_optimization=False):
'''
Adds a standard LSTM recurrent network operator to a model.
model: CNNModelHelper object new operators would be added to
input_blob: the input sequence in a format T x N x D
where T is sequence size, N - batch size and D - input dimention
seq_lengths: blob containing sequence lengths which would be passed to
LSTMUnit operator
initial_states: a tupple of (hidden_input_blob, cell_input_blob)
which are going to be inputs to the cell net on the first iteration
dim_in: input dimention
dim_out: output dimention
outputs_with_grads : position indices of output blobs which will receive
external error gradient during backpropagation
return_params: if True, will return a dictionary of parameters of the LSTM
memory_optimization: if enabled, the LSTM step is recomputed on backward step
so that we don't need to store forward activations for each
timestep. Saves memory with cost of computation.
'''
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
return "{}/{}".format(str(scope), str(name))
""" initial bulk fully-connected """
input_blob = model.FC(
input_blob, s('i2h'), dim_in=dim_in, dim_out=4 * dim_out, axis=2)
""" the step net """
step_model = CNNModelHelper(name='lstm_cell', param_model=model)
input_t, timestep, cell_t_prev, hidden_t_prev = (
step_model.net.AddScopedExternalInputs(
'input_t', 'timestep', 'cell_t_prev', 'hidden_t_prev'))
gates_t = step_model.FC(
hidden_t_prev, s('gates_t'), dim_in=dim_out,
dim_out=4 * dim_out, axis=2)
step_model.net.Sum([gates_t, input_t], gates_t)
hidden_t, cell_t = step_model.net.LSTMUnit(
[hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
[s('hidden_t'), s('cell_t')],
)
step_model.net.AddExternalOutputs(cell_t, hidden_t)
""" recurrent network """
(hidden_input_blob, cell_input_blob) = initial_states
output, last_output, all_states, last_state = recurrent_net(
net=model.net,
cell_net=step_model.net,
inputs=[(input_t, input_blob)],
initial_cell_inputs=[
(hidden_t_prev, hidden_input_blob),
(cell_t_prev, cell_input_blob),
],
links={
hidden_t_prev: hidden_t,
cell_t_prev: cell_t,
},
timestep=timestep,
scope=scope,
outputs_with_grads=outputs_with_grads,
recompute_blobs_on_backward=[gates_t] if memory_optimization else None
)
if return_params:
params = {
'input':
{'weights': input_blob + "_w",
'biases': input_blob + '_b'},
'recurrent': {'weights': gates_t + "_w",
'biases': gates_t + '_b'}
}
return output, last_output, all_states, last_state, params
else:
return output, last_output, all_states, last_state
def GetLSTMParamNames():
weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
return {'weights': weight_params, 'biases': bias_params}
def InitFromLSTMParams(lstm_pblobs, param_values):
'''
Set the parameters of LSTM based on predefined values
'''
weight_params = GetLSTMParamNames()['weights']
bias_params = GetLSTMParamNames()['biases']
for input_type in param_values.keys():
weight_values = [param_values[input_type][w].flatten() for w in weight_params]
wmat = np.array([])
for w in weight_values:
wmat = np.append(wmat, w)
bias_values = [param_values[input_type][b].flatten() for b in bias_params]
bm = np.array([])
for b in bias_values:
bm = np.append(bm, b)
weights_blob = lstm_pblobs[input_type]['weights']
bias_blob = lstm_pblobs[input_type]['biases']
cur_weight = workspace.FetchBlob(weights_blob)
cur_biases = workspace.FetchBlob(bias_blob)
workspace.FeedBlob(
weights_blob,
wmat.reshape(cur_weight.shape).astype(np.float32))
workspace.FeedBlob(
bias_blob,
bm.reshape(cur_biases.shape).astype(np.float32))
def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
scope, recurrent_params=None, input_params=None,
num_layers=1, return_params=False):
'''
CuDNN version of LSTM for GPUs.
input_blob Blob containing the input. Will need to be available
when param_init_net is run, because the sequence lengths
and batch sizes will be inferred from the size of this
blob.
initial_states tuple of (hidden_init, cell_init) blobs
dim_in input dimensions
dim_out output/hidden dimension
scope namescope to apply
recurrent_params dict of blobs containing values for recurrent
gate weights, biases (if None, use random init values)
See GetLSTMParamNames() for format.
input_params dict of blobs containing values for input
gate weights, biases (if None, use random init values)
See GetLSTMParamNames() for format.
num_layers number of LSTM layers
return_params if True, returns (param_extract_net, param_mapping)
where param_extract_net is a net that when run, will
populate the blobs specified in param_mapping with the
current gate weights and biases (input/recurrent).
Useful for assigning the values back to non-cuDNN
LSTM.
'''
with core.NameScope(scope):
weight_params = GetLSTMParamNames()['weights']
bias_params = GetLSTMParamNames()['biases']
input_weight_size = dim_out * dim_in
recurrent_weight_size = dim_out * dim_out
input_bias_size = dim_out
recurrent_bias_size = dim_out
def init(layer, pname, input_type):
if pname in weight_params:
sz = input_weight_size if input_type == 'input' \
else recurrent_weight_size
elif pname in bias_params:
sz = input_bias_size if input_type == 'input' \
else recurrent_bias_size
else:
assert False, "unknown parameter type {}".format(pname)
return model.param_init_net.UniformFill(
[],
"lstm_init_{}_{}_{}".format(input_type, pname, layer),
shape=[sz])
# Multiply by 4 since we have 4 gates per LSTM unit
total_sz = 4 * num_layers * (
input_weight_size + recurrent_weight_size + input_bias_size +
recurrent_bias_size
)
weights = model.param_init_net.UniformFill(
[], "lstm_weight", shape=[total_sz])
model.params.append(weights)
model.weights.append(weights)
lstm_args = {
'hidden_size': dim_out,
'rnn_mode': 'lstm',
'bidirectional': 0, # TODO
'dropout': 1.0, # TODO
'input_mode': 'linear', # TODO
'num_layers': num_layers,
'engine': 'CUDNN'
}
param_extract_net = core.Net("lstm_param_extractor")
param_extract_net.AddExternalInputs([input_blob, weights])
param_extract_mapping = {}
# Populate the weights-blob from blobs containing parameters for
# the individual components of the LSTM, such as forget/input gate
# weights and bises. Also, create a special param_extract_net that
# can be used to grab those individual params from the black-box
# weights blob. These results can be then fed to InitFromLSTMParams()
for input_type in ['input', 'recurrent']:
param_extract_mapping[input_type] = {}
p = recurrent_params if input_type == 'recurrent' else input_params
if p is None:
p = {}
for pname in weight_params + bias_params:
for j in range(0, num_layers):
values = p[pname] if pname in p else init(j, pname, input_type)
model.param_init_net.RecurrentParamSet(
[input_blob, weights, values],
weights,
layer=j,
input_type=input_type,
param_type=pname,
**lstm_args
)
if pname not in param_extract_mapping[input_type]:
param_extract_mapping[input_type][pname] = {}
b = param_extract_net.RecurrentParamGet(
[input_blob, weights],
["lstm_{}_{}_{}".format(input_type, pname, j)],
layer=j,
input_type=input_type,
param_type=pname,
**lstm_args
)
param_extract_mapping[input_type][pname][j] = b
(hidden_input_blob, cell_input_blob) = initial_states
output, hidden_output, cell_output, rnn_scratch, dropout_states = \
model.net.Recurrent(
[input_blob, cell_input_blob, cell_input_blob, weights],
["lstm_output", "lstm_hidden_output", "lstm_cell_output",
"lstm_rnn_scratch", "lstm_dropout_states"],
seed=random.randint(0, 100000), # TODO: dropout seed
**lstm_args
)
model.net.AddExternalOutputs(
hidden_output, cell_output, rnn_scratch, dropout_states)
if return_params:
param_extract = param_extract_net, param_extract_mapping
return output, hidden_output, cell_output, param_extract
else:
return output, hidden_output, cell_output
def LSTMWithAttention(
model,
decoder_inputs,
decoder_input_lengths,
initial_decoder_hidden_state,
initial_decoder_cell_state,
initial_attention_weighted_encoder_context,
encoder_output_dim,
encoder_outputs,
decoder_input_dim,
decoder_state_dim,
scope,
attention_type=AttentionType.Regular,
outputs_with_grads=(0, 4),
weighted_encoder_outputs=None,
lstm_memory_optimization=False,
attention_memory_optimization=False,
):
'''
Adds a LSTM with attention mechanism to a model.
The implementation is based on https://arxiv.org/abs/1409.0473, with
a small difference in the order
how we compute new attention context and new hidden state, similarly to
https://arxiv.org/abs/1508.04025.
The model uses encoder-decoder naming conventions,
where the decoder is the sequence the op is iterating over,
while computing the attention context over the encoder.
model: CNNModelHelper object new operators would be added to
decoder_inputs: the input sequence in a format T x N x D
where T is sequence size, N - batch size and D - input dimention
decoder_input_lengths: blob containing sequence lengths
which would be passed to LSTMUnit operator
initial_decoder_hidden_state: initial hidden state of LSTM
initial_decoder_cell_state: initial cell state of LSTM
initial_attention_weighted_encoder_context: initial attention context
encoder_output_dim: dimension of encoder outputs
encoder_outputs: the sequence, on which we compute the attention context
at every iteration
decoder_input_dim: input dimention (last dimension on decoder_inputs)
decoder_state_dim: size of hidden states of LSTM
attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
Determines which type of attention mechanism to use.
outputs_with_grads : position indices of output blobs which will receive
external error gradient during backpropagation
weighted_encoder_outputs: encoder outputs to be used to compute attention
weights. In the basic case it's just linear transformation of
encoder outputs (that the default, when weighted_encoder_outputs is None).
However, it can be something more complicated - like a separate
encoder network (for example, in case of convolutional encoder)
lstm_memory_optimization: recompute LSTM activations on backward pass, so
we don't need to store their values in forward passes
attention_memory_optimization: recompute attention for backward pass
'''
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
return "{}/{}".format(str(scope), str(name))
decoder_inputs = model.FC(
decoder_inputs,
s('i2h'),
dim_in=decoder_input_dim,
dim_out=4 * decoder_state_dim,
axis=2,
)
# [batch_size, encoder_output_dim, encoder_length]
encoder_outputs_transposed = model.Transpose(
encoder_outputs,
s('encoder_outputs_transposed'),
axes=[1, 2, 0],
)
if weighted_encoder_outputs is None:
weighted_encoder_outputs = model.FC(
encoder_outputs,
s('weighted_encoder_outputs'),
dim_in=encoder_output_dim,
dim_out=encoder_output_dim,
axis=2,
)
step_model = CNNModelHelper(
name='lstm_with_attention_cell',
param_model=model,
)
(
input_t,
timestep,
cell_t_prev,
hidden_t_prev,
attention_weighted_encoder_context_t_prev,
) = (
step_model.net.AddScopedExternalInputs(
'input_t',
'timestep',
'cell_t_prev',
'hidden_t_prev',
'attention_weighted_encoder_context_t_prev',
)
)
step_model.net.AddExternalInputs(
encoder_outputs_transposed,
weighted_encoder_outputs
)
gates_concatenated_input_t, _ = step_model.net.Concat(
[hidden_t_prev, attention_weighted_encoder_context_t_prev],
[
s('gates_concatenated_input_t'),
s('_gates_concatenated_input_t_concat_dims'),
],
axis=2,
)
gates_t = step_model.FC(
gates_concatenated_input_t,
s('gates_t'),
dim_in=decoder_state_dim + encoder_output_dim,
dim_out=4 * decoder_state_dim,
axis=2,
)
step_model.net.Sum([gates_t, input_t], gates_t)
hidden_t_intermediate, cell_t = step_model.net.LSTMUnit(
[hidden_t_prev, cell_t_prev, gates_t, decoder_input_lengths, timestep],
['hidden_t_intermediate', s('cell_t')],
)
if attention_type == AttentionType.Recurrent:
attention_weighted_encoder_context_t, _, attention_blobs = apply_recurrent_attention(
model=step_model,
encoder_output_dim=encoder_output_dim,
encoder_outputs_transposed=encoder_outputs_transposed,
weighted_encoder_outputs=weighted_encoder_outputs,
decoder_hidden_state_t=hidden_t_intermediate,
decoder_hidden_state_dim=decoder_state_dim,
scope=scope,
attention_weighted_encoder_context_t_prev=(
attention_weighted_encoder_context_t_prev
),
)
else:
attention_weighted_encoder_context_t, _, attention_blobs = apply_regular_attention(
model=step_model,
encoder_output_dim=encoder_output_dim,
encoder_outputs_transposed=encoder_outputs_transposed,
weighted_encoder_outputs=weighted_encoder_outputs,
decoder_hidden_state_t=hidden_t_intermediate,
decoder_hidden_state_dim=decoder_state_dim,
scope=scope,
)
hidden_t = step_model.Copy(hidden_t_intermediate, s('hidden_t'))
step_model.net.AddExternalOutputs(
cell_t,
hidden_t,
attention_weighted_encoder_context_t,
)
recompute_blobs = []
if attention_memory_optimization:
recompute_blobs.extend(attention_blobs)
if lstm_memory_optimization:
recompute_blobs.extend([gates_t])
return recurrent_net(
net=model.net,
cell_net=step_model.net,
inputs=[
(input_t, decoder_inputs),
],
initial_cell_inputs=[
(hidden_t_prev, initial_decoder_hidden_state),
(cell_t_prev, initial_decoder_cell_state),
(
attention_weighted_encoder_context_t_prev,
initial_attention_weighted_encoder_context,
),
],
links={
hidden_t_prev: hidden_t,
cell_t_prev: cell_t,
attention_weighted_encoder_context_t_prev: (
attention_weighted_encoder_context_t
),
},
timestep=timestep,
scope=scope,
outputs_with_grads=outputs_with_grads,
recompute_blobs_on_backward=recompute_blobs,
)
def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope, outputs_with_grads=(0,), memory_optimization=False):
'''
Adds MI flavor of standard LSTM recurrent network operator to a model.
See https://arxiv.org/pdf/1606.06630.pdf
model: CNNModelHelper object new operators would be added to
input_blob: the input sequence in a format T x N x D
where T is sequence size, N - batch size and D - input dimention
seq_lengths: blob containing sequence lengths which would be passed to
LSTMUnit operator
initial_states: a tupple of (hidden_input_blob, cell_input_blob)
which are going to be inputs to the cell net on the first iteration
dim_in: input dimention
dim_out: output dimention
outputs_with_grads : position indices of output blobs which will receive
external error gradient during backpropagation
memory_optimization: if enabled, the LSTM step is recomputed on backward step
so that we don't need to store forward activations for each
timestep. Saves memory with cost of computation.
'''
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
return "{}/{}".format(str(scope), str(name))
""" initial bulk fully-connected """
input_blob = model.FC(
input_blob, s('i2h'), dim_in=dim_in, dim_out=4 * dim_out, axis=2)
""" the step net """
step_model = CNNModelHelper(name='milstm_cell', param_model=model)
input_t, timestep, cell_t_prev, hidden_t_prev = (
step_model.net.AddScopedExternalInputs(
'input_t', 'timestep', 'cell_t_prev', 'hidden_t_prev'))
# hU^T
# Shape: [1, batch_size, 4 * hidden_size]
prev_t = step_model.FC(
hidden_t_prev, s('prev_t'), dim_in=dim_out,
dim_out=4 * dim_out, axis=2)
# defining MI parameters
alpha = step_model.param_init_net.ConstantFill(
[],
[s('alpha')],
shape=[4 * dim_out],
value=1.0
)
beta1 = step_model.param_init_net.ConstantFill(
[],
[s('beta1')],
shape=[4 * dim_out],
value=1.0
)
beta2 = step_model.param_init_net.ConstantFill(
[],
[s('beta2')],
shape=[4 * dim_out],
value=1.0
)
b = step_model.param_init_net.ConstantFill(
[],
[s('b')],
shape=[4 * dim_out],
value=0.0
)
model.params.extend([alpha, beta1, beta2, b])
# alpha * (xW^T * hU^T)
# Shape: [1, batch_size, 4 * hidden_size]
alpha_tdash = step_model.net.Mul(
[prev_t, input_t],
s('alpha_tdash')
)
# Shape: [batch_size, 4 * hidden_size]
alpha_tdash_rs, _ = step_model.net.Reshape(
alpha_tdash,
[s('alpha_tdash_rs'), s('alpha_tdash_old_shape')],
shape=[-1, 4 * dim_out],
)
alpha_t = step_model.net.Mul(
[alpha_tdash_rs, alpha],
s('alpha_t'),
broadcast=1,
use_grad_hack=1
)
# beta1 * hU^T
# Shape: [batch_size, 4 * hidden_size]
prev_t_rs, _ = step_model.net.Reshape(
prev_t,
[s('prev_t_rs'), s('prev_t_old_shape')],
shape=[-1, 4 * dim_out],
)
beta1_t = step_model.net.Mul(
[prev_t_rs, beta1],
s('beta1_t'),
broadcast=1,
use_grad_hack=1
)
# beta2 * xW^T
# Shape: [batch_szie, 4 * hidden_size]
input_t_rs, _ = step_model.net.Reshape(
input_t,
[s('input_t_rs'), s('input_t_old_shape')],
shape=[-1, 4 * dim_out],
)
beta2_t = step_model.net.Mul(
[input_t_rs, beta2],
s('beta2_t'),
broadcast=1,
use_grad_hack=1
)
# Add 'em all up
gates_tdash = step_model.net.Sum(
[alpha_t, beta1_t, beta2_t],
s('gates_tdash')
)
gates_t = step_model.net.Add(
[gates_tdash, b],
s('gates_t'),
broadcast=1,
use_grad_hack=1
)
# # Shape: [1, batch_size, 4 * hidden_size]
gates_t_rs, _ = step_model.net.Reshape(
gates_t,
[s('gates_t_rs'), s('gates_t_old_shape')],
shape=[1, -1, 4 * dim_out],
)
hidden_t, cell_t = step_model.net.LSTMUnit(
[hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
[s('hidden_t'), s('cell_t')],
)
step_model.net.AddExternalOutputs(cell_t, hidden_t)
""" recurrent network """
(hidden_input_blob, cell_input_blob) = initial_states
output, last_output, all_states, last_state = recurrent_net(
net=model.net,
cell_net=step_model.net,
inputs=[(input_t, input_blob)],
initial_cell_inputs=[
(hidden_t_prev, hidden_input_blob),
(cell_t_prev, cell_input_blob),
],
links={
hidden_t_prev: hidden_t,
cell_t_prev: cell_t,
},
timestep=timestep,
scope=scope,
outputs_with_grads=outputs_with_grads,
recompute_blobs_on_backward=[gates_t] if memory_optimization else None
)
return output, last_output, all_states, last_state