## @package rnn_cell # Module caffe2.python.rnn_cell from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import numpy as np import random from caffe2.python.attention import ( AttentionType, apply_regular_attention, apply_recurrent_attention, ) from caffe2.python import core, recurrent, workspace from caffe2.python.cnn import CNNModelHelper class RNNCell(object): ''' Base class for writing recurrent / stateful operations. One needs to implement 3 methods: _apply, prepare_input and get_state_names. As a result base class will provice apply_over_sequence method, which allows you to apply recurrent operations over a sequence of any length. ''' def __init__(self, name, forward_only=False): self.name = name self.recompute_blobs = [] self.forward_only = forward_only def scope(self, name): return self.name + '/' + name if self.name is not None else name def apply_over_sequence( self, model, inputs, seq_lengths, initial_states, outputs_with_grads=None, ): preprocessed_inputs = self.prepare_input(model, inputs) step_model = CNNModelHelper(name=self.name, param_model=model) input_t, timestep = step_model.net.AddScopedExternalInputs( 'input_t', 'timestep', ) states_prev = step_model.net.AddScopedExternalInputs(*[ s + '_prev' for s in self.get_state_names() ]) states = self._apply( model=step_model, input_t=input_t, seq_lengths=seq_lengths, states=states_prev, timestep=timestep, ) return recurrent.recurrent_net( net=model.net, cell_net=step_model.net, inputs=[(input_t, preprocessed_inputs)], initial_cell_inputs=zip(states_prev, initial_states), links=dict(zip(states_prev, states)), timestep=timestep, scope=self.name, outputs_with_grads=( outputs_with_grads if outputs_with_grads is not None else self.get_outputs_with_grads() ), recompute_blobs_on_backward=self.recompute_blobs, forward_only=self.forward_only, ) def apply(self, model, input_t, seq_lengths, states, timestep): input_t = self.prepare_input(model, input_t) return self._apply(model, input_t, seq_lengths, states, timestep) def _apply(self, model, input_t, seq_lengths, states, timestep): ''' A single step of a recurrent network. model: CNNModelHelper object new operators would be added to input_blob: single input with shape (1, batch_size, input_dim) seq_lengths: blob containing sequence lengths which would be passed to LSTMUnit operator states: previous recurrent states timestep: current recurrent iteration. Could be used together with seq_lengths in order to determine, if some shorter sequences in the batch have already ended. ''' raise NotImplementedError('Abstract method') def prepare_input(self, model, input_blob): ''' If some operations in _apply method depend only on the input, not on recurrent states, they could be computed in advance. model: CNNModelHelper object new operators would be added to input_blob: either the whole input sequence with shape (sequence_length, batch_size, input_dim) or a single input with shape (1, batch_size, input_dim). ''' raise NotImplementedError('Abstract method') def get_state_names(self): ''' Return the names of the recurrent states. It's required by apply_over_sequence method in order to allocate recurrent states for all steps with meaningful names. ''' raise NotImplementedError('Abstract method') class LSTMCell(RNNCell): def __init__( self, input_size, hidden_size, forget_bias, memory_optimization, name, forward_only=False, ): super(LSTMCell, self).__init__(name, forward_only) self.input_size = input_size self.hidden_size = hidden_size self.forget_bias = float(forget_bias) self.memory_optimization = memory_optimization def _apply( self, model, input_t, seq_lengths, states, timestep, ): hidden_t_prev, cell_t_prev = states gates_t = model.FC( hidden_t_prev, self.scope('gates_t'), dim_in=self.hidden_size, dim_out=4 * self.hidden_size, axis=2, ) model.net.Sum([gates_t, input_t], gates_t) hidden_t, cell_t = model.net.LSTMUnit( [ hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep, ], list(self.get_state_names()), forget_bias=self.forget_bias, ) model.net.AddExternalOutputs(hidden_t, cell_t) if self.memory_optimization: self.recompute_blobs = [gates_t] return hidden_t, cell_t def get_input_params(self): return { 'weights': self.scope('i2h') + '_w', 'biases': self.scope('i2h') + '_b', } def get_recurrent_params(self): return { 'weights': self.scope('gates_t') + '_w', 'biases': self.scope('gates_t') + '_b', } def prepare_input(self, model, input_blob): return model.FC( input_blob, self.scope('i2h'), dim_in=self.input_size, dim_out=4 * self.hidden_size, axis=2, ) def get_state_names(self): return (self.scope('hidden_t'), self.scope('cell_t')) def get_outputs_with_grads(self): return [0] def get_output_size(self): return self.hidden_size def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False, memory_optimization=False, forget_bias=0.0, forward_only=False): ''' Adds a standard LSTM recurrent network operator to a model. model: CNNModelHelper object new operators would be added to input_blob: the input sequence in a format T x N x D where T is sequence size, N - batch size and D - input dimention seq_lengths: blob containing sequence lengths which would be passed to LSTMUnit operator initial_states: a tupple of (hidden_input_blob, cell_input_blob) which are going to be inputs to the cell net on the first iteration dim_in: input dimention dim_out: output dimention outputs_with_grads : position indices of output blobs which will receive external error gradient during backpropagation return_params: if True, will return a dictionary of parameters of the LSTM memory_optimization: if enabled, the LSTM step is recomputed on backward step so that we don't need to store forward activations for each timestep. Saves memory with cost of computation. forget_bias: forget gate bias (default 0.0) forward_only: whether to create a backward pass ''' cell = LSTMCell( input_size=dim_in, hidden_size=dim_out, forget_bias=forget_bias, memory_optimization=memory_optimization, name=scope, forward_only=forward_only, ) result = cell.apply_over_sequence( model=model, inputs=input_blob, seq_lengths=seq_lengths, initial_states=initial_states, outputs_with_grads=outputs_with_grads, ) if return_params: result = list(result) + [{ 'input': cell.get_input_params(), 'recurrent': cell.get_recurrent_params(), }] return tuple(result) def GetLSTMParamNames(): weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"] bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"] return {'weights': weight_params, 'biases': bias_params} def InitFromLSTMParams(lstm_pblobs, param_values): ''' Set the parameters of LSTM based on predefined values ''' weight_params = GetLSTMParamNames()['weights'] bias_params = GetLSTMParamNames()['biases'] for input_type in param_values.keys(): weight_values = [param_values[input_type][w].flatten() for w in weight_params] wmat = np.array([]) for w in weight_values: wmat = np.append(wmat, w) bias_values = [param_values[input_type][b].flatten() for b in bias_params] bm = np.array([]) for b in bias_values: bm = np.append(bm, b) weights_blob = lstm_pblobs[input_type]['weights'] bias_blob = lstm_pblobs[input_type]['biases'] cur_weight = workspace.FetchBlob(weights_blob) cur_biases = workspace.FetchBlob(bias_blob) workspace.FeedBlob( weights_blob, wmat.reshape(cur_weight.shape).astype(np.float32)) workspace.FeedBlob( bias_blob, bm.reshape(cur_biases.shape).astype(np.float32)) def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out, scope, recurrent_params=None, input_params=None, num_layers=1, return_params=False): ''' CuDNN version of LSTM for GPUs. input_blob Blob containing the input. Will need to be available when param_init_net is run, because the sequence lengths and batch sizes will be inferred from the size of this blob. initial_states tuple of (hidden_init, cell_init) blobs dim_in input dimensions dim_out output/hidden dimension scope namescope to apply recurrent_params dict of blobs containing values for recurrent gate weights, biases (if None, use random init values) See GetLSTMParamNames() for format. input_params dict of blobs containing values for input gate weights, biases (if None, use random init values) See GetLSTMParamNames() for format. num_layers number of LSTM layers return_params if True, returns (param_extract_net, param_mapping) where param_extract_net is a net that when run, will populate the blobs specified in param_mapping with the current gate weights and biases (input/recurrent). Useful for assigning the values back to non-cuDNN LSTM. ''' with core.NameScope(scope): weight_params = GetLSTMParamNames()['weights'] bias_params = GetLSTMParamNames()['biases'] input_weight_size = dim_out * dim_in upper_layer_input_weight_size = dim_out * dim_out recurrent_weight_size = dim_out * dim_out input_bias_size = dim_out recurrent_bias_size = dim_out def init(layer, pname, input_type): input_weight_size_for_layer = input_weight_size if layer == 0 else \ upper_layer_input_weight_size if pname in weight_params: sz = input_weight_size_for_layer if input_type == 'input' \ else recurrent_weight_size elif pname in bias_params: sz = input_bias_size if input_type == 'input' \ else recurrent_bias_size else: assert False, "unknown parameter type {}".format(pname) return model.param_init_net.UniformFill( [], "lstm_init_{}_{}_{}".format(input_type, pname, layer), shape=[sz]) # Multiply by 4 since we have 4 gates per LSTM unit first_layer_sz = input_weight_size + recurrent_weight_size + \ input_bias_size + recurrent_bias_size upper_layer_sz = upper_layer_input_weight_size + \ recurrent_weight_size + input_bias_size + \ recurrent_bias_size total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz) weights = model.param_init_net.UniformFill( [], "lstm_weight", shape=[total_sz]) model.params.append(weights) model.weights.append(weights) lstm_args = { 'hidden_size': dim_out, 'rnn_mode': 'lstm', 'bidirectional': 0, # TODO 'dropout': 1.0, # TODO 'input_mode': 'linear', # TODO 'num_layers': num_layers, 'engine': 'CUDNN' } param_extract_net = core.Net("lstm_param_extractor") param_extract_net.AddExternalInputs([input_blob, weights]) param_extract_mapping = {} # Populate the weights-blob from blobs containing parameters for # the individual components of the LSTM, such as forget/input gate # weights and bises. Also, create a special param_extract_net that # can be used to grab those individual params from the black-box # weights blob. These results can be then fed to InitFromLSTMParams() for input_type in ['input', 'recurrent']: param_extract_mapping[input_type] = {} p = recurrent_params if input_type == 'recurrent' else input_params if p is None: p = {} for pname in weight_params + bias_params: for j in range(0, num_layers): values = p[pname] if pname in p else init(j, pname, input_type) model.param_init_net.RecurrentParamSet( [input_blob, weights, values], weights, layer=j, input_type=input_type, param_type=pname, **lstm_args ) if pname not in param_extract_mapping[input_type]: param_extract_mapping[input_type][pname] = {} b = param_extract_net.RecurrentParamGet( [input_blob, weights], ["lstm_{}_{}_{}".format(input_type, pname, j)], layer=j, input_type=input_type, param_type=pname, **lstm_args ) param_extract_mapping[input_type][pname][j] = b (hidden_input_blob, cell_input_blob) = initial_states output, hidden_output, cell_output, rnn_scratch, dropout_states = \ model.net.Recurrent( [input_blob, cell_input_blob, cell_input_blob, weights], ["lstm_output", "lstm_hidden_output", "lstm_cell_output", "lstm_rnn_scratch", "lstm_dropout_states"], seed=random.randint(0, 100000), # TODO: dropout seed **lstm_args ) model.net.AddExternalOutputs( hidden_output, cell_output, rnn_scratch, dropout_states) if return_params: param_extract = param_extract_net, param_extract_mapping return output, hidden_output, cell_output, param_extract else: return output, hidden_output, cell_output class LSTMWithAttentionCell(RNNCell): def __init__( self, encoder_output_dim, encoder_outputs, decoder_input_dim, decoder_state_dim, name, attention_type, weighted_encoder_outputs, forget_bias, lstm_memory_optimization, attention_memory_optimization, forward_only=False, ): super(LSTMWithAttentionCell, self).__init__(name, forward_only) self.encoder_output_dim = encoder_output_dim self.encoder_outputs = encoder_outputs self.decoder_input_dim = decoder_input_dim self.decoder_state_dim = decoder_state_dim self.weighted_encoder_outputs = weighted_encoder_outputs self.encoder_outputs_transposed = None assert attention_type in [ AttentionType.Regular, AttentionType.Recurrent, ] self.attention_type = attention_type self.lstm_memory_optimization = lstm_memory_optimization self.attention_memory_optimization = attention_memory_optimization def _apply( self, model, input_t, seq_lengths, states, timestep, ): ( hidden_t_prev, cell_t_prev, attention_weighted_encoder_context_t_prev, ) = states gates_concatenated_input_t, _ = model.net.Concat( [hidden_t_prev, attention_weighted_encoder_context_t_prev], [ self.scope('gates_concatenated_input_t'), self.scope('_gates_concatenated_input_t_concat_dims'), ], axis=2, ) gates_t = model.FC( gates_concatenated_input_t, self.scope('gates_t'), dim_in=self.decoder_state_dim + self.encoder_output_dim, dim_out=4 * self.decoder_state_dim, axis=2, ) model.net.Sum([gates_t, input_t], gates_t) hidden_t_intermediate, cell_t = model.net.LSTMUnit( [ hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep, ], ['hidden_t_intermediate', self.scope('cell_t')], ) if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev ), ) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, attention_blobs, ) = apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, ) hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t')) model.net.AddExternalOutputs( cell_t, hidden_t, attention_weighted_encoder_context_t, ) if self.attention_memory_optimization: self.recompute_blobs.extend(attention_blobs) if self.lstm_memory_optimization: self.recompute_blobs.append(gates_t) return hidden_t, cell_t, attention_weighted_encoder_context_t def get_attention_weights(self): # [batch_size, encoder_length, 1] return self.attention_weights_3d def prepare_input(self, model, input_blob): if self.encoder_outputs_transposed is None: self.encoder_outputs_transposed = model.Transpose( self.encoder_outputs, self.scope('encoder_outputs_transposed'), axes=[1, 2, 0], ) if self.weighted_encoder_outputs is None: self.weighted_encoder_outputs = model.FC( self.encoder_outputs, self.scope('weighted_encoder_outputs'), dim_in=self.encoder_output_dim, dim_out=self.encoder_output_dim, axis=2, ) return model.FC( input_blob, self.scope('i2h'), dim_in=self.decoder_input_dim, dim_out=4 * self.decoder_state_dim, axis=2, ) def get_state_names(self): return ( self.scope('hidden_t'), self.scope('cell_t'), self.scope('attention_weighted_encoder_context_t'), ) def get_outputs_with_grads(self): return [0, 4] def get_output_size(self): return self.decoder_state_dim + self.encoder_output_dim def LSTMWithAttention( model, decoder_inputs, decoder_input_lengths, initial_decoder_hidden_state, initial_decoder_cell_state, initial_attention_weighted_encoder_context, encoder_output_dim, encoder_outputs, decoder_input_dim, decoder_state_dim, scope, attention_type=AttentionType.Regular, outputs_with_grads=(0, 4), weighted_encoder_outputs=None, lstm_memory_optimization=False, attention_memory_optimization=False, forget_bias=0.0, forward_only=False, ): ''' Adds a LSTM with attention mechanism to a model. The implementation is based on https://arxiv.org/abs/1409.0473, with a small difference in the order how we compute new attention context and new hidden state, similarly to https://arxiv.org/abs/1508.04025. The model uses encoder-decoder naming conventions, where the decoder is the sequence the op is iterating over, while computing the attention context over the encoder. model: CNNModelHelper object new operators would be added to decoder_inputs: the input sequence in a format T x N x D where T is sequence size, N - batch size and D - input dimention decoder_input_lengths: blob containing sequence lengths which would be passed to LSTMUnit operator initial_decoder_hidden_state: initial hidden state of LSTM initial_decoder_cell_state: initial cell state of LSTM initial_attention_weighted_encoder_context: initial attention context encoder_output_dim: dimension of encoder outputs encoder_outputs: the sequence, on which we compute the attention context at every iteration decoder_input_dim: input dimention (last dimension on decoder_inputs) decoder_state_dim: size of hidden states of LSTM attention_type: One of: AttentionType.Regular, AttentionType.Recurrent. Determines which type of attention mechanism to use. outputs_with_grads : position indices of output blobs which will receive external error gradient during backpropagation weighted_encoder_outputs: encoder outputs to be used to compute attention weights. In the basic case it's just linear transformation of encoder outputs (that the default, when weighted_encoder_outputs is None). However, it can be something more complicated - like a separate encoder network (for example, in case of convolutional encoder) lstm_memory_optimization: recompute LSTM activations on backward pass, so we don't need to store their values in forward passes attention_memory_optimization: recompute attention for backward pass forward_only: whether to create only forward pass ''' cell = LSTMWithAttentionCell( encoder_output_dim=encoder_output_dim, encoder_outputs=encoder_outputs, decoder_input_dim=decoder_input_dim, decoder_state_dim=decoder_state_dim, name=scope, attention_type=attention_type, weighted_encoder_outputs=weighted_encoder_outputs, forget_bias=forget_bias, lstm_memory_optimization=lstm_memory_optimization, attention_memory_optimization=attention_memory_optimization, forward_only=forward_only, ) return cell.apply_over_sequence( model=model, inputs=decoder_inputs, seq_lengths=decoder_input_lengths, initial_states=( initial_decoder_hidden_state, initial_decoder_cell_state, initial_attention_weighted_encoder_context, ), outputs_with_grads=None, ) class MILSTMCell(LSTMCell): def _apply( self, model, input_t, seq_lengths, states, timestep, ): ( hidden_t_prev, cell_t_prev, ) = states # hU^T # Shape: [1, batch_size, 4 * hidden_size] prev_t = model.FC( hidden_t_prev, self.scope('prev_t'), dim_in=self.hidden_size, dim_out=4 * self.hidden_size, axis=2) # defining MI parameters alpha = model.param_init_net.ConstantFill( [], [self.scope('alpha')], shape=[4 * self.hidden_size], value=1.0 ) beta1 = model.param_init_net.ConstantFill( [], [self.scope('beta1')], shape=[4 * self.hidden_size], value=1.0 ) beta2 = model.param_init_net.ConstantFill( [], [self.scope('beta2')], shape=[4 * self.hidden_size], value=1.0 ) b = model.param_init_net.ConstantFill( [], [self.scope('b')], shape=[4 * self.hidden_size], value=0.0 ) model.params.extend([alpha, beta1, beta2, b]) # alpha * (xW^T * hU^T) # Shape: [1, batch_size, 4 * hidden_size] alpha_tdash = model.net.Mul( [prev_t, input_t], self.scope('alpha_tdash') ) # Shape: [batch_size, 4 * hidden_size] alpha_tdash_rs, _ = model.net.Reshape( alpha_tdash, [self.scope('alpha_tdash_rs'), self.scope('alpha_tdash_old_shape')], shape=[-1, 4 * self.hidden_size], ) alpha_t = model.net.Mul( [alpha_tdash_rs, alpha], self.scope('alpha_t'), broadcast=1, use_grad_hack=1 ) # beta1 * hU^T # Shape: [batch_size, 4 * hidden_size] prev_t_rs, _ = model.net.Reshape( prev_t, [self.scope('prev_t_rs'), self.scope('prev_t_old_shape')], shape=[-1, 4 * self.hidden_size], ) beta1_t = model.net.Mul( [prev_t_rs, beta1], self.scope('beta1_t'), broadcast=1, use_grad_hack=1 ) # beta2 * xW^T # Shape: [batch_szie, 4 * hidden_size] input_t_rs, _ = model.net.Reshape( input_t, [self.scope('input_t_rs'), self.scope('input_t_old_shape')], shape=[-1, 4 * self.hidden_size], ) beta2_t = model.net.Mul( [input_t_rs, beta2], self.scope('beta2_t'), broadcast=1, use_grad_hack=1 ) # Add 'em all up gates_tdash = model.net.Sum( [alpha_t, beta1_t, beta2_t], self.scope('gates_tdash') ) gates_t = model.net.Add( [gates_tdash, b], self.scope('gates_t'), broadcast=1, use_grad_hack=1 ) # # Shape: [1, batch_size, 4 * hidden_size] gates_t_rs, _ = model.net.Reshape( gates_t, [self.scope('gates_t_rs'), self.scope('gates_t_old_shape')], shape=[1, -1, 4 * self.hidden_size], ) hidden_t_intermediate, cell_t = model.net.LSTMUnit( [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep], [self.scope('hidden_t_intermediate'), self.scope('cell_t')], forget_bias=self.forget_bias, ) hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t')) model.net.AddExternalOutputs( cell_t, hidden_t, ) if self.memory_optimization: self.recompute_blobs = [gates_t] return hidden_t, cell_t def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), memory_optimization=False, forget_bias=0.0, forward_only=False): ''' Adds MI flavor of standard LSTM recurrent network operator to a model. See https://arxiv.org/pdf/1606.06630.pdf model: CNNModelHelper object new operators would be added to input_blob: the input sequence in a format T x N x D where T is sequence size, N - batch size and D - input dimention seq_lengths: blob containing sequence lengths which would be passed to LSTMUnit operator initial_states: a tupple of (hidden_input_blob, cell_input_blob) which are going to be inputs to the cell net on the first iteration dim_in: input dimention dim_out: output dimention outputs_with_grads : position indices of output blobs which will receive external error gradient during backpropagation memory_optimization: if enabled, the LSTM step is recomputed on backward step so that we don't need to store forward activations for each timestep. Saves memory with cost of computation. forward_only run only forward pass ''' cell = MILSTMCell( input_size=dim_in, hidden_size=dim_out, forget_bias=forget_bias, memory_optimization=memory_optimization, name=scope, forward_only=forward_only, ) result = cell.apply_over_sequence( model=model, inputs=input_blob, seq_lengths=seq_lengths, initial_states=initial_states, outputs_with_grads=outputs_with_grads, ) return tuple(result) class MILSTMWithAttentionCell(LSTMWithAttentionCell): def _apply( self, model, input_t, seq_lengths, states, timestep, ): ( hidden_t_prev, cell_t_prev, attention_weighted_encoder_context_t_prev, ) = states gates_concatenated_input_t, _ = model.net.Concat( [hidden_t_prev, attention_weighted_encoder_context_t_prev], [ self.scope('gates_concatenated_input_t'), self.scope('_gates_concatenated_input_t_concat_dims'), ], axis=2, ) # hU^T # Shape: [1, batch_size, 4 * hidden_size] prev_t = model.FC( gates_concatenated_input_t, self.scope('prev_t'), dim_in=self.decoder_state_dim + self.encoder_output_dim, dim_out=4 * self.decoder_state_dim, axis=2, ) # defining MI parameters alpha = model.param_init_net.ConstantFill( [], [self.scope('alpha')], shape=[4 * self.decoder_state_dim], value=1.0 ) beta1 = model.param_init_net.ConstantFill( [], [self.scope('beta1')], shape=[4 * self.decoder_state_dim], value=1.0 ) beta2 = model.param_init_net.ConstantFill( [], [self.scope('beta2')], shape=[4 * self.decoder_state_dim], value=1.0 ) b = model.param_init_net.ConstantFill( [], [self.scope('b')], shape=[4 * self.decoder_state_dim], value=0.0 ) model.params.extend([alpha, beta1, beta2, b]) # alpha * (xW^T * hU^T) # Shape: [1, batch_size, 4 * hidden_size] alpha_tdash = model.net.Mul( [prev_t, input_t], self.scope('alpha_tdash') ) # Shape: [batch_size, 4 * hidden_size] alpha_tdash_rs, _ = model.net.Reshape( alpha_tdash, [self.scope('alpha_tdash_rs'), self.scope('alpha_tdash_old_shape')], shape=[-1, 4 * self.decoder_state_dim], ) alpha_t = model.net.Mul( [alpha_tdash_rs, alpha], self.scope('alpha_t'), broadcast=1, use_grad_hack=1 ) # beta1 * hU^T # Shape: [batch_size, 4 * hidden_size] prev_t_rs, _ = model.net.Reshape( prev_t, [self.scope('prev_t_rs'), self.scope('prev_t_old_shape')], shape=[-1, 4 * self.decoder_state_dim], ) beta1_t = model.net.Mul( [prev_t_rs, beta1], self.scope('beta1_t'), broadcast=1, use_grad_hack=1 ) # beta2 * xW^T # Shape: [batch_szie, 4 * hidden_size] input_t_rs, _ = model.net.Reshape( input_t, [self.scope('input_t_rs'), self.scope('input_t_old_shape')], shape=[-1, 4 * self.decoder_state_dim], ) beta2_t = model.net.Mul( [input_t_rs, beta2], self.scope('beta2_t'), broadcast=1, use_grad_hack=1 ) # Add 'em all up gates_tdash = model.net.Sum( [alpha_t, beta1_t, beta2_t], self.scope('gates_tdash') ) gates_t = model.net.Add( [gates_tdash, b], self.scope('gates_t'), broadcast=1, use_grad_hack=1 ) # # Shape: [1, batch_size, 4 * hidden_size] gates_t_rs, _ = model.net.Reshape( gates_t, [self.scope('gates_t_rs'), self.scope('gates_t_old_shape')], shape=[1, -1, 4 * self.decoder_state_dim], ) hidden_t_intermediate, cell_t = model.net.LSTMUnit( [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep], [self.scope('hidden_t_intermediate'), self.scope('cell_t')], ) if self.attention_type == AttentionType.Recurrent: ( attention_weighted_encoder_context_t, self.attention_weights_3d, self.recompute_blobs, ) = ( apply_recurrent_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, attention_weighted_encoder_context_t_prev=( attention_weighted_encoder_context_t_prev ), ) ) else: ( attention_weighted_encoder_context_t, self.attention_weights_3d, self.recompute_blobs, ) = ( apply_regular_attention( model=model, encoder_output_dim=self.encoder_output_dim, encoder_outputs_transposed=self.encoder_outputs_transposed, weighted_encoder_outputs=self.weighted_encoder_outputs, decoder_hidden_state_t=hidden_t_intermediate, decoder_hidden_state_dim=self.decoder_state_dim, scope=self.name, ) ) hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t')) model.net.AddExternalOutputs( cell_t, hidden_t, attention_weighted_encoder_context_t, ) return hidden_t, cell_t, attention_weighted_encoder_context_t