mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Implementation of ##LSTMWithAttention## Still TBD: 1. There are problems with back propagation, because gradient is not implemented for ops with broadcasting 2. I need to make initial_recurrent_state to be of shape [dim] rather than [1, batch_size, dim], so one doesn't need to provide batch_size to LSTMWithAttention Differential Revision: D4298735 fbshipit-source-id: 8903fcff4d6a66647ee6d45a6ef28803fc3091e5
122 lines
4.0 KiB
Python
122 lines
4.0 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
def apply_regular_attention(
|
|
model,
|
|
encoder_output_dim,
|
|
encoder_outputs_transposed,
|
|
weighted_encoder_outputs,
|
|
decoder_hidden_state_t,
|
|
decoder_hidden_state_dim,
|
|
# TODO: we need to provide batch_size for some reshape methods,
|
|
# ideally, we should be able to not specify it
|
|
batch_size,
|
|
scope,
|
|
):
|
|
def s(name):
|
|
# We have to manually scope due to our internal/external blob
|
|
# relationships.
|
|
return "{}/{}".format(str(scope), str(name))
|
|
|
|
# [1, batch_size, encoder_output_dim]
|
|
weighted_decoder_hidden_state = model.FC(
|
|
decoder_hidden_state_t,
|
|
s('weighted_decoder_hidden_state'),
|
|
dim_in=decoder_hidden_state_dim,
|
|
dim_out=encoder_output_dim,
|
|
axis=2,
|
|
)
|
|
# [batch_size, encoder_output_dim]
|
|
weighted_decoder_hidden_state = model.net.Squeeze(
|
|
weighted_decoder_hidden_state,
|
|
weighted_decoder_hidden_state,
|
|
dims=[0],
|
|
)
|
|
# TODO: remove that excessive when RecurrentNetwork supports
|
|
# Sum op at the beginning of step_net
|
|
weighted_encoder_outputs_copy = model.net.Copy(
|
|
weighted_encoder_outputs,
|
|
s('weighted_encoder_outputs_copy'),
|
|
)
|
|
# [encoder_length, batch_size, encoder_output_dim]
|
|
decoder_hidden_encoder_outputs_sum = model.net.Add(
|
|
[weighted_encoder_outputs_copy, weighted_decoder_hidden_state],
|
|
s('decoder_hidden_encoder_outputs_sum'),
|
|
broadcast=1,
|
|
use_grad_hack=1,
|
|
)
|
|
# [encoder_length, batch_size, encoder_output_dim]
|
|
decoder_hidden_encoder_outputs_sum = model.net.Tanh(
|
|
decoder_hidden_encoder_outputs_sum,
|
|
decoder_hidden_encoder_outputs_sum,
|
|
)
|
|
# [encoder_length * batch_size, encoder_output_dim]
|
|
decoder_hidden_encoder_outputs_sum_tanh_2d, _ = model.net.Reshape(
|
|
decoder_hidden_encoder_outputs_sum,
|
|
[
|
|
s('decoder_hidden_encoder_outputs_sum_tanh_2d'),
|
|
s('decoder_hidden_encoder_outputs_sum_tanh_t_old_shape'),
|
|
],
|
|
shape=[-1, encoder_output_dim],
|
|
)
|
|
attention_v = model.param_init_net.XavierFill(
|
|
[],
|
|
s('attention_v'),
|
|
shape=[encoder_output_dim, 1],
|
|
)
|
|
model.add_param(attention_v)
|
|
|
|
# [encoder_length * batch_size, 1]
|
|
attention_logits = model.net.MatMul(
|
|
[decoder_hidden_encoder_outputs_sum_tanh_2d, attention_v],
|
|
s('attention_logits'),
|
|
)
|
|
# [encoder_length, batch_size]
|
|
attention_logits, _ = model.net.Reshape(
|
|
attention_logits,
|
|
[
|
|
attention_logits,
|
|
s('attention_logits_old_shape'),
|
|
],
|
|
shape=[-1, batch_size],
|
|
)
|
|
# [batch_size, encoder_length]
|
|
attention_logits_transposed = model.net.Transpose(
|
|
attention_logits,
|
|
s('attention_logits_transposed'),
|
|
axes=[1, 0],
|
|
)
|
|
# TODO: we could try to force some attention weights to be zeros,
|
|
# based on encoder_lengths.
|
|
# [batch_size, encoder_length]
|
|
attention_weights = model.Softmax(
|
|
attention_logits_transposed,
|
|
s('attention_weights'),
|
|
)
|
|
# TODO: make this operation in-place
|
|
# [batch_size, encoder_length, 1]
|
|
attention_weights_3d = model.net.ExpandDims(
|
|
attention_weights,
|
|
s('attention_weights_3d'),
|
|
dims=[2],
|
|
)
|
|
# [batch_size, encoder_output_dim, 1]
|
|
attention_weighted_encoder_context = model.net.BatchMatMul(
|
|
[encoder_outputs_transposed, attention_weights_3d],
|
|
s('attention_weighted_encoder_context'),
|
|
)
|
|
# TODO: somehow I cannot use Squeeze in-place op here
|
|
# [batch_size, encoder_output_dim]
|
|
attention_weighted_encoder_context, _ = model.net.Reshape(
|
|
attention_weighted_encoder_context,
|
|
[
|
|
attention_weighted_encoder_context,
|
|
s('attention_weighted_encoder_context_old_shape')
|
|
],
|
|
shape=[-1, encoder_output_dim],
|
|
)
|
|
return attention_weighted_encoder_context
|