#include "caffe2/perfkernels/lstm_unit_cpu-impl.h" namespace caffe2 { namespace perfkernels { namespace { // Explicit initialize for float and AVX2 vectorization template void LstmUnitImpl( const int N, const int D, const int t, const float* H_prev, const float* C_prev, const float* X, const int32_t* seqLengths, const bool drop_states, float* C, float* H, const float forget_bias); template void LstmUnitGradientImpl( int N, int D, int t, const float* C_prev, const float* X, const int32_t* seqLengths, const float* C, const float* H, const float* C_diff, const float* H_diff, bool drop_states, float* H_prev_diff, float* C_prev_diff, float* X_diff, const float forget_bias); } // namespace // Define templated implementation fo LSTM kernels on CPU supporting AVX2 template void LstmUnitImpl__avx2_fma( const int N, const int D, const int t, const T* H_prev, const T* C_prev, const T* X, const int32_t* seqLengths, const bool drop_states, T* C, T* H, const float forget_bias) { LstmUnitImpl( N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias); } template void LstmUnitGradientImpl__avx2_fma( int N, int D, int t, const T* C_prev, const T* X, const int32_t* seqLengths, const T* C, const T* H, const T* C_diff, const T* H_diff, bool drop_states, T* H_prev_diff, T* C_prev_diff, T* X_diff, const float forget_bias) { LstmUnitGradientImpl( N, D, t, C_prev, X, seqLengths, C, H, C_diff, H_diff, drop_states, H_prev_diff, C_prev_diff, X_diff, forget_bias); } // Explicit initialize for float template void LstmUnitImpl__avx2_fma( const int N, const int D, const int t, const float* H_prev, const float* C_prev, const float* X, const int32_t* seqLengths, const bool drop_states, float* C, float* H, const float forget_bias); template void LstmUnitGradientImpl__avx2_fma( int N, int D, int t, const float* C_prev, const float* X, const int32_t* seqLengths, const float* C, const float* H, const float* C_diff, const float* H_diff, bool drop_states, float* H_prev_diff, float* C_prev_diff, float* X_diff, const float forget_bias); } // namespace perfkernels } // namespace caffe2