pytorch/caffe2/python/attention.py
Aapo Kyrola 1e5140aa76 option to recompute blobs backward pass with massive memory savings
Summary:
This diff adds an option to recurrent_net to define some cell blobs to be recomputed on backward step, and thus they don't need to be stored in the step workspace. This is done by modifying the backward step to automatically include all operators that are needed to produce the output that is to be recomputed, and by storing those blobs in a shared workspace. To enable the shared workspace, i had to modify the stepworkspaces blob to also store a forward shared workspace. Making it a class field won't work since the lifecycle of the blob does not match the lifecycle of the operator.

For basic LSTM, the performance hit is quite modest (about 15% with one setting, but your mileage might vary. For Attention models, I am sure this is beneficial as computing the attention blobs is not expensive.

For basic LSTM, the memory saving is wonderful: each forward workspace only has 4 bytes (for timestep).

I also modified the neural_mt LSTM Cells, but there is no test available, so I am not 100% sure I did it correctly. Please have a look.

Added options to LSTM, MILSTM and LSTMAttention to enable memory mode.

Reviewed By: urikz

Differential Revision: D4853890

fbshipit-source-id: d8d0e0e75a5330d174fbfa39b96d8e4e8c446baa
2017-04-11 13:03:48 -07:00

270 lines
7.6 KiB
Python

## @package attention
# Module caffe2.python.attention
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttentionType:
Regular, Recurrent = range(2)
def s(scope, name):
# We have to manually scope due to our internal/external blob
# relationships.
return "{}/{}".format(str(scope), str(name))
# c_i = \sum_j w_{ij}\textbf{s}_j
def _calc_weighted_context(
model,
encoder_outputs_transposed,
encoder_output_dim,
attention_weights_3d,
scope,
):
# [batch_size, encoder_output_dim, 1]
attention_weighted_encoder_context = model.net.BatchMatMul(
[encoder_outputs_transposed, attention_weights_3d],
s(scope, 'attention_weighted_encoder_context'),
)
# TODO: somehow I cannot use Squeeze in-place op here
# [batch_size, encoder_output_dim]
attention_weighted_encoder_context, _ = model.net.Reshape(
attention_weighted_encoder_context,
[
attention_weighted_encoder_context,
s(scope, 'attention_weighted_encoder_context_old_shape')
],
shape=[1, -1, encoder_output_dim],
)
return attention_weighted_encoder_context
# Calculate a softmax over the passed in attention energy logits
def _calc_attention_weights(
model,
attention_logits_transposed,
scope
):
# TODO: we could try to force some attention weights to be zeros,
# based on encoder_lengths.
# [batch_size, encoder_length]
attention_weights = model.Softmax(
attention_logits_transposed,
s(scope, 'attention_weights'),
engine='CUDNN',
)
# TODO: make this operation in-place
# [batch_size, encoder_length, 1]
attention_weights_3d = model.net.ExpandDims(
attention_weights,
s(scope, 'attention_weights_3d'),
dims=[2],
)
return attention_weights_3d
# e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
def _calc_attention_logits_from_sum_match(
model,
decoder_hidden_encoder_outputs_sum,
encoder_output_dim,
scope
):
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum = model.net.Tanh(
decoder_hidden_encoder_outputs_sum,
decoder_hidden_encoder_outputs_sum,
)
attention_v = model.param_init_net.XavierFill(
[],
s(scope, 'attention_v'),
shape=[1, encoder_output_dim],
)
model.add_param(attention_v)
attention_zeros = model.param_init_net.ConstantFill(
[],
s(scope, 'attention_zeros'),
value=0.0,
shape=[1],
)
# [encoder_length, batch_size, 1]
attention_logits = model.net.FC(
[decoder_hidden_encoder_outputs_sum, attention_v, attention_zeros],
[s(scope, 'attention_logits')],
axis=2
)
# [encoder_length, batch_size]
attention_logits = model.net.Squeeze(
[attention_logits],
[attention_logits],
dims=[2],
)
# [batch_size, encoder_length]
attention_logits_transposed = model.Transpose(
attention_logits,
s(scope, 'attention_logits_transposed'),
axes=[1, 0],
)
return attention_logits_transposed
# \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
def _apply_fc_weight_for_sum_match(
model,
input,
dim_in,
dim_out,
scope,
name
):
output = model.FC(
input,
s(scope, name),
dim_in=dim_in,
dim_out=dim_out,
axis=2,
)
output = model.net.Squeeze(
output,
output,
dims=[0]
)
return output
# Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
def apply_recurrent_attention(
model,
encoder_output_dim,
encoder_outputs_transposed,
weighted_encoder_outputs,
decoder_hidden_state_t,
decoder_hidden_state_dim,
attention_weighted_encoder_context_t_prev,
scope,
):
weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
model=model,
input=attention_weighted_encoder_context_t_prev,
dim_in=encoder_output_dim,
dim_out=encoder_output_dim,
scope=scope,
name='weighted_prev_attention_context'
)
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
model=model,
input=decoder_hidden_state_t,
dim_in=decoder_hidden_state_dim,
dim_out=encoder_output_dim,
scope=scope,
name='weighted_decoder_hidden_state'
)
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
[
weighted_encoder_outputs,
weighted_decoder_hidden_state
],
s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
broadcast=1,
use_grad_hack=1,
)
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum = model.net.Add(
[
decoder_hidden_encoder_outputs_sum_tmp,
weighted_prev_attention_context
],
s(scope, 'decoder_hidden_encoder_outputs_sum'),
broadcast=1,
use_grad_hack=1,
)
attention_logits_transposed = _calc_attention_logits_from_sum_match(
model=model,
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
encoder_output_dim=encoder_output_dim,
scope=scope
)
# [batch_size, encoder_length, 1]
attention_weights_3d = _calc_attention_weights(
model=model,
attention_logits_transposed=attention_logits_transposed,
scope=scope
)
# [batch_size, encoder_output_dim, 1]
attention_weighted_encoder_context = _calc_weighted_context(
model=model,
encoder_outputs_transposed=encoder_outputs_transposed,
encoder_output_dim=encoder_output_dim,
attention_weights_3d=attention_weights_3d,
scope=scope
)
return attention_weighted_encoder_context, attention_weights_3d, [
decoder_hidden_encoder_outputs_sum_tmp,
decoder_hidden_encoder_outputs_sum
]
def apply_regular_attention(
model,
encoder_output_dim,
encoder_outputs_transposed,
weighted_encoder_outputs,
decoder_hidden_state_t,
decoder_hidden_state_dim,
scope,
):
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
model=model,
input=decoder_hidden_state_t,
dim_in=decoder_hidden_state_dim,
dim_out=encoder_output_dim,
scope=scope,
name='weighted_decoder_hidden_state'
)
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum = model.net.Add(
[weighted_encoder_outputs, weighted_decoder_hidden_state],
s(scope, 'decoder_hidden_encoder_outputs_sum'),
broadcast=1,
use_grad_hack=1,
)
attention_logits_transposed = _calc_attention_logits_from_sum_match(
model=model,
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
encoder_output_dim=encoder_output_dim,
scope=scope
)
# [batch_size, encoder_length, 1]
attention_weights_3d = _calc_attention_weights(
model=model,
attention_logits_transposed=attention_logits_transposed,
scope=scope
)
# [batch_size, encoder_output_dim, 1]
attention_weighted_encoder_context = _calc_weighted_context(
model=model,
encoder_outputs_transposed=encoder_outputs_transposed,
encoder_output_dim=encoder_output_dim,
attention_weights_3d=attention_weights_3d,
scope=scope
)
return attention_weighted_encoder_context, attention_weights_3d, [
decoder_hidden_encoder_outputs_sum
]