from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals def apply_regular_attention( model, encoder_output_dim, encoder_outputs_transposed, weighted_encoder_outputs, decoder_hidden_state_t, decoder_hidden_state_dim, # TODO: we need to provide batch_size for some reshape methods, # ideally, we should be able to not specify it batch_size, scope, ): def s(name): # We have to manually scope due to our internal/external blob # relationships. return "{}/{}".format(str(scope), str(name)) # [1, batch_size, encoder_output_dim] weighted_decoder_hidden_state = model.FC( decoder_hidden_state_t, s('weighted_decoder_hidden_state'), dim_in=decoder_hidden_state_dim, dim_out=encoder_output_dim, axis=2, ) # [batch_size, encoder_output_dim] weighted_decoder_hidden_state = model.net.Squeeze( weighted_decoder_hidden_state, weighted_decoder_hidden_state, dims=[0], ) # TODO: remove that excessive when RecurrentNetwork supports # Sum op at the beginning of step_net weighted_encoder_outputs_copy = model.net.Copy( weighted_encoder_outputs, s('weighted_encoder_outputs_copy'), ) # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum = model.net.Add( [weighted_encoder_outputs_copy, weighted_decoder_hidden_state], s('decoder_hidden_encoder_outputs_sum'), broadcast=1, use_grad_hack=1, ) # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum = model.net.Tanh( decoder_hidden_encoder_outputs_sum, decoder_hidden_encoder_outputs_sum, ) # [encoder_length * batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum_tanh_2d, _ = model.net.Reshape( decoder_hidden_encoder_outputs_sum, [ s('decoder_hidden_encoder_outputs_sum_tanh_2d'), s('decoder_hidden_encoder_outputs_sum_tanh_t_old_shape'), ], shape=[-1, encoder_output_dim], ) attention_v = model.param_init_net.XavierFill( [], s('attention_v'), shape=[encoder_output_dim, 1], ) model.add_param(attention_v) # [encoder_length * batch_size, 1] attention_logits = model.net.MatMul( [decoder_hidden_encoder_outputs_sum_tanh_2d, attention_v], s('attention_logits'), ) # [encoder_length, batch_size] attention_logits, _ = model.net.Reshape( attention_logits, [ attention_logits, s('attention_logits_old_shape'), ], shape=[-1, batch_size], ) # [batch_size, encoder_length] attention_logits_transposed = model.net.Transpose( attention_logits, s('attention_logits_transposed'), axes=[1, 0], ) # TODO: we could try to force some attention weights to be zeros, # based on encoder_lengths. # [batch_size, encoder_length] attention_weights = model.Softmax( attention_logits_transposed, s('attention_weights'), ) # TODO: make this operation in-place # [batch_size, encoder_length, 1] attention_weights_3d = model.net.ExpandDims( attention_weights, s('attention_weights_3d'), dims=[2], ) # [batch_size, encoder_output_dim, 1] attention_weighted_encoder_context = model.net.BatchMatMul( [encoder_outputs_transposed, attention_weights_3d], s('attention_weighted_encoder_context'), ) # TODO: somehow I cannot use Squeeze in-place op here # [batch_size, encoder_output_dim] attention_weighted_encoder_context, _ = model.net.Reshape( attention_weighted_encoder_context, [ attention_weighted_encoder_context, s('attention_weighted_encoder_context_old_shape') ], shape=[-1, encoder_output_dim], ) return attention_weighted_encoder_context