pytorch/caffe2/python/attention.py
Yury Zemlyanskiy d1424c3265 Revert D4702086: Remove batch_size parameter from attention and LSTMWithAttention interfaces
Summary: This reverts commit c4c1d8425cd36c1e86695918eaba2667c27e9601

Differential Revision: D4702086

fbshipit-source-id: 4620610b182bb84b9297b5de32782761ae89d20b
2017-03-17 17:36:47 -07:00

285 lines
8.4 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttentionType:
Regular, Recurrent = range(2)
def s(scope, name):
# We have to manually scope due to our internal/external blob
# relationships.
return "{}/{}".format(str(scope), str(name))
# c_i = \sum_j w_{ij}\textbf{s}_j
def _calc_weighted_context(
model,
encoder_outputs_transposed,
encoder_output_dim,
attention_weights_3d,
scope,
):
# [batch_size, encoder_output_dim, 1]
attention_weighted_encoder_context = model.net.BatchMatMul(
[encoder_outputs_transposed, attention_weights_3d],
s(scope, 'attention_weighted_encoder_context'),
)
# TODO: somehow I cannot use Squeeze in-place op here
# [batch_size, encoder_output_dim]
attention_weighted_encoder_context, _ = model.net.Reshape(
attention_weighted_encoder_context,
[
attention_weighted_encoder_context,
s(scope, 'attention_weighted_encoder_context_old_shape')
],
shape=[-1, encoder_output_dim],
)
return attention_weighted_encoder_context
# Calculate a softmax over the passed in attention energy logits
def _calc_attention_weights(
model,
attention_logits_transposed,
scope
):
# TODO: we could try to force some attention weights to be zeros,
# based on encoder_lengths.
# [batch_size, encoder_length]
attention_weights = model.Softmax(
attention_logits_transposed,
s(scope, 'attention_weights'),
)
# TODO: make this operation in-place
# [batch_size, encoder_length, 1]
attention_weights_3d = model.net.ExpandDims(
attention_weights,
s(scope, 'attention_weights_3d'),
dims=[2],
)
return attention_weights_3d
# e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
def _calc_attention_logits_from_sum_match(
model,
decoder_hidden_encoder_outputs_sum,
encoder_output_dim,
batch_size,
scope
):
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum = model.net.Tanh(
decoder_hidden_encoder_outputs_sum,
decoder_hidden_encoder_outputs_sum,
)
# [encoder_length * batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum_tanh_2d, _ = model.net.Reshape(
decoder_hidden_encoder_outputs_sum,
[
s(scope, 'decoder_hidden_encoder_outputs_sum_tanh_2d'),
s(scope, 'decoder_hidden_encoder_outputs_sum_tanh_t_old_shape'),
],
shape=[-1, encoder_output_dim],
)
attention_v = model.param_init_net.XavierFill(
[],
s(scope, 'attention_v'),
shape=[encoder_output_dim, 1],
)
model.add_param(attention_v)
# [encoder_length * batch_size, 1]
attention_logits = model.net.MatMul(
[decoder_hidden_encoder_outputs_sum_tanh_2d, attention_v],
s(scope, 'attention_logits'),
)
# [encoder_length, batch_size]
attention_logits, _ = model.net.Reshape(
attention_logits,
[
attention_logits,
s(scope, 'attention_logits_old_shape'),
],
shape=[-1, batch_size],
)
# [batch_size, encoder_length]
attention_logits_transposed = model.net.Transpose(
attention_logits,
s(scope, 'attention_logits_transposed'),
axes=[1, 0],
)
return attention_logits_transposed
# \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
def _apply_fc_weight_for_sum_match(
model,
input,
dim_in,
dim_out,
scope,
name
):
output = model.FC(
input,
s(scope, name),
dim_in=dim_in,
dim_out=dim_out,
axis=2,
)
output = model.net.Squeeze(
output,
output,
dims=[0]
)
return output
# Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
def apply_recurrent_attention(
model,
encoder_output_dim,
encoder_outputs_transposed,
weighted_encoder_outputs,
decoder_hidden_state_t,
decoder_hidden_state_dim,
attention_weighted_encoder_context_t_prev,
# TODO: we need to provide batch_size for some reshape methods,
# ideally, we should be able to not specify it
batch_size,
scope,
):
weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
model=model,
input=attention_weighted_encoder_context_t_prev,
dim_in=encoder_output_dim,
dim_out=encoder_output_dim,
scope=scope,
name='weighted_prev_attention_context'
)
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
model=model,
input=decoder_hidden_state_t,
dim_in=decoder_hidden_state_dim,
dim_out=encoder_output_dim,
scope=scope,
name='weighted_decoder_hidden_state'
)
# TODO: remove that excessive when RecurrentNetwork supports
# Sum op at the beginning of step_net
weighted_encoder_outputs_copy = model.net.Copy(
weighted_encoder_outputs,
s(scope, 'weighted_encoder_outputs_copy'),
)
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
[
weighted_encoder_outputs_copy,
weighted_decoder_hidden_state
],
s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
broadcast=1,
use_grad_hack=1,
)
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum = model.net.Add(
[
decoder_hidden_encoder_outputs_sum_tmp,
weighted_prev_attention_context
],
s(scope, 'decoder_hidden_encoder_outputs_sum'),
broadcast=1,
use_grad_hack=1,
)
attention_logits_transposed = _calc_attention_logits_from_sum_match(
model=model,
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
encoder_output_dim=encoder_output_dim,
batch_size=batch_size,
scope=scope
)
attention_weights_3d = _calc_attention_weights(
model=model,
attention_logits_transposed=attention_logits_transposed,
scope=scope
)
# [batch_size, encoder_output_dim, 1]
attention_weighted_encoder_context = _calc_weighted_context(
model=model,
encoder_outputs_transposed=encoder_outputs_transposed,
encoder_output_dim=encoder_output_dim,
attention_weights_3d=attention_weights_3d,
scope=scope
)
return attention_weighted_encoder_context
def apply_regular_attention(
model,
encoder_output_dim,
encoder_outputs_transposed,
weighted_encoder_outputs,
decoder_hidden_state_t,
decoder_hidden_state_dim,
# TODO: we need to provide batch_size for some reshape methods,
# ideally, we should be able to not specify it
batch_size,
scope,
):
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
model=model,
input=decoder_hidden_state_t,
dim_in=decoder_hidden_state_dim,
dim_out=encoder_output_dim,
scope=scope,
name='weighted_decoder_hidden_state'
)
# TODO: remove that excessive when RecurrentNetwork supports
# Sum op at the beginning of step_net
weighted_encoder_outputs_copy = model.net.Copy(
weighted_encoder_outputs,
s(scope, 'weighted_encoder_outputs_copy'),
)
# [encoder_length, batch_size, encoder_output_dim]
decoder_hidden_encoder_outputs_sum = model.net.Add(
[weighted_encoder_outputs_copy, weighted_decoder_hidden_state],
s(scope, 'decoder_hidden_encoder_outputs_sum'),
broadcast=1,
use_grad_hack=1,
)
attention_logits_transposed = _calc_attention_logits_from_sum_match(
model=model,
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
encoder_output_dim=encoder_output_dim,
batch_size=batch_size,
scope=scope
)
attention_weights_3d = _calc_attention_weights(
model=model,
attention_logits_transposed=attention_logits_transposed,
scope=scope
)
# [batch_size, encoder_output_dim, 1]
attention_weighted_encoder_context = _calc_weighted_context(
model=model,
encoder_outputs_transposed=encoder_outputs_transposed,
encoder_output_dim=encoder_output_dim,
attention_weights_3d=attention_weights_3d,
scope=scope
)
return attention_weighted_encoder_context