mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
BatchNorm fallback to THNN when eps < CUDNN_BN_MIN_EPSILON (#1742)
This commit is contained in:
parent
352f8b2fa6
commit
edd41d8d80
|
|
@ -17,6 +17,10 @@ namespace torch { namespace autograd {
|
|||
|
||||
using thpp::Tensor;
|
||||
|
||||
#ifndef CUDNN_BN_MIN_EPSILON
|
||||
#define CUDNN_BN_MIN_EPSILON 0
|
||||
#endif
|
||||
|
||||
auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
|
||||
check_input_variables("BatchNorm", inputs, 3, 1);
|
||||
|
||||
|
|
@ -41,7 +45,7 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
|
|||
std::unique_ptr<Tensor> save_std(output->newTensor());
|
||||
save_std->resizeAs(*running_var);
|
||||
|
||||
if (use_cudnn) {
|
||||
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
|
||||
#ifdef WITH_CUDNN
|
||||
torch::cudnn::cudnn_batch_norm_forward(
|
||||
state,
|
||||
|
|
@ -123,7 +127,7 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
|
|||
}
|
||||
}
|
||||
|
||||
if (use_cudnn) {
|
||||
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
|
||||
#ifdef WITH_CUDNN
|
||||
torch::cudnn::cudnn_batch_norm_backward(
|
||||
state,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user