BatchNorm fallback to THNN when eps < CUDNN_BN_MIN_EPSILON (#1742)

This commit is contained in:
Soumith Chintala 2017-06-07 09:56:28 -04:00 committed by GitHub
parent 352f8b2fa6
commit edd41d8d80

View File

@ -17,6 +17,10 @@ namespace torch { namespace autograd {
using thpp::Tensor;
#ifndef CUDNN_BN_MIN_EPSILON
#define CUDNN_BN_MIN_EPSILON 0
#endif
auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
check_input_variables("BatchNorm", inputs, 3, 1);
@ -41,7 +45,7 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
std::unique_ptr<Tensor> save_std(output->newTensor());
save_std->resizeAs(*running_var);
if (use_cudnn) {
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
#ifdef WITH_CUDNN
torch::cudnn::cudnn_batch_norm_forward(
state,
@ -123,7 +127,7 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
}
}
if (use_cudnn) {
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
#ifdef WITH_CUDNN
torch::cudnn::cudnn_batch_norm_backward(
state,